diff --git a/cC.ml b/cC.ml new file mode 100644 index 00000000..f7b6e005 --- /dev/null +++ b/cC.ml @@ -0,0 +1,494 @@ +(* +Copyright (c) 2013, Simon Cruanes +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. Redistributions in binary +form must reproduce the above copyright notice, this list of conditions and the +following disclaimer in the documentation and/or other materials provided with +the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*) + +(** {1 Functional Congruence Closure} *) + +(** This implementation follows more or less the paper + "fast congruence closure and extensions" by Nieuwenhuis & Oliveras. + It uses semi-persistent data structures but still thrives for efficiency. *) + +(** {2 Curryfied terms} *) + +module type CurryfiedTerm = sig + type symbol + type t = private { + shape : shape; (** Which kind of term is it? *) + tag : int; (** Unique ID *) + } (** A curryfied term *) + and shape = private + | Const of symbol (** Constant *) + | Apply of t * t (** Curryfied application *) + + val mk_const : symbol -> t + val mk_app : t -> t -> t + val get_id : t -> int + val eq : t -> t -> bool + val pp_skel : out_channel -> t -> unit (* print tags recursively *) +end + +module Curryfy(X : Hashtbl.HashedType) = struct + type symbol = X.t + type t = { + shape : shape; (** Which kind of term is it? *) + tag : int; (** Unique ID *) + } + and shape = + | Const of symbol (** Constant *) + | Apply of t * t (** Curryfied application *) + + type term = t + + module WE = Weak.Make(struct + type t = term + let equal a b = match a.shape, b.shape with + | Const ia, Const ib -> X.equal ia ib + | Apply (a1,a2), Apply (b1,b2) -> a1 == b1 && a2 == b2 + | _ -> false + let hash a = match a.shape with + | Const i -> X.hash i + | Apply (a, b) -> a.tag * 65599 + b.tag + end) + + let __table = WE.create 10001 + let count = ref 0 + + let hashcons t = + let t' = WE.merge __table t in + (if t == t' then incr count); + t' + + let mk_const i = + let t = {shape=Const i; tag= !count; } in + hashcons t + + let mk_app a b = + let t = {shape=Apply (a, b); tag= !count; } in + hashcons t + + let get_id t = t.tag + + let eq t1 t2 = t1 == t2 + + let rec pp_skel oc t = match t.shape with + | Const _ -> Printf.fprintf oc "%d" t.tag + | Apply (t1, t2) -> + Printf.fprintf oc "(%a %a):%d" pp_skel t1 pp_skel t2 t.tag +end + +(** {2 Congruence Closure} *) + +module type S = sig + module CT : CurryfiedTerm + + type t + (** Congruence Closure instance *) + + exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t + (** Exception raised when equality and inequality constraints are + inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in + the congruence closure, but [a' != b'] was asserted before. *) + + val create : int -> t + (** Create an empty CC of given size *) + + val eq : t -> CT.t -> CT.t -> bool + (** Check whether the two terms are equal *) + + val merge : t -> CT.t -> CT.t -> t + (** Assert that the two terms are equal (may raise Inconsistent) *) + + val distinct : t -> CT.t -> CT.t -> t + (** Assert that the two given terms are distinct (may raise Inconsistent) *) + + type action = + | Merge of CT.t * CT.t + | Distinct of CT.t * CT.t + (** Action that can be performed on the CC *) + + val do_action : t -> action -> t + (** Perform the given action (may raise Inconsistent) *) + + val can_eq : t -> CT.t -> CT.t -> bool + (** Check whether the two terms can be equal *) + + val iter_equiv_class : t -> CT.t -> (CT.t -> unit) -> unit + (** Iterate on terms that are congruent to the given term *) + + type explanation = + | ByCongruence of CT.t * CT.t (* direct congruence of terms *) + | ByMerge of CT.t * CT.t (* user merge of terms *) + + val explain : t -> CT.t -> CT.t -> explanation list + (** Explain why those two terms are equal (assuming they are, + otherwise raises Invalid_argument) by returning a list + of merges. *) +end + +module Make(T : CurryfiedTerm) = struct + module CT = T + module BV = Puf.PBitVector + module Puf = Puf.Make(CT) + + module HashedCT = struct + type t = CT.t + let equal t1 t2 = t1.CT.tag = t2.CT.tag + let hash t = t.CT.tag + end + + (* Persistent Hashtable on curryfied terms *) + module THashtbl = PersistentHashtbl.Make(HashedCT) + + (* Persistent Hashtable on pairs of curryfied terms *) + module T2Hashtbl = PersistentHashtbl.Make(struct + type t = CT.t * CT.t + let equal (t1,t1') (t2,t2') = t1.CT.tag = t2.CT.tag && t1'.CT.tag = t2'.CT.tag + let hash (t,t') = t.CT.tag * 65599 + t'.CT.tag + end) + + type t = { + uf : pending_eqn Puf.t; (* representatives for terms *) + defined : BV.t; (* is the term defined? *) + use : eqn list THashtbl.t; (* for all repr a, a -> all a@b=c and b@a=c *) + lookup : eqn T2Hashtbl.t; (* for all reprs a,b, some a@b=c (if any) *) + inconsistent : (CT.t * CT.t) option; + } (** Congruence Closure data structure *) + and eqn = + | EqnSimple of CT.t * CT.t (* t1 = t2 *) + | EqnApply of CT.t * CT.t * CT.t (* (t1 @ t2) = t3 *) + (** Equation between two terms *) + and pending_eqn = + | PendingSimple of eqn + | PendingDouble of eqn * eqn + + exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t + (** Exception raised when equality and inequality constraints are + inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in + the congruence closure, but [a' != b'] was asserted before. *) + + (** Create an empty CC of given size *) + let create size = + { uf = Puf.create size; + defined = BV.make 3; + use = THashtbl.create size; + lookup = T2Hashtbl.create size; + inconsistent = None; + } + + let mem cc t = + BV.get cc.defined t.CT.tag + + let is_const t = match t.CT.shape with + | CT.Const _ -> true + | CT.Apply _ -> false + + (** Merge equations in the congruence closure structure. [q] is a list + of [eqn], processed in FIFO order. May raise Inconsistent. *) + let rec merge cc eqn = match eqn with + | EqnSimple (a, b) -> + (* a=b, just propagate *) + propagate cc [PendingSimple eqn] + | EqnApply (a1, a2, a) -> + (* (a1 @ a2) = a *) + let a1' = Puf.find cc.uf a1 in + let a2' = Puf.find cc.uf a2 in + begin try + (* eqn' is (b1 @ b2) = b for some b1=a1', b2=a2' *) + let eqn' = T2Hashtbl.find cc.lookup (a1', a2') in + (* merge a and b because of eqn and eqn' *) + propagate cc [PendingDouble (eqn, eqn')] + with Not_found -> + (* remember that a1' @ a2' = a *) + let lookup = T2Hashtbl.replace cc.lookup (a1', a2') eqn in + let use_a1' = try THashtbl.find cc.use a1' with Not_found -> [] in + let use_a2' = try THashtbl.find cc.use a2' with Not_found -> [] in + let use = THashtbl.replace cc.use a1' (eqn::use_a1') in + let use = THashtbl.replace use a2' (eqn::use_a2') in + { cc with use; lookup; } + end + (* propagate: merge pending equations *) + and propagate cc eqns = + let pending = ref eqns in + let uf = ref cc.uf in + let use = ref cc.use in + let lookup = ref cc.lookup in + (* process each pending equation *) + while !pending <> [] do + let eqn = List.hd !pending in + pending := List.tl !pending; + (* extract the two merged terms *) + let a, b = match eqn with + | PendingSimple (EqnSimple (a, b)) -> a, b + | PendingDouble (EqnApply (a1,a2,a), EqnApply (b1,b2,b)) -> a, b + | _ -> assert false + in + let a' = Puf.find !uf a in + let b' = Puf.find !uf b in + if not (CT.eq a' b') then begin + let use_a' = try THashtbl.find !use a' with Not_found -> [] in + let use_b' = try THashtbl.find !use b' with Not_found -> [] in + (* merge a and b's equivalence classes *) + (* Format.printf "merge %d %d@." a.CT.tag b.CT.tag; *) + uf := Puf.union !uf a b eqn; + (* check which of [a'] and [b'] is the new representative. [repr] is + the new representative, and [other] is the former representative *) + let repr = Puf.find !uf a' in + let use_repr = ref (if CT.eq repr a' then use_a' else use_b') in + let use_other = if CT.eq repr a' then use_b' else use_a' in + (* consider all c1@c2=c in use(a') *) + List.iter + (fun eqn -> match eqn with + | EqnSimple _ -> () + | EqnApply (c1, c2, c) -> + let c1' = Puf.find !uf c1 in + let c2' = Puf.find !uf c2 in + begin try + let eqn' = T2Hashtbl.find !lookup (c1', c2') in + (* merge eqn with eqn', by congruence *) + pending := (PendingDouble (eqn,eqn')) :: !pending + with Not_found -> + lookup := T2Hashtbl.replace !lookup (c1', c2') eqn; + use_repr := eqn :: !use_repr; + end) + use_other; + (* update use list of [repr] *) + use := THashtbl.replace !use repr !use_repr; + (* check for inconsistencies *) + match Puf.inconsistent !uf with + | None -> () (* consistent *) + | Some (t1, t2, t1', t2') -> + (* inconsistent *) + let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in + raise (Inconsistent (cc, t1, t2, t1', t2')) + end + done; + let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in + cc + + (** Add the given term to the CC *) + let rec add cc t = + match t.CT.shape with + | CT.Const _ -> + cc (* always trivially defined *) + | CT.Apply (t1, t2) -> + if BV.get cc.defined t.CT.tag + then cc (* already defined *) + else begin + (* note that [t] is defined, add it to the UF to avoid GC *) + let defined = BV.set_true cc.defined t.CT.tag in + let cc = {cc with defined; } in + (* recursive add. invariant: if a term is added, then its subterms + also are (hence the base case of constants or already added terms). *) + let cc = add cc t1 in + let cc = add cc t2 in + let cc = merge cc (EqnApply (t1, t2, t)) in + cc + end + + (** Check whether the two terms are equal *) + let eq cc t1 t2 = + let cc = add (add cc t1) t2 in + let t1' = Puf.find cc.uf t1 in + let t2' = Puf.find cc.uf t2 in + CT.eq t1' t2' + + (** Assert that the two terms are equal (may raise Inconsistent) *) + let merge cc t1 t2 = + let cc = add (add cc t1) t2 in + merge cc (EqnSimple (t1, t2)) + + (** Assert that the two given terms are distinct (may raise Inconsistent) *) + let distinct cc t1 t2 = + let cc = add (add cc t1) t2 in + let t1' = Puf.find cc.uf t1 in + let t2' = Puf.find cc.uf t2 in + if CT.eq t1' t2' + then raise (Inconsistent (cc, t1', t2', t1, t2)) (* they are equal, fail *) + else + (* remember that they should not become equal *) + let uf = Puf.distinct cc.uf t1 t2 in + { cc with uf; } + + type action = + | Merge of CT.t * CT.t + | Distinct of CT.t * CT.t + (** Action that can be performed on the CC *) + + let do_action cc action = match action with + | Merge (t1, t2) -> merge cc t1 t2 + | Distinct (t1, t2) -> distinct cc t1 t2 + + (** Check whether the two terms can be equal *) + let can_eq cc t1 t2 = + let cc = add (add cc t1) t2 in + not (Puf.must_be_distinct cc.uf t1 t2) + + (** Iterate on terms that are congruent to the given term *) + let iter_equiv_class cc t f = + Puf.iter_equiv_class cc.uf t f + + (** {3 Auxilliary Union-find for explanations} *) + + module SparseUF = struct + module H = Hashtbl.Make(HashedCT) + + type t = uf_ref H.t + and uf_ref = { + term : CT.t; + mutable parent : CT.t; + mutable highest_node : CT.t; + } (** Union-find reference *) + + let create size = H.create size + + let get_ref uf t = + try H.find uf t + with Not_found -> + let r_t = { term=t; parent=t; highest_node=t; } in + H.add uf t r_t; + r_t + + let rec find_ref uf r_t = + if CT.eq r_t.parent r_t.term + then r_t (* fixpoint *) + else + let r_t' = get_ref uf r_t.parent in + find_ref uf r_t' (* recurse (no path compression) *) + + let find uf t = + try + let r_t = H.find uf t in + (find_ref uf r_t).term + with Not_found -> + t + + let eq uf t1 t2 = + CT.eq (find uf t1) (find uf t2) + + let highest_node uf t = + try + let r_t = H.find uf t in + (find_ref uf r_t).highest_node + with Not_found -> + t + + (* oriented union (t1 -> t2), assuming t2 is "higher" than t1 *) + let union uf t1 t2 = + let r_t1' = find_ref uf (get_ref uf t1) in + let r_t2' = find_ref uf (get_ref uf t2) in + r_t1'.parent <- r_t2'.term + end + + (** {3 Producing explanations} *) + + type explanation = + | ByCongruence of CT.t * CT.t (* direct congruence of terms *) + | ByMerge of CT.t * CT.t (* user merge of terms *) + + (** Explain why those two terms are equal (they must be) *) + let explain cc t1 t2 = + assert (eq cc t1 t2); + (* keeps track of which equalities are already explained *) + let explained = SparseUF.create 5 in + let explanations = ref [] in + (* equations waiting to be explained *) + let pending = Queue.create () in + Queue.push (t1,t2) pending; + (* explain why a=c, where c is the root of the proof forest a belongs to *) + let rec explain_along a c = + let a' = SparseUF.highest_node explained a in + if CT.eq a' c then () + else match Puf.explain_step cc.uf a' with + | None -> assert (CT.eq a' c) + | Some (b, e) -> + (* a->b on the path from a to c *) + begin match e with + | PendingSimple (EqnSimple (a',b')) -> + explanations := ByMerge (a', b') :: !explanations + | PendingDouble (EqnApply (a1, a2, a'), EqnApply (b1, b2, b')) -> + explanations := ByCongruence (a', b') :: !explanations; + Queue.push (a1, b1) pending; + Queue.push (a2, b2) pending; + | _ -> assert false + end; + (* now a' = b is justified *) + SparseUF.union explained a' b; + (* recurse *) + let new_a = SparseUF.highest_node explained b in + explain_along new_a c + in + (* process pending equations *) + while not (Queue.is_empty pending) do + let a, b = Queue.pop pending in + if SparseUF.eq explained a b + then () + else begin + let c = Puf.common_ancestor cc.uf a b in + explain_along a c; + explain_along b c; + end + done; + !explanations +end + +module StrTerm = Curryfy(struct + type t = string + let equal s1 s2 = s1 = s2 + let hash s = Hashtbl.hash s +end) + +module StrCC = Make(StrTerm) + +let lex str = + let lexer = Genlex.make_lexer ["("; ")"] in + lexer (Stream.of_string str) + +let parse str = + let stream = lex str in + let rec parse_term () = + match Stream.peek stream with + | Some (Genlex.Kwd "(") -> + Stream.junk stream; + let t1 = parse_term () in + let t2 = parse_term () in + begin match Stream.peek stream with + | Some (Genlex.Kwd ")") -> + Stream.junk stream; + StrTerm.mk_app t1 t2 (* end apply *) + | _ -> raise (Stream.Error "expected )") + end + | Some (Genlex.Ident s) -> + Stream.junk stream; + StrTerm.mk_const s + | _ -> raise (Stream.Error "expected term") + in + parse_term () + +let rec pp fmt t = + match t.StrTerm.shape with + | StrTerm.Const s -> + Format.fprintf fmt "%s:%d" s t.StrTerm.tag + | StrTerm.Apply (t1, t2) -> + Format.fprintf fmt "(%a %a):%d" pp t1 pp t2 t.StrTerm.tag + diff --git a/cC.mli b/cC.mli new file mode 100644 index 00000000..89a1b031 --- /dev/null +++ b/cC.mli @@ -0,0 +1,105 @@ +(* +Copyright (c) 2013, Simon Cruanes +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. Redistributions in binary +form must reproduce the above copyright notice, this list of conditions and the +following disclaimer in the documentation and/or other materials provided with +the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*) + +(** {1 Functional Congruence Closure} *) + +(** {2 Curryfied terms} *) + +module type CurryfiedTerm = sig + type symbol + type t = private { + shape : shape; (** Which kind of term is it? *) + tag : int; (** Unique ID *) + } (** A curryfied term *) + and shape = private + | Const of symbol (** Constant *) + | Apply of t * t (** Curryfied application *) + + val mk_const : symbol -> t + val mk_app : t -> t -> t + val get_id : t -> int + val eq : t -> t -> bool + val pp_skel : out_channel -> t -> unit (* print tags recursively *) +end + +module Curryfy(X : Hashtbl.HashedType) : CurryfiedTerm with type symbol = X.t + +(** {2 Congruence Closure} *) + +module type S = sig + module CT : CurryfiedTerm + + type t + (** Congruence Closure instance *) + + exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t + (** Exception raised when equality and inequality constraints are + inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in + the congruence closure, but [a' != b'] was asserted before. *) + + val create : int -> t + (** Create an empty CC of given size *) + + val eq : t -> CT.t -> CT.t -> bool + (** Check whether the two terms are equal *) + + val merge : t -> CT.t -> CT.t -> t + (** Assert that the two terms are equal (may raise Inconsistent) *) + + val distinct : t -> CT.t -> CT.t -> t + (** Assert that the two given terms are distinct (may raise Inconsistent) *) + + type action = + | Merge of CT.t * CT.t + | Distinct of CT.t * CT.t + (** Action that can be performed on the CC *) + + val do_action : t -> action -> t + (** Perform the given action (may raise Inconsistent) *) + + val can_eq : t -> CT.t -> CT.t -> bool + (** Check whether the two terms can be equal *) + + val iter_equiv_class : t -> CT.t -> (CT.t -> unit) -> unit + (** Iterate on terms that are congruent to the given term *) + + type explanation = + | ByCongruence of CT.t * CT.t (* direct congruence of terms *) + | ByMerge of CT.t * CT.t (* user merge of terms *) + + val explain : t -> CT.t -> CT.t -> explanation list + (** Explain why those two terms are equal (assuming they are, + otherwise raises Invalid_argument) by returning a list + of merges. *) +end + +module Make(T : CurryfiedTerm) : S with module CT = T + +module StrTerm : CurryfiedTerm with type symbol = string + +module StrCC : S with module CT = StrTerm + +val parse : string -> StrTerm.t +val pp : Format.formatter -> StrTerm.t -> unit diff --git a/puf.ml b/puf.ml new file mode 100644 index 00000000..7a00564a --- /dev/null +++ b/puf.ml @@ -0,0 +1,519 @@ +(* +Copyright (c) 2013, Simon Cruanes +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. Redistributions in binary +form must reproduce the above copyright notice, this list of conditions and the +following disclaimer in the documentation and/or other materials provided with +the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*) + +(** {1 Functional (persistent) extensible union-find} *) + +(** {2 Persistent array} *) + +module PArray = struct + type 'a t = 'a zipper ref + and 'a zipper = + | Array of 'a array + | Diff of int * 'a * 'a t + + (* XXX maybe having a snapshot of the array from point to point may help? *) + + let make size elt = + let a = Array.create size elt in + ref (Array a) + + let init size f = + let a = Array.init size f in + ref (Array a) + + (** Recover the given version of the shared array. Returns the array + itself. *) + let rec reroot t = + match !t with + | Array a -> a + | Diff (i, v, t') -> + begin + let a = reroot t' in + let v' = a.(i) in + t' := Diff (i, v', t); + a.(i) <- v; + t := Array a; + a + end + + let get t i = + match !t with + | Array a -> a.(i) + | Diff _ -> + let a = reroot t in + a.(i) + + let set t i v = + let a = + match !t with + | Array a -> a + | Diff _ -> reroot t in + let v' = a.(i) in + if v == v' + then t (* no change *) + else begin + let t' = ref (Array a) in + a.(i) <- v; + t := Diff (i, v', t'); + t' (* create new array *) + end + + let rec length t = + match !t with + | Array a -> Array.length a + | Diff (_, _, t') -> length t' + + (** Extend [t] to the given [size], initializing new elements with [elt] *) + let extend t size elt = + let a = match !t with + | Array a -> a + | _ -> reroot t in + if size > Array.length a + then begin (* resize: create bigger array *) + let size = min Sys.max_array_length size in + let a' = Array.make size elt in + (* copy old part *) + Array.blit a 0 a' 0 (Array.length a); + t := Array a' + end + + (** Extend [t] to the given [size], initializing elements with [f] *) + let extend_init t size f = + let a = match !t with + | Array a -> a + | _ -> reroot t in + if size > Array.length a + then begin (* resize: create bigger array *) + let size = min Sys.max_array_length size in + let a' = Array.init size f in + (* copy old part *) + Array.blit a 0 a' 0 (Array.length a); + t := Array a' + end + + let fold_left f acc t = + let a = reroot t in + Array.fold_left f acc a +end + +(** {2 Persistent Bitvector} *) + +module PBitVector = struct + type t = int PArray.t + + let width = Sys.word_size - 1 (* number of usable bits in an integer *) + + let make size = PArray.make size 0 + + let ensure bv offset = + if offset >= PArray.length bv + then + let len = offset + offset/2 + 1 in + PArray.extend bv len 0 + else () + + (** [get bv i] gets the value of the [i]-th element of [bv] *) + let get bv i = + let offset = i / width in + let bit = i mod width in + ensure bv offset; + let bits = PArray.get bv offset in + (bits land (1 lsl bit)) <> 0 + + (** [set bv i v] sets the value of the [i]-th element of [bv] to [v] *) + let set bv i v = + let offset = i / width in + let bit = i mod width in + ensure bv offset; + let bits = PArray.get bv offset in + let bits' = + if v + then bits lor (1 lsl bit) + else bits land (lnot (1 lsl bit)) + in + PArray.set bv offset bits' + + (** Bitvector with all bits set to 0 *) + let clear bv = make 5 + + let set_true bv i = set bv i true + let set_false bv i = set bv i false +end + +(** {2 Type with unique identifier} *) + +module type ID = sig + type t + val get_id : t -> int +end + +(** {2 Persistent Union-Find with explanations} *) + +module type S = sig + type elt + (** Elements of the Union-find *) + + type 'e t + (** An instance of the union-find, ie a set of equivalence classes; It + is parametrized by the type of explanations. *) + + val create : int -> 'e t + (** Create a union-find of the given size. *) + + val find : 'e t -> elt -> elt + (** [find uf a] returns the current representative of [a] in the given + union-find structure [uf]. By default, [find uf a = a]. *) + + val union : 'e t -> elt -> elt -> 'e -> 'e t + (** [union uf a b why] returns an update of [uf] where [find a = find b], + the merge being justified by [why]. *) + + val distinct : 'e t -> elt -> elt -> 'e t + (** Ensure that the two elements are distinct. *) + + val must_be_distinct : _ t -> elt -> elt -> bool + (** Should the two elements be distinct? *) + + val fold_equiv_class : _ t -> elt -> ('a -> elt -> 'a) -> 'a -> 'a + (** [fold_equiv_class uf a f acc] folds on [acc] and every element + that is congruent to [a] with [f]. *) + + val iter_equiv_class : _ t -> elt -> (elt -> unit) -> unit + (** [iter_equiv_class uf a f] calls [f] on every element of [uf] that + is congruent to [a], including [a] itself. *) + + val inconsistent : _ t -> (elt * elt * elt * elt) option + (** Check whether the UF is inconsistent. It returns [Some (a, b, a', b')] + in case of inconsistency, where a = b, a = a' and b = b' by congruence, + and a' != b' was a call to [distinct]. *) + + val common_ancestor : 'e t -> elt -> elt -> elt + (** Closest common ancestor of the two elements in the proof forest *) + + val explain_step : 'e t -> elt -> (elt * 'e) option + (** Edge from the element to its parent in the proof forest; Returns + None if the element is a root of the forest. *) + + val explain : 'e t -> elt -> elt -> 'e list + (** [explain uf a b] returns a list of labels that justify why + [find uf a = find uf b]. Such labels were provided by [union]. *) + + val explain_distinct : 'e t -> elt -> elt -> elt * elt + (** [explain_distinct uf a b] gives the original pair [a', b'] that + made [a] and [b] distinct by calling [distinct a' b'] *) +end + +module IH = Hashtbl.Make(struct type t = int let equal i j = i = j let hash i = i end) + +module Make(X : ID) : S with type elt = X.t = struct + type elt = X.t + + type 'e t = { + mutable parent : int PArray.t; (* idx of the parent, with path compression *) + mutable data : elt_data option PArray.t; (* ID -> data for an element *) + inconsistent : (elt * elt * elt * elt) option; (* is the UF inconsistent? *) + forest : 'e edge PArray.t; (* explanation forest *) + } (** An instance of the union-find, ie a set of equivalence classes *) + and elt_data = { + elt : elt; + size : int; (* number of elements in the class *) + next : int; (* next element in equiv class *) + distinct : (int * elt * elt) list; (* classes distinct from this one, and why *) + } (** Data associated to the element. Most of it is only meaningful for + a representative (ie when elt = parent(elt)). *) + and 'e edge = + | EdgeNone + | EdgeTo of int * 'e + (** Edge of the proof forest, annotated with 'e *) + + let get_data uf id = + match PArray.get uf.data id with + | Some data -> data + | None -> assert false + + (** Create a union-find of the given size. *) + let create size = + { parent = PArray.init size (fun i -> i); + data = PArray.make size None; + inconsistent = None; + forest = PArray.make size EdgeNone; + } + + (* ensure the arrays are big enough for [id], and set [elt.(id) <- elt] *) + let ensure uf id elt = + if id >= PArray.length uf.data then begin + (* resize *) + let len = id + (id / 2) in + PArray.extend_init uf.parent len (fun i -> i); + PArray.extend uf.data len None; + PArray.extend uf.forest len EdgeNone; + end; + match PArray.get uf.data id with + | None -> + let data = { elt; size = 1; next=id; distinct=[]; } in + uf.data <- PArray.set uf.data id (Some data) + | Some _ -> () + + (* Find the ID of the root of the given ID *) + let rec find_root uf id = + let parent_id = PArray.get uf.parent id in + if id = parent_id + then id + else begin (* recurse *) + let root = find_root uf parent_id in + (* path compression *) + (if root <> parent_id then uf.parent <- PArray.set uf.parent id root); + root + end + + (** [find uf a] returns the current representative of [a] in the given + union-find structure [uf]. By default, [find uf a = a]. *) + let find uf elt = + let id = X.get_id elt in + if id >= PArray.length uf.parent + then elt (* not present *) + else + let id' = find_root uf id in + match PArray.get uf.data id' with + | Some data -> data.elt + | None -> assert (id = id'); elt (* not present *) + + (* merge i and j in the forest, with explanation why *) + let rec merge_forest forest i j why = + assert (i <> j); + (* invert path from i to roo, reverting all edges *) + let rec invert_path forest i = + match PArray.get forest i with + | EdgeNone -> forest (* reached root *) + | EdgeTo (i', e) -> + let forest' = invert_path forest i' in + PArray.set forest' i' (EdgeTo (i, e)) + in + let forest = invert_path forest i in + (* root of [j] is the new root of [i] and [j] *) + let forest = PArray.set forest i (EdgeTo (j, why)) in + forest + + (** Merge the class of [a] (whose representative is [ia'] into the class + of [b], whose representative is [ib'] *) + let merge_into uf a ia' b ib' why = + let data_a = get_data uf ia' in + let data_b = get_data uf ib' in + (* merge roots (a -> b, arbitrarily) *) + let parent = PArray.set uf.parent ia' ib' in + (* merge 'distinct' lists: distinct(b) <- distinct(b)+distinct(a) *) + let distinct' = List.rev_append data_a.distinct data_b.distinct in + (* size of the new equivalence class *) + let size' = data_a.size + data_b.size in + (* concatenation of circular linked lists (equivalence classes), + concatenation of distinct lists *) + let data_a' = {data_a with next=data_b.next; } in + let data_b' = {data_b with next=data_a.next; distinct=distinct'; size=size'; } in + let data = PArray.set uf.data ia' (Some data_a') in + let data = PArray.set data ib' (Some data_b') in + (* inconsistency check *) + let inconsistent = + List.fold_left + (fun acc (id, a', b') -> match acc with + | Some _ -> acc + | None when find_root uf id = ib' -> Some (a, b, a', b') (* found! *) + | None -> None) + None data_a.distinct + in + (* update forest *) + let forest = merge_forest uf.forest (X.get_id a) (X.get_id b) why in + { parent; data; inconsistent; forest; } + + (** [union uf a b why] returns an update of [uf] where [find a = find b], + the merge being justified by [why]. *) + let union uf a b why = + (if uf.inconsistent <> None + then raise (Invalid_argument "inconsistent uf")); + let ia = X.get_id a in + let ib = X.get_id b in + (* get sure we can access [ia] and [ib] in [uf] *) + ensure uf ia a; + ensure uf ib b; + (* indexes of roots of [a] and [b] *) + let ia' = find_root uf ia + and ib' = find_root uf ib in + if ia' = ib' + then uf (* no change *) + else + (* data associated to both representatives *) + let data_a = get_data uf ia' in + let data_b = get_data uf ib' in + (* merge the smaller class into the bigger class *) + if data_a.size > data_b.size + then merge_into uf b ib' a ia' why + else merge_into uf a ia' b ib' why + + (** Ensure that the two elements are distinct. May raise Inconsistent *) + let distinct uf a b = + (if uf.inconsistent <> None + then raise (Invalid_argument "inconsistent uf")); + let ia = X.get_id a in + let ib = X.get_id b in + ensure uf ia a; + ensure uf ib b; + (* representatives of a and b *) + let ia' = find_root uf ia in + let ib' = find_root uf ib in + (* update 'distinct' lists *) + let data_a = get_data uf ia' in + let data_a' = {data_a with distinct= (ib',a,b) :: data_a.distinct; } in + let data_b = get_data uf ib' in + let data_b' = {data_b with distinct= (ia',a,b) :: data_b.distinct; } in + let data = PArray.set uf.data ia' (Some data_a') in + let data = PArray.set data ib' (Some data_b') in + (* check inconsistency *) + let inconsistent = if ia' = ib' then Some (data_a.elt, data_b.elt, a, b) else None in + { uf with inconsistent; data; } + + let must_be_distinct uf a b = + let ia = X.get_id a in + let ib = X.get_id b in + let len = PArray.length uf.parent in + if ia >= len || ib >= len + then false (* no chance *) + else + (* representatives *) + let ia' = find_root uf ia in + let ib' = find_root uf ib in + (* list of equiv classes that must be != a *) + match PArray.get uf.data ia' with + | None -> false (* ia' not present *) + | Some data_a -> + List.exists (fun (id,_,_) -> find_root uf id = ib') data_a.distinct + + (** [fold_equiv_class uf a f acc] folds on [acc] and every element + that is congruent to [a] with [f]. *) + let fold_equiv_class uf a f acc = + let ia = X.get_id a in + if ia >= PArray.length uf.parent + then f acc a (* alone. *) + else + let rec traverse acc id = + match PArray.get uf.data id with + | None -> f acc a (* alone. *) + | Some data -> + let acc' = f acc data.elt in + let id' = data.next in + if id' = ia + then acc' (* traversed the whole list *) + else traverse acc' id' + in + traverse acc ia + + (** [iter_equiv_class uf a f] calls [f] on every element of [uf] that + is congruent to [a], including [a] itself. *) + let iter_equiv_class uf a f = + let ia = X.get_id a in + if ia >= PArray.length uf.parent + then f a (* alone. *) + else + let rec traverse id = + match PArray.get uf.data id with + | None -> f a (* alone. *) + | Some data -> + f data.elt; (* yield element *) + let id' = data.next in + if id' = ia + then () (* traversed the whole list *) + else traverse id' + in + traverse ia + + let inconsistent uf = uf.inconsistent + + (** Closest common ancestor of the two elements in the proof forest *) + let common_ancestor uf a b = + let forest = uf.forest in + let explored = IH.create 3 in + let rec recurse i j = + if i = j + then return i (* found *) + else if IH.mem explored i + then return i + else if IH.mem explored j + then return j + else + let i' = match PArray.get forest i with + | EdgeNone -> i + | EdgeTo (i', e) -> + IH.add explored i (); + i' + and j' = match PArray.get forest j with + | EdgeNone -> j + | EdgeTo (j', e) -> + IH.add explored j (); + j' + in + recurse i' j' + and return i = + (get_data uf i).elt (* return the element *) + in + recurse (X.get_id a) (X.get_id b) + + (** Edge from the element to its parent in the proof forest; Returns + None if the element is a root of the forest. *) + let explain_step uf a = + match PArray.get uf.forest (X.get_id a) with + | EdgeNone -> None + | EdgeTo (i, e) -> + let b = (get_data uf i).elt in + Some (b, e) + + (** [explain uf a b] returns a list of labels that justify why + [find uf a = find uf b]. Such labels were provided by [union]. *) + let explain uf a b = + (if find_root uf (X.get_id a) <> find_root uf (X.get_id b) + then failwith "Puf.explain: can only explain equal terms"); + let c = common_ancestor uf a b in + (* path from [x] to [c] *) + let rec build_path path x = + if (X.get_id x) = (X.get_id c) + then path + else match explain_step uf x with + | None -> assert false + | Some (x', e) -> + build_path (e::path) x' + in + build_path (build_path [] a) b + + (** [explain_distinct uf a b] gives the original pair [a', b'] that + made [a] and [b] distinct by calling [distinct a' b']. The + terms must be distinct, otherwise Failure is raised. *) + let explain_distinct uf a b = + let ia' = find_root uf (X.get_id a) in + let ib' = find_root uf (X.get_id b) in + let node_a = get_data uf ia' in + let rec search l = match l with + | [] -> failwith "Puf.explain_distinct: classes are not distinct" + | (ib'', a', b')::_ when ib' = ib'' -> (a', b') (* explanation found *) + | _ :: l' -> search l' + in + search node_a.distinct +end diff --git a/puf.mli b/puf.mli new file mode 100644 index 00000000..c44f4e2b --- /dev/null +++ b/puf.mli @@ -0,0 +1,138 @@ +(* +Copyright (c) 2013, Simon Cruanes +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. Redistributions in binary +form must reproduce the above copyright notice, this list of conditions and the +following disclaimer in the documentation and/or other materials provided with +the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*) + +(** {1 Functional (persistent) extensible union-find} *) + +(** {2 Persistent array} *) + +module PArray : sig + type 'a t + + val make : int -> 'a -> 'a t + + val init : int -> (int -> 'a) -> 'a t + + val get : 'a t -> int -> 'a + + val set : 'a t -> int -> 'a -> 'a t + + val length : 'a t -> int + + val fold_left : ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b + + val extend : 'a t -> int -> 'a -> unit + (** Extend [t] to the given [size], initializing new elements with [elt] *) + + val extend_init : 'a t -> int -> (int -> 'a) -> unit + (** Extend [t] to the given [size], initializing elements with [f] *) +end + +(** {2 Persistent Bitvector} *) + +module PBitVector : sig + type t + + val make : int -> t + (** Create a new bitvector of the given initial size (in words) *) + + val get : t -> int -> bool + (** [get bv i] gets the value of the [i]-th element of [bv] *) + + val set : t -> int -> bool -> t + (** [set bv i v] sets the value of the [i]-th element of [bv] to [v] *) + + val clear : t -> t + (** Bitvector with all bits set to 0 *) + + val set_true : t -> int -> t + val set_false : t -> int -> t +end + +(** {2 Type with unique identifier} *) + +module type ID = sig + type t + val get_id : t -> int + (** Unique integer ID for the element. Must be >= 0. *) +end + +(** {2 Persistent Union-Find with explanations} *) + +module type S = sig + type elt + (** Elements of the Union-find *) + + type 'e t + (** An instance of the union-find, ie a set of equivalence classes; It + is parametrized by the type of explanations. *) + + val create : int -> 'e t + (** Create a union-find of the given size. *) + + val find : 'e t -> elt -> elt + (** [find uf a] returns the current representative of [a] in the given + union-find structure [uf]. By default, [find uf a = a]. *) + + val union : 'e t -> elt -> elt -> 'e -> 'e t + (** [union uf a b why] returns an update of [uf] where [find a = find b], + the merge being justified by [why]. *) + + val distinct : 'e t -> elt -> elt -> 'e t + (** Ensure that the two elements are distinct. *) + + val must_be_distinct : _ t -> elt -> elt -> bool + (** Should the two elements be distinct? *) + + val fold_equiv_class : _ t -> elt -> ('a -> elt -> 'a) -> 'a -> 'a + (** [fold_equiv_class uf a f acc] folds on [acc] and every element + that is congruent to [a] with [f]. *) + + val iter_equiv_class : _ t -> elt -> (elt -> unit) -> unit + (** [iter_equiv_class uf a f] calls [f] on every element of [uf] that + is congruent to [a], including [a] itself. *) + + val inconsistent : _ t -> (elt * elt * elt * elt) option + (** Check whether the UF is inconsistent. It returns [Some (a, b, a', b')] + in case of inconsistency, where a = b, a = a' and b = b' by congruence, + and a' != b' was a call to [distinct]. *) + + val common_ancestor : 'e t -> elt -> elt -> elt + (** Closest common ancestor of the two elements in the proof forest *) + + val explain_step : 'e t -> elt -> (elt * 'e) option + (** Edge from the element to its parent in the proof forest; Returns + None if the element is a root of the forest. *) + + val explain : 'e t -> elt -> elt -> 'e list + (** [explain uf a b] returns a list of labels that justify why + [find uf a = find uf b]. Such labels were provided by [union]. *) + + val explain_distinct : 'e t -> elt -> elt -> elt * elt + (** [explain_distinct uf a b] gives the original pair [a', b'] that + made [a] and [b] distinct by calling [distinct a' b']. The + terms must be distinct, otherwise Failure is raised. *) +end + +module Make(X : ID) : S with type elt = X.t diff --git a/tests/run_tests.ml b/tests/run_tests.ml index 64537a6b..44377efd 100644 --- a/tests/run_tests.ml +++ b/tests/run_tests.ml @@ -6,6 +6,8 @@ let suite = "all_tests" >::: [ Test_pHashtbl.suite; Test_PersistentHashtbl.suite; + Test_cc.suite; + Test_puf.suite; Test_vector.suite; Test_gen.suite; Test_deque.suite; diff --git a/tests/test_cc.ml b/tests/test_cc.ml new file mode 100644 index 00000000..97b40b7a --- /dev/null +++ b/tests/test_cc.ml @@ -0,0 +1,93 @@ +(** Tests for congruence closure *) + +open OUnit + +let parse = CC.parse +let pp = CC.pp + +module CT = CC.StrTerm +module CC = CC.StrCC + +let test_add () = + let cc = CC.create 5 in + let t = parse "((a (b c)) d)" in + OUnit.assert_equal ~cmp:CT.eq t t; + let t2 = parse "(f (g (h x)))" in + OUnit.assert_bool "not eq" (not (CC.eq cc t t2)); + () + +let test_merge () = + let t1 = parse "((f (a b)) c)" in + let t2 = parse "((f (a b2)) c2)" in + (* Format.printf "t1=%a, t2=%a@." pp t1 pp t2; *) + let cc = CC.create 5 in + (* merge b and b2 *) + let cc = CC.merge cc (parse "b") (parse "b2") in + OUnit.assert_bool "not eq" (not (CC.eq cc t1 t2)); + OUnit.assert_bool "eq_sub" (CC.eq cc (parse "b") (parse "b2")); + (* merge c and c2 *) + let cc = CC.merge cc (parse "c") (parse "c2") in + OUnit.assert_bool "eq_sub" (CC.eq cc (parse "c") (parse "c2")); + (* Format.printf "t1=%a, t2=%a@." pp (CC.normalize cc t1) pp (CC.normalize cc t2); *) + OUnit.assert_bool "eq" (CC.eq cc t1 t2); + () + +let test_merge2 () = + let cc = CC.create 5 in + let cc = CC.distinct cc (parse "a") (parse "b") in + let cc = CC.merge cc (parse "(f c)") (parse "a") in + let cc = CC.merge cc (parse "(f d)") (parse "b") in + OUnit.assert_bool "not_eq" (not (CC.can_eq cc (parse "a") (parse "b"))); + OUnit.assert_bool "inconsistent" + (try ignore (CC.merge cc (parse "c") (parse "d")); false + with CC.Inconsistent _ -> true); + () + +let test_merge3 () = + let cc = CC.create 5 in + (* f^3(a) = a *) + let cc = CC.merge cc (parse "a") (parse "(f (f (f a)))") in + OUnit.assert_equal ~cmp:CT.eq (parse "(f (f a))") (parse "(f (f a))"); + (* f^4(a) = a *) + let cc = CC.merge cc (parse "(f (f (f (f (f a)))))") (parse "a") in + (* CC.iter_equiv_class cc (parse "a") (fun t -> Format.printf "a = %a@." pp t); *) + (* hence, f^5(a) = f^2(f^3(a)) = f^2(a), and f^3(a) = f(f^2(a)) = f(a) = a *) + OUnit.assert_bool "eq" (CC.eq cc (parse "a") (parse "(f a)")); + () + +let test_merge4 () = + let cc = CC.create 5 in + let cc = CC.merge cc (parse "true") (parse "(p (f (f (f (f (f (f a)))))))") in + let cc = CC.merge cc (parse "a") (parse "(f b)") in + let cc = CC.merge cc (parse "(f a)") (parse "b") in + OUnit.assert_bool "eq" (CC.eq cc (parse "a") (parse "(f (f (f (f (f (f a))))))")); + () + +let test_explain () = + let cc = CC.create 5 in + (* f^3(a) = a *) + let cc = CC.merge cc (parse "a") (parse "(f (f (f a)))") in + (* f^4(a) = a *) + let cc = CC.merge cc (parse "(f (f (f (f (f a)))))") (parse "a") in + (* Format.printf "t: %a@." pp (parse "(f (f (f (f (f a)))))"); *) + (* hence, f^5(a) = f^2(f^3(a)) = f^2(a), and f^3(a) = f(f^2(a)) = f(a) = a *) + let l = CC.explain cc (parse "a") (parse "(f (f a))") in + (* + List.iter + (function + | CC.ByMerge (a,b) -> Format.printf "merge %a %a@." pp a pp b + | CC.ByCongruence (a,b) -> Format.printf "congruence %a %a@." pp a pp b) + l; + *) + OUnit.assert_equal 4 (List.length l); + () + +let suite = + "test_cc" >::: + [ "test_add" >:: test_add; + "test_merge" >:: test_merge; + "test_merge2" >:: test_merge2; + "test_merge3" >:: test_merge3; + "test_merge4" >:: test_merge4; + "test_explain" >:: test_explain; + ] diff --git a/tests/test_puf.ml b/tests/test_puf.ml new file mode 100644 index 00000000..d5af04d3 --- /dev/null +++ b/tests/test_puf.ml @@ -0,0 +1,102 @@ +(** Tests for persistent union find *) + +open OUnit + +module P = Puf.Make(struct type t = int let get_id i = i end) + +let rec merge_list uf l = match l with + | [] | [_] -> uf + | x::((y::_) as l') -> + merge_list (P.union uf x y (x,y)) l' + +let test_union () = + let uf = P.create 5 in + let uf = merge_list uf [1;2;3] in + let uf = merge_list uf [5;6] in + OUnit.assert_equal (P.find uf 1) (P.find uf 2); + OUnit.assert_equal (P.find uf 1) (P.find uf 3); + OUnit.assert_equal (P.find uf 5) (P.find uf 6); + OUnit.assert_bool "noteq" ((P.find uf 1) <> (P.find uf 5)); + OUnit.assert_equal 10 (P.find uf 10); + let uf = P.union uf 1 5 (1,5) in + OUnit.assert_equal (P.find uf 2) (P.find uf 6); + () + +let test_iter () = + let uf = P.create 5 in + let uf = merge_list uf [1;2;3] in + let uf = merge_list uf [5;6] in + let uf = merge_list uf [10;11;12;13;2] in + (* equiv classes *) + let l1 = ref [] in + P.iter_equiv_class uf 1 (fun x -> l1 := x:: !l1); + let l2 = ref [] in + P.iter_equiv_class uf 5 (fun x -> l2 := x:: !l2); + OUnit.assert_equal [1;2;3;10;11;12;13] (List.sort compare !l1); + OUnit.assert_equal [5;6] (List.sort compare !l2); + () + +let test_distinct () = + let uf = P.create 5 in + let uf = merge_list uf [1;2;3] in + let uf = merge_list uf [5;6] in + let uf = P.distinct uf 1 5 in + OUnit.assert_equal None (P.inconsistent uf); + let uf' = P.union uf 2 6 (2,6) in + OUnit.assert_bool "inconsistent" + (match P.inconsistent uf' with | None -> false | Some _ -> true); + OUnit.assert_equal None (P.inconsistent uf); + let uf = P.union uf 1 10 (1,10) in + OUnit.assert_equal None (P.inconsistent uf); + () + +let test_big () = + let uf = P.create 5 in + let uf = ref uf in + for i = 0 to 100_000 do + uf := P.union !uf 1 i (1,i); + done; + let uf = !uf in + let n = P.fold_equiv_class uf 1 (fun acc _ -> acc+1) 0 in + OUnit.assert_equal ~printer:string_of_int 100_001 n; + () + +let test_explain () = + let uf = P.create 5 in + let uf = P.union uf 1 2 (1,2) in + let uf = P.union uf 1 3 (1,3) in + let uf = P.union uf 5 6 (5,6) in + let uf = P.union uf 4 5 (4,5) in + let uf = P.union uf 5 3 (5,3) in + OUnit.assert_bool "eq" (P.find uf 1 = P.find uf 5); + let l = P.explain uf 1 6 in + OUnit.assert_bool "not empty explanation" (l <> []); + (* List.iter (fun (a,b) -> Format.printf "%d, %d@." a b) l; *) + () + +(* +let bench () = + let run n = + let uf = P.create 5 in + let uf = ref uf in + for i = 0 to n do + uf := P.union !uf 1 i; + done + in + let res = Bench.bench_args run + [ "100", 100; + "10_000", 10_000; + ] + in Bench.summarize 1. res; + () +*) + +let suite = + "test_puf" >::: + [ "test_union" >:: test_union; + "test_iter" >:: test_iter; + "test_distinct" >:: test_distinct; + "test_big" >:: test_big; + "test_explain" >:: test_explain; + (* "bench" >:: bench; *) + ]