added CC (congruence closure with curryfied terms);

added Puf (persistent Union-Find, used in CC);
added their unit tests
This commit is contained in:
Simon Cruanes 2013-04-17 15:43:19 +02:00
parent dd434c9ef7
commit 7a0605d96f
7 changed files with 1453 additions and 0 deletions

494
cC.ml Normal file
View file

@ -0,0 +1,494 @@
(*
Copyright (c) 2013, Simon Cruanes
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer. Redistributions in binary
form must reproduce the above copyright notice, this list of conditions and the
following disclaimer in the documentation and/or other materials provided with
the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*)
(** {1 Functional Congruence Closure} *)
(** This implementation follows more or less the paper
"fast congruence closure and extensions" by Nieuwenhuis & Oliveras.
It uses semi-persistent data structures but still thrives for efficiency. *)
(** {2 Curryfied terms} *)
module type CurryfiedTerm = sig
type symbol
type t = private {
shape : shape; (** Which kind of term is it? *)
tag : int; (** Unique ID *)
} (** A curryfied term *)
and shape = private
| Const of symbol (** Constant *)
| Apply of t * t (** Curryfied application *)
val mk_const : symbol -> t
val mk_app : t -> t -> t
val get_id : t -> int
val eq : t -> t -> bool
val pp_skel : out_channel -> t -> unit (* print tags recursively *)
end
module Curryfy(X : Hashtbl.HashedType) = struct
type symbol = X.t
type t = {
shape : shape; (** Which kind of term is it? *)
tag : int; (** Unique ID *)
}
and shape =
| Const of symbol (** Constant *)
| Apply of t * t (** Curryfied application *)
type term = t
module WE = Weak.Make(struct
type t = term
let equal a b = match a.shape, b.shape with
| Const ia, Const ib -> X.equal ia ib
| Apply (a1,a2), Apply (b1,b2) -> a1 == b1 && a2 == b2
| _ -> false
let hash a = match a.shape with
| Const i -> X.hash i
| Apply (a, b) -> a.tag * 65599 + b.tag
end)
let __table = WE.create 10001
let count = ref 0
let hashcons t =
let t' = WE.merge __table t in
(if t == t' then incr count);
t'
let mk_const i =
let t = {shape=Const i; tag= !count; } in
hashcons t
let mk_app a b =
let t = {shape=Apply (a, b); tag= !count; } in
hashcons t
let get_id t = t.tag
let eq t1 t2 = t1 == t2
let rec pp_skel oc t = match t.shape with
| Const _ -> Printf.fprintf oc "%d" t.tag
| Apply (t1, t2) ->
Printf.fprintf oc "(%a %a):%d" pp_skel t1 pp_skel t2 t.tag
end
(** {2 Congruence Closure} *)
module type S = sig
module CT : CurryfiedTerm
type t
(** Congruence Closure instance *)
exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t
(** Exception raised when equality and inequality constraints are
inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in
the congruence closure, but [a' != b'] was asserted before. *)
val create : int -> t
(** Create an empty CC of given size *)
val eq : t -> CT.t -> CT.t -> bool
(** Check whether the two terms are equal *)
val merge : t -> CT.t -> CT.t -> t
(** Assert that the two terms are equal (may raise Inconsistent) *)
val distinct : t -> CT.t -> CT.t -> t
(** Assert that the two given terms are distinct (may raise Inconsistent) *)
type action =
| Merge of CT.t * CT.t
| Distinct of CT.t * CT.t
(** Action that can be performed on the CC *)
val do_action : t -> action -> t
(** Perform the given action (may raise Inconsistent) *)
val can_eq : t -> CT.t -> CT.t -> bool
(** Check whether the two terms can be equal *)
val iter_equiv_class : t -> CT.t -> (CT.t -> unit) -> unit
(** Iterate on terms that are congruent to the given term *)
type explanation =
| ByCongruence of CT.t * CT.t (* direct congruence of terms *)
| ByMerge of CT.t * CT.t (* user merge of terms *)
val explain : t -> CT.t -> CT.t -> explanation list
(** Explain why those two terms are equal (assuming they are,
otherwise raises Invalid_argument) by returning a list
of merges. *)
end
module Make(T : CurryfiedTerm) = struct
module CT = T
module BV = Puf.PBitVector
module Puf = Puf.Make(CT)
module HashedCT = struct
type t = CT.t
let equal t1 t2 = t1.CT.tag = t2.CT.tag
let hash t = t.CT.tag
end
(* Persistent Hashtable on curryfied terms *)
module THashtbl = PersistentHashtbl.Make(HashedCT)
(* Persistent Hashtable on pairs of curryfied terms *)
module T2Hashtbl = PersistentHashtbl.Make(struct
type t = CT.t * CT.t
let equal (t1,t1') (t2,t2') = t1.CT.tag = t2.CT.tag && t1'.CT.tag = t2'.CT.tag
let hash (t,t') = t.CT.tag * 65599 + t'.CT.tag
end)
type t = {
uf : pending_eqn Puf.t; (* representatives for terms *)
defined : BV.t; (* is the term defined? *)
use : eqn list THashtbl.t; (* for all repr a, a -> all a@b=c and b@a=c *)
lookup : eqn T2Hashtbl.t; (* for all reprs a,b, some a@b=c (if any) *)
inconsistent : (CT.t * CT.t) option;
} (** Congruence Closure data structure *)
and eqn =
| EqnSimple of CT.t * CT.t (* t1 = t2 *)
| EqnApply of CT.t * CT.t * CT.t (* (t1 @ t2) = t3 *)
(** Equation between two terms *)
and pending_eqn =
| PendingSimple of eqn
| PendingDouble of eqn * eqn
exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t
(** Exception raised when equality and inequality constraints are
inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in
the congruence closure, but [a' != b'] was asserted before. *)
(** Create an empty CC of given size *)
let create size =
{ uf = Puf.create size;
defined = BV.make 3;
use = THashtbl.create size;
lookup = T2Hashtbl.create size;
inconsistent = None;
}
let mem cc t =
BV.get cc.defined t.CT.tag
let is_const t = match t.CT.shape with
| CT.Const _ -> true
| CT.Apply _ -> false
(** Merge equations in the congruence closure structure. [q] is a list
of [eqn], processed in FIFO order. May raise Inconsistent. *)
let rec merge cc eqn = match eqn with
| EqnSimple (a, b) ->
(* a=b, just propagate *)
propagate cc [PendingSimple eqn]
| EqnApply (a1, a2, a) ->
(* (a1 @ a2) = a *)
let a1' = Puf.find cc.uf a1 in
let a2' = Puf.find cc.uf a2 in
begin try
(* eqn' is (b1 @ b2) = b for some b1=a1', b2=a2' *)
let eqn' = T2Hashtbl.find cc.lookup (a1', a2') in
(* merge a and b because of eqn and eqn' *)
propagate cc [PendingDouble (eqn, eqn')]
with Not_found ->
(* remember that a1' @ a2' = a *)
let lookup = T2Hashtbl.replace cc.lookup (a1', a2') eqn in
let use_a1' = try THashtbl.find cc.use a1' with Not_found -> [] in
let use_a2' = try THashtbl.find cc.use a2' with Not_found -> [] in
let use = THashtbl.replace cc.use a1' (eqn::use_a1') in
let use = THashtbl.replace use a2' (eqn::use_a2') in
{ cc with use; lookup; }
end
(* propagate: merge pending equations *)
and propagate cc eqns =
let pending = ref eqns in
let uf = ref cc.uf in
let use = ref cc.use in
let lookup = ref cc.lookup in
(* process each pending equation *)
while !pending <> [] do
let eqn = List.hd !pending in
pending := List.tl !pending;
(* extract the two merged terms *)
let a, b = match eqn with
| PendingSimple (EqnSimple (a, b)) -> a, b
| PendingDouble (EqnApply (a1,a2,a), EqnApply (b1,b2,b)) -> a, b
| _ -> assert false
in
let a' = Puf.find !uf a in
let b' = Puf.find !uf b in
if not (CT.eq a' b') then begin
let use_a' = try THashtbl.find !use a' with Not_found -> [] in
let use_b' = try THashtbl.find !use b' with Not_found -> [] in
(* merge a and b's equivalence classes *)
(* Format.printf "merge %d %d@." a.CT.tag b.CT.tag; *)
uf := Puf.union !uf a b eqn;
(* check which of [a'] and [b'] is the new representative. [repr] is
the new representative, and [other] is the former representative *)
let repr = Puf.find !uf a' in
let use_repr = ref (if CT.eq repr a' then use_a' else use_b') in
let use_other = if CT.eq repr a' then use_b' else use_a' in
(* consider all c1@c2=c in use(a') *)
List.iter
(fun eqn -> match eqn with
| EqnSimple _ -> ()
| EqnApply (c1, c2, c) ->
let c1' = Puf.find !uf c1 in
let c2' = Puf.find !uf c2 in
begin try
let eqn' = T2Hashtbl.find !lookup (c1', c2') in
(* merge eqn with eqn', by congruence *)
pending := (PendingDouble (eqn,eqn')) :: !pending
with Not_found ->
lookup := T2Hashtbl.replace !lookup (c1', c2') eqn;
use_repr := eqn :: !use_repr;
end)
use_other;
(* update use list of [repr] *)
use := THashtbl.replace !use repr !use_repr;
(* check for inconsistencies *)
match Puf.inconsistent !uf with
| None -> () (* consistent *)
| Some (t1, t2, t1', t2') ->
(* inconsistent *)
let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in
raise (Inconsistent (cc, t1, t2, t1', t2'))
end
done;
let cc = { cc with use= !use; lookup= !lookup; uf= !uf; } in
cc
(** Add the given term to the CC *)
let rec add cc t =
match t.CT.shape with
| CT.Const _ ->
cc (* always trivially defined *)
| CT.Apply (t1, t2) ->
if BV.get cc.defined t.CT.tag
then cc (* already defined *)
else begin
(* note that [t] is defined, add it to the UF to avoid GC *)
let defined = BV.set_true cc.defined t.CT.tag in
let cc = {cc with defined; } in
(* recursive add. invariant: if a term is added, then its subterms
also are (hence the base case of constants or already added terms). *)
let cc = add cc t1 in
let cc = add cc t2 in
let cc = merge cc (EqnApply (t1, t2, t)) in
cc
end
(** Check whether the two terms are equal *)
let eq cc t1 t2 =
let cc = add (add cc t1) t2 in
let t1' = Puf.find cc.uf t1 in
let t2' = Puf.find cc.uf t2 in
CT.eq t1' t2'
(** Assert that the two terms are equal (may raise Inconsistent) *)
let merge cc t1 t2 =
let cc = add (add cc t1) t2 in
merge cc (EqnSimple (t1, t2))
(** Assert that the two given terms are distinct (may raise Inconsistent) *)
let distinct cc t1 t2 =
let cc = add (add cc t1) t2 in
let t1' = Puf.find cc.uf t1 in
let t2' = Puf.find cc.uf t2 in
if CT.eq t1' t2'
then raise (Inconsistent (cc, t1', t2', t1, t2)) (* they are equal, fail *)
else
(* remember that they should not become equal *)
let uf = Puf.distinct cc.uf t1 t2 in
{ cc with uf; }
type action =
| Merge of CT.t * CT.t
| Distinct of CT.t * CT.t
(** Action that can be performed on the CC *)
let do_action cc action = match action with
| Merge (t1, t2) -> merge cc t1 t2
| Distinct (t1, t2) -> distinct cc t1 t2
(** Check whether the two terms can be equal *)
let can_eq cc t1 t2 =
let cc = add (add cc t1) t2 in
not (Puf.must_be_distinct cc.uf t1 t2)
(** Iterate on terms that are congruent to the given term *)
let iter_equiv_class cc t f =
Puf.iter_equiv_class cc.uf t f
(** {3 Auxilliary Union-find for explanations} *)
module SparseUF = struct
module H = Hashtbl.Make(HashedCT)
type t = uf_ref H.t
and uf_ref = {
term : CT.t;
mutable parent : CT.t;
mutable highest_node : CT.t;
} (** Union-find reference *)
let create size = H.create size
let get_ref uf t =
try H.find uf t
with Not_found ->
let r_t = { term=t; parent=t; highest_node=t; } in
H.add uf t r_t;
r_t
let rec find_ref uf r_t =
if CT.eq r_t.parent r_t.term
then r_t (* fixpoint *)
else
let r_t' = get_ref uf r_t.parent in
find_ref uf r_t' (* recurse (no path compression) *)
let find uf t =
try
let r_t = H.find uf t in
(find_ref uf r_t).term
with Not_found ->
t
let eq uf t1 t2 =
CT.eq (find uf t1) (find uf t2)
let highest_node uf t =
try
let r_t = H.find uf t in
(find_ref uf r_t).highest_node
with Not_found ->
t
(* oriented union (t1 -> t2), assuming t2 is "higher" than t1 *)
let union uf t1 t2 =
let r_t1' = find_ref uf (get_ref uf t1) in
let r_t2' = find_ref uf (get_ref uf t2) in
r_t1'.parent <- r_t2'.term
end
(** {3 Producing explanations} *)
type explanation =
| ByCongruence of CT.t * CT.t (* direct congruence of terms *)
| ByMerge of CT.t * CT.t (* user merge of terms *)
(** Explain why those two terms are equal (they must be) *)
let explain cc t1 t2 =
assert (eq cc t1 t2);
(* keeps track of which equalities are already explained *)
let explained = SparseUF.create 5 in
let explanations = ref [] in
(* equations waiting to be explained *)
let pending = Queue.create () in
Queue.push (t1,t2) pending;
(* explain why a=c, where c is the root of the proof forest a belongs to *)
let rec explain_along a c =
let a' = SparseUF.highest_node explained a in
if CT.eq a' c then ()
else match Puf.explain_step cc.uf a' with
| None -> assert (CT.eq a' c)
| Some (b, e) ->
(* a->b on the path from a to c *)
begin match e with
| PendingSimple (EqnSimple (a',b')) ->
explanations := ByMerge (a', b') :: !explanations
| PendingDouble (EqnApply (a1, a2, a'), EqnApply (b1, b2, b')) ->
explanations := ByCongruence (a', b') :: !explanations;
Queue.push (a1, b1) pending;
Queue.push (a2, b2) pending;
| _ -> assert false
end;
(* now a' = b is justified *)
SparseUF.union explained a' b;
(* recurse *)
let new_a = SparseUF.highest_node explained b in
explain_along new_a c
in
(* process pending equations *)
while not (Queue.is_empty pending) do
let a, b = Queue.pop pending in
if SparseUF.eq explained a b
then ()
else begin
let c = Puf.common_ancestor cc.uf a b in
explain_along a c;
explain_along b c;
end
done;
!explanations
end
module StrTerm = Curryfy(struct
type t = string
let equal s1 s2 = s1 = s2
let hash s = Hashtbl.hash s
end)
module StrCC = Make(StrTerm)
let lex str =
let lexer = Genlex.make_lexer ["("; ")"] in
lexer (Stream.of_string str)
let parse str =
let stream = lex str in
let rec parse_term () =
match Stream.peek stream with
| Some (Genlex.Kwd "(") ->
Stream.junk stream;
let t1 = parse_term () in
let t2 = parse_term () in
begin match Stream.peek stream with
| Some (Genlex.Kwd ")") ->
Stream.junk stream;
StrTerm.mk_app t1 t2 (* end apply *)
| _ -> raise (Stream.Error "expected )")
end
| Some (Genlex.Ident s) ->
Stream.junk stream;
StrTerm.mk_const s
| _ -> raise (Stream.Error "expected term")
in
parse_term ()
let rec pp fmt t =
match t.StrTerm.shape with
| StrTerm.Const s ->
Format.fprintf fmt "%s:%d" s t.StrTerm.tag
| StrTerm.Apply (t1, t2) ->
Format.fprintf fmt "(%a %a):%d" pp t1 pp t2 t.StrTerm.tag

105
cC.mli Normal file
View file

@ -0,0 +1,105 @@
(*
Copyright (c) 2013, Simon Cruanes
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer. Redistributions in binary
form must reproduce the above copyright notice, this list of conditions and the
following disclaimer in the documentation and/or other materials provided with
the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*)
(** {1 Functional Congruence Closure} *)
(** {2 Curryfied terms} *)
module type CurryfiedTerm = sig
type symbol
type t = private {
shape : shape; (** Which kind of term is it? *)
tag : int; (** Unique ID *)
} (** A curryfied term *)
and shape = private
| Const of symbol (** Constant *)
| Apply of t * t (** Curryfied application *)
val mk_const : symbol -> t
val mk_app : t -> t -> t
val get_id : t -> int
val eq : t -> t -> bool
val pp_skel : out_channel -> t -> unit (* print tags recursively *)
end
module Curryfy(X : Hashtbl.HashedType) : CurryfiedTerm with type symbol = X.t
(** {2 Congruence Closure} *)
module type S = sig
module CT : CurryfiedTerm
type t
(** Congruence Closure instance *)
exception Inconsistent of t * CT.t * CT.t * CT.t * CT.t
(** Exception raised when equality and inequality constraints are
inconsistent. [Inconsistent (a, b, a', b')] means that [a=b, a=a', b=b'] in
the congruence closure, but [a' != b'] was asserted before. *)
val create : int -> t
(** Create an empty CC of given size *)
val eq : t -> CT.t -> CT.t -> bool
(** Check whether the two terms are equal *)
val merge : t -> CT.t -> CT.t -> t
(** Assert that the two terms are equal (may raise Inconsistent) *)
val distinct : t -> CT.t -> CT.t -> t
(** Assert that the two given terms are distinct (may raise Inconsistent) *)
type action =
| Merge of CT.t * CT.t
| Distinct of CT.t * CT.t
(** Action that can be performed on the CC *)
val do_action : t -> action -> t
(** Perform the given action (may raise Inconsistent) *)
val can_eq : t -> CT.t -> CT.t -> bool
(** Check whether the two terms can be equal *)
val iter_equiv_class : t -> CT.t -> (CT.t -> unit) -> unit
(** Iterate on terms that are congruent to the given term *)
type explanation =
| ByCongruence of CT.t * CT.t (* direct congruence of terms *)
| ByMerge of CT.t * CT.t (* user merge of terms *)
val explain : t -> CT.t -> CT.t -> explanation list
(** Explain why those two terms are equal (assuming they are,
otherwise raises Invalid_argument) by returning a list
of merges. *)
end
module Make(T : CurryfiedTerm) : S with module CT = T
module StrTerm : CurryfiedTerm with type symbol = string
module StrCC : S with module CT = StrTerm
val parse : string -> StrTerm.t
val pp : Format.formatter -> StrTerm.t -> unit

519
puf.ml Normal file
View file

@ -0,0 +1,519 @@
(*
Copyright (c) 2013, Simon Cruanes
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer. Redistributions in binary
form must reproduce the above copyright notice, this list of conditions and the
following disclaimer in the documentation and/or other materials provided with
the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*)
(** {1 Functional (persistent) extensible union-find} *)
(** {2 Persistent array} *)
module PArray = struct
type 'a t = 'a zipper ref
and 'a zipper =
| Array of 'a array
| Diff of int * 'a * 'a t
(* XXX maybe having a snapshot of the array from point to point may help? *)
let make size elt =
let a = Array.create size elt in
ref (Array a)
let init size f =
let a = Array.init size f in
ref (Array a)
(** Recover the given version of the shared array. Returns the array
itself. *)
let rec reroot t =
match !t with
| Array a -> a
| Diff (i, v, t') ->
begin
let a = reroot t' in
let v' = a.(i) in
t' := Diff (i, v', t);
a.(i) <- v;
t := Array a;
a
end
let get t i =
match !t with
| Array a -> a.(i)
| Diff _ ->
let a = reroot t in
a.(i)
let set t i v =
let a =
match !t with
| Array a -> a
| Diff _ -> reroot t in
let v' = a.(i) in
if v == v'
then t (* no change *)
else begin
let t' = ref (Array a) in
a.(i) <- v;
t := Diff (i, v', t');
t' (* create new array *)
end
let rec length t =
match !t with
| Array a -> Array.length a
| Diff (_, _, t') -> length t'
(** Extend [t] to the given [size], initializing new elements with [elt] *)
let extend t size elt =
let a = match !t with
| Array a -> a
| _ -> reroot t in
if size > Array.length a
then begin (* resize: create bigger array *)
let size = min Sys.max_array_length size in
let a' = Array.make size elt in
(* copy old part *)
Array.blit a 0 a' 0 (Array.length a);
t := Array a'
end
(** Extend [t] to the given [size], initializing elements with [f] *)
let extend_init t size f =
let a = match !t with
| Array a -> a
| _ -> reroot t in
if size > Array.length a
then begin (* resize: create bigger array *)
let size = min Sys.max_array_length size in
let a' = Array.init size f in
(* copy old part *)
Array.blit a 0 a' 0 (Array.length a);
t := Array a'
end
let fold_left f acc t =
let a = reroot t in
Array.fold_left f acc a
end
(** {2 Persistent Bitvector} *)
module PBitVector = struct
type t = int PArray.t
let width = Sys.word_size - 1 (* number of usable bits in an integer *)
let make size = PArray.make size 0
let ensure bv offset =
if offset >= PArray.length bv
then
let len = offset + offset/2 + 1 in
PArray.extend bv len 0
else ()
(** [get bv i] gets the value of the [i]-th element of [bv] *)
let get bv i =
let offset = i / width in
let bit = i mod width in
ensure bv offset;
let bits = PArray.get bv offset in
(bits land (1 lsl bit)) <> 0
(** [set bv i v] sets the value of the [i]-th element of [bv] to [v] *)
let set bv i v =
let offset = i / width in
let bit = i mod width in
ensure bv offset;
let bits = PArray.get bv offset in
let bits' =
if v
then bits lor (1 lsl bit)
else bits land (lnot (1 lsl bit))
in
PArray.set bv offset bits'
(** Bitvector with all bits set to 0 *)
let clear bv = make 5
let set_true bv i = set bv i true
let set_false bv i = set bv i false
end
(** {2 Type with unique identifier} *)
module type ID = sig
type t
val get_id : t -> int
end
(** {2 Persistent Union-Find with explanations} *)
module type S = sig
type elt
(** Elements of the Union-find *)
type 'e t
(** An instance of the union-find, ie a set of equivalence classes; It
is parametrized by the type of explanations. *)
val create : int -> 'e t
(** Create a union-find of the given size. *)
val find : 'e t -> elt -> elt
(** [find uf a] returns the current representative of [a] in the given
union-find structure [uf]. By default, [find uf a = a]. *)
val union : 'e t -> elt -> elt -> 'e -> 'e t
(** [union uf a b why] returns an update of [uf] where [find a = find b],
the merge being justified by [why]. *)
val distinct : 'e t -> elt -> elt -> 'e t
(** Ensure that the two elements are distinct. *)
val must_be_distinct : _ t -> elt -> elt -> bool
(** Should the two elements be distinct? *)
val fold_equiv_class : _ t -> elt -> ('a -> elt -> 'a) -> 'a -> 'a
(** [fold_equiv_class uf a f acc] folds on [acc] and every element
that is congruent to [a] with [f]. *)
val iter_equiv_class : _ t -> elt -> (elt -> unit) -> unit
(** [iter_equiv_class uf a f] calls [f] on every element of [uf] that
is congruent to [a], including [a] itself. *)
val inconsistent : _ t -> (elt * elt * elt * elt) option
(** Check whether the UF is inconsistent. It returns [Some (a, b, a', b')]
in case of inconsistency, where a = b, a = a' and b = b' by congruence,
and a' != b' was a call to [distinct]. *)
val common_ancestor : 'e t -> elt -> elt -> elt
(** Closest common ancestor of the two elements in the proof forest *)
val explain_step : 'e t -> elt -> (elt * 'e) option
(** Edge from the element to its parent in the proof forest; Returns
None if the element is a root of the forest. *)
val explain : 'e t -> elt -> elt -> 'e list
(** [explain uf a b] returns a list of labels that justify why
[find uf a = find uf b]. Such labels were provided by [union]. *)
val explain_distinct : 'e t -> elt -> elt -> elt * elt
(** [explain_distinct uf a b] gives the original pair [a', b'] that
made [a] and [b] distinct by calling [distinct a' b'] *)
end
module IH = Hashtbl.Make(struct type t = int let equal i j = i = j let hash i = i end)
module Make(X : ID) : S with type elt = X.t = struct
type elt = X.t
type 'e t = {
mutable parent : int PArray.t; (* idx of the parent, with path compression *)
mutable data : elt_data option PArray.t; (* ID -> data for an element *)
inconsistent : (elt * elt * elt * elt) option; (* is the UF inconsistent? *)
forest : 'e edge PArray.t; (* explanation forest *)
} (** An instance of the union-find, ie a set of equivalence classes *)
and elt_data = {
elt : elt;
size : int; (* number of elements in the class *)
next : int; (* next element in equiv class *)
distinct : (int * elt * elt) list; (* classes distinct from this one, and why *)
} (** Data associated to the element. Most of it is only meaningful for
a representative (ie when elt = parent(elt)). *)
and 'e edge =
| EdgeNone
| EdgeTo of int * 'e
(** Edge of the proof forest, annotated with 'e *)
let get_data uf id =
match PArray.get uf.data id with
| Some data -> data
| None -> assert false
(** Create a union-find of the given size. *)
let create size =
{ parent = PArray.init size (fun i -> i);
data = PArray.make size None;
inconsistent = None;
forest = PArray.make size EdgeNone;
}
(* ensure the arrays are big enough for [id], and set [elt.(id) <- elt] *)
let ensure uf id elt =
if id >= PArray.length uf.data then begin
(* resize *)
let len = id + (id / 2) in
PArray.extend_init uf.parent len (fun i -> i);
PArray.extend uf.data len None;
PArray.extend uf.forest len EdgeNone;
end;
match PArray.get uf.data id with
| None ->
let data = { elt; size = 1; next=id; distinct=[]; } in
uf.data <- PArray.set uf.data id (Some data)
| Some _ -> ()
(* Find the ID of the root of the given ID *)
let rec find_root uf id =
let parent_id = PArray.get uf.parent id in
if id = parent_id
then id
else begin (* recurse *)
let root = find_root uf parent_id in
(* path compression *)
(if root <> parent_id then uf.parent <- PArray.set uf.parent id root);
root
end
(** [find uf a] returns the current representative of [a] in the given
union-find structure [uf]. By default, [find uf a = a]. *)
let find uf elt =
let id = X.get_id elt in
if id >= PArray.length uf.parent
then elt (* not present *)
else
let id' = find_root uf id in
match PArray.get uf.data id' with
| Some data -> data.elt
| None -> assert (id = id'); elt (* not present *)
(* merge i and j in the forest, with explanation why *)
let rec merge_forest forest i j why =
assert (i <> j);
(* invert path from i to roo, reverting all edges *)
let rec invert_path forest i =
match PArray.get forest i with
| EdgeNone -> forest (* reached root *)
| EdgeTo (i', e) ->
let forest' = invert_path forest i' in
PArray.set forest' i' (EdgeTo (i, e))
in
let forest = invert_path forest i in
(* root of [j] is the new root of [i] and [j] *)
let forest = PArray.set forest i (EdgeTo (j, why)) in
forest
(** Merge the class of [a] (whose representative is [ia'] into the class
of [b], whose representative is [ib'] *)
let merge_into uf a ia' b ib' why =
let data_a = get_data uf ia' in
let data_b = get_data uf ib' in
(* merge roots (a -> b, arbitrarily) *)
let parent = PArray.set uf.parent ia' ib' in
(* merge 'distinct' lists: distinct(b) <- distinct(b)+distinct(a) *)
let distinct' = List.rev_append data_a.distinct data_b.distinct in
(* size of the new equivalence class *)
let size' = data_a.size + data_b.size in
(* concatenation of circular linked lists (equivalence classes),
concatenation of distinct lists *)
let data_a' = {data_a with next=data_b.next; } in
let data_b' = {data_b with next=data_a.next; distinct=distinct'; size=size'; } in
let data = PArray.set uf.data ia' (Some data_a') in
let data = PArray.set data ib' (Some data_b') in
(* inconsistency check *)
let inconsistent =
List.fold_left
(fun acc (id, a', b') -> match acc with
| Some _ -> acc
| None when find_root uf id = ib' -> Some (a, b, a', b') (* found! *)
| None -> None)
None data_a.distinct
in
(* update forest *)
let forest = merge_forest uf.forest (X.get_id a) (X.get_id b) why in
{ parent; data; inconsistent; forest; }
(** [union uf a b why] returns an update of [uf] where [find a = find b],
the merge being justified by [why]. *)
let union uf a b why =
(if uf.inconsistent <> None
then raise (Invalid_argument "inconsistent uf"));
let ia = X.get_id a in
let ib = X.get_id b in
(* get sure we can access [ia] and [ib] in [uf] *)
ensure uf ia a;
ensure uf ib b;
(* indexes of roots of [a] and [b] *)
let ia' = find_root uf ia
and ib' = find_root uf ib in
if ia' = ib'
then uf (* no change *)
else
(* data associated to both representatives *)
let data_a = get_data uf ia' in
let data_b = get_data uf ib' in
(* merge the smaller class into the bigger class *)
if data_a.size > data_b.size
then merge_into uf b ib' a ia' why
else merge_into uf a ia' b ib' why
(** Ensure that the two elements are distinct. May raise Inconsistent *)
let distinct uf a b =
(if uf.inconsistent <> None
then raise (Invalid_argument "inconsistent uf"));
let ia = X.get_id a in
let ib = X.get_id b in
ensure uf ia a;
ensure uf ib b;
(* representatives of a and b *)
let ia' = find_root uf ia in
let ib' = find_root uf ib in
(* update 'distinct' lists *)
let data_a = get_data uf ia' in
let data_a' = {data_a with distinct= (ib',a,b) :: data_a.distinct; } in
let data_b = get_data uf ib' in
let data_b' = {data_b with distinct= (ia',a,b) :: data_b.distinct; } in
let data = PArray.set uf.data ia' (Some data_a') in
let data = PArray.set data ib' (Some data_b') in
(* check inconsistency *)
let inconsistent = if ia' = ib' then Some (data_a.elt, data_b.elt, a, b) else None in
{ uf with inconsistent; data; }
let must_be_distinct uf a b =
let ia = X.get_id a in
let ib = X.get_id b in
let len = PArray.length uf.parent in
if ia >= len || ib >= len
then false (* no chance *)
else
(* representatives *)
let ia' = find_root uf ia in
let ib' = find_root uf ib in
(* list of equiv classes that must be != a *)
match PArray.get uf.data ia' with
| None -> false (* ia' not present *)
| Some data_a ->
List.exists (fun (id,_,_) -> find_root uf id = ib') data_a.distinct
(** [fold_equiv_class uf a f acc] folds on [acc] and every element
that is congruent to [a] with [f]. *)
let fold_equiv_class uf a f acc =
let ia = X.get_id a in
if ia >= PArray.length uf.parent
then f acc a (* alone. *)
else
let rec traverse acc id =
match PArray.get uf.data id with
| None -> f acc a (* alone. *)
| Some data ->
let acc' = f acc data.elt in
let id' = data.next in
if id' = ia
then acc' (* traversed the whole list *)
else traverse acc' id'
in
traverse acc ia
(** [iter_equiv_class uf a f] calls [f] on every element of [uf] that
is congruent to [a], including [a] itself. *)
let iter_equiv_class uf a f =
let ia = X.get_id a in
if ia >= PArray.length uf.parent
then f a (* alone. *)
else
let rec traverse id =
match PArray.get uf.data id with
| None -> f a (* alone. *)
| Some data ->
f data.elt; (* yield element *)
let id' = data.next in
if id' = ia
then () (* traversed the whole list *)
else traverse id'
in
traverse ia
let inconsistent uf = uf.inconsistent
(** Closest common ancestor of the two elements in the proof forest *)
let common_ancestor uf a b =
let forest = uf.forest in
let explored = IH.create 3 in
let rec recurse i j =
if i = j
then return i (* found *)
else if IH.mem explored i
then return i
else if IH.mem explored j
then return j
else
let i' = match PArray.get forest i with
| EdgeNone -> i
| EdgeTo (i', e) ->
IH.add explored i ();
i'
and j' = match PArray.get forest j with
| EdgeNone -> j
| EdgeTo (j', e) ->
IH.add explored j ();
j'
in
recurse i' j'
and return i =
(get_data uf i).elt (* return the element *)
in
recurse (X.get_id a) (X.get_id b)
(** Edge from the element to its parent in the proof forest; Returns
None if the element is a root of the forest. *)
let explain_step uf a =
match PArray.get uf.forest (X.get_id a) with
| EdgeNone -> None
| EdgeTo (i, e) ->
let b = (get_data uf i).elt in
Some (b, e)
(** [explain uf a b] returns a list of labels that justify why
[find uf a = find uf b]. Such labels were provided by [union]. *)
let explain uf a b =
(if find_root uf (X.get_id a) <> find_root uf (X.get_id b)
then failwith "Puf.explain: can only explain equal terms");
let c = common_ancestor uf a b in
(* path from [x] to [c] *)
let rec build_path path x =
if (X.get_id x) = (X.get_id c)
then path
else match explain_step uf x with
| None -> assert false
| Some (x', e) ->
build_path (e::path) x'
in
build_path (build_path [] a) b
(** [explain_distinct uf a b] gives the original pair [a', b'] that
made [a] and [b] distinct by calling [distinct a' b']. The
terms must be distinct, otherwise Failure is raised. *)
let explain_distinct uf a b =
let ia' = find_root uf (X.get_id a) in
let ib' = find_root uf (X.get_id b) in
let node_a = get_data uf ia' in
let rec search l = match l with
| [] -> failwith "Puf.explain_distinct: classes are not distinct"
| (ib'', a', b')::_ when ib' = ib'' -> (a', b') (* explanation found *)
| _ :: l' -> search l'
in
search node_a.distinct
end

138
puf.mli Normal file
View file

@ -0,0 +1,138 @@
(*
Copyright (c) 2013, Simon Cruanes
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer. Redistributions in binary
form must reproduce the above copyright notice, this list of conditions and the
following disclaimer in the documentation and/or other materials provided with
the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*)
(** {1 Functional (persistent) extensible union-find} *)
(** {2 Persistent array} *)
module PArray : sig
type 'a t
val make : int -> 'a -> 'a t
val init : int -> (int -> 'a) -> 'a t
val get : 'a t -> int -> 'a
val set : 'a t -> int -> 'a -> 'a t
val length : 'a t -> int
val fold_left : ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b
val extend : 'a t -> int -> 'a -> unit
(** Extend [t] to the given [size], initializing new elements with [elt] *)
val extend_init : 'a t -> int -> (int -> 'a) -> unit
(** Extend [t] to the given [size], initializing elements with [f] *)
end
(** {2 Persistent Bitvector} *)
module PBitVector : sig
type t
val make : int -> t
(** Create a new bitvector of the given initial size (in words) *)
val get : t -> int -> bool
(** [get bv i] gets the value of the [i]-th element of [bv] *)
val set : t -> int -> bool -> t
(** [set bv i v] sets the value of the [i]-th element of [bv] to [v] *)
val clear : t -> t
(** Bitvector with all bits set to 0 *)
val set_true : t -> int -> t
val set_false : t -> int -> t
end
(** {2 Type with unique identifier} *)
module type ID = sig
type t
val get_id : t -> int
(** Unique integer ID for the element. Must be >= 0. *)
end
(** {2 Persistent Union-Find with explanations} *)
module type S = sig
type elt
(** Elements of the Union-find *)
type 'e t
(** An instance of the union-find, ie a set of equivalence classes; It
is parametrized by the type of explanations. *)
val create : int -> 'e t
(** Create a union-find of the given size. *)
val find : 'e t -> elt -> elt
(** [find uf a] returns the current representative of [a] in the given
union-find structure [uf]. By default, [find uf a = a]. *)
val union : 'e t -> elt -> elt -> 'e -> 'e t
(** [union uf a b why] returns an update of [uf] where [find a = find b],
the merge being justified by [why]. *)
val distinct : 'e t -> elt -> elt -> 'e t
(** Ensure that the two elements are distinct. *)
val must_be_distinct : _ t -> elt -> elt -> bool
(** Should the two elements be distinct? *)
val fold_equiv_class : _ t -> elt -> ('a -> elt -> 'a) -> 'a -> 'a
(** [fold_equiv_class uf a f acc] folds on [acc] and every element
that is congruent to [a] with [f]. *)
val iter_equiv_class : _ t -> elt -> (elt -> unit) -> unit
(** [iter_equiv_class uf a f] calls [f] on every element of [uf] that
is congruent to [a], including [a] itself. *)
val inconsistent : _ t -> (elt * elt * elt * elt) option
(** Check whether the UF is inconsistent. It returns [Some (a, b, a', b')]
in case of inconsistency, where a = b, a = a' and b = b' by congruence,
and a' != b' was a call to [distinct]. *)
val common_ancestor : 'e t -> elt -> elt -> elt
(** Closest common ancestor of the two elements in the proof forest *)
val explain_step : 'e t -> elt -> (elt * 'e) option
(** Edge from the element to its parent in the proof forest; Returns
None if the element is a root of the forest. *)
val explain : 'e t -> elt -> elt -> 'e list
(** [explain uf a b] returns a list of labels that justify why
[find uf a = find uf b]. Such labels were provided by [union]. *)
val explain_distinct : 'e t -> elt -> elt -> elt * elt
(** [explain_distinct uf a b] gives the original pair [a', b'] that
made [a] and [b] distinct by calling [distinct a' b']. The
terms must be distinct, otherwise Failure is raised. *)
end
module Make(X : ID) : S with type elt = X.t

View file

@ -6,6 +6,8 @@ let suite =
"all_tests" >:::
[ Test_pHashtbl.suite;
Test_PersistentHashtbl.suite;
Test_cc.suite;
Test_puf.suite;
Test_vector.suite;
Test_gen.suite;
Test_deque.suite;

93
tests/test_cc.ml Normal file
View file

@ -0,0 +1,93 @@
(** Tests for congruence closure *)
open OUnit
let parse = CC.parse
let pp = CC.pp
module CT = CC.StrTerm
module CC = CC.StrCC
let test_add () =
let cc = CC.create 5 in
let t = parse "((a (b c)) d)" in
OUnit.assert_equal ~cmp:CT.eq t t;
let t2 = parse "(f (g (h x)))" in
OUnit.assert_bool "not eq" (not (CC.eq cc t t2));
()
let test_merge () =
let t1 = parse "((f (a b)) c)" in
let t2 = parse "((f (a b2)) c2)" in
(* Format.printf "t1=%a, t2=%a@." pp t1 pp t2; *)
let cc = CC.create 5 in
(* merge b and b2 *)
let cc = CC.merge cc (parse "b") (parse "b2") in
OUnit.assert_bool "not eq" (not (CC.eq cc t1 t2));
OUnit.assert_bool "eq_sub" (CC.eq cc (parse "b") (parse "b2"));
(* merge c and c2 *)
let cc = CC.merge cc (parse "c") (parse "c2") in
OUnit.assert_bool "eq_sub" (CC.eq cc (parse "c") (parse "c2"));
(* Format.printf "t1=%a, t2=%a@." pp (CC.normalize cc t1) pp (CC.normalize cc t2); *)
OUnit.assert_bool "eq" (CC.eq cc t1 t2);
()
let test_merge2 () =
let cc = CC.create 5 in
let cc = CC.distinct cc (parse "a") (parse "b") in
let cc = CC.merge cc (parse "(f c)") (parse "a") in
let cc = CC.merge cc (parse "(f d)") (parse "b") in
OUnit.assert_bool "not_eq" (not (CC.can_eq cc (parse "a") (parse "b")));
OUnit.assert_bool "inconsistent"
(try ignore (CC.merge cc (parse "c") (parse "d")); false
with CC.Inconsistent _ -> true);
()
let test_merge3 () =
let cc = CC.create 5 in
(* f^3(a) = a *)
let cc = CC.merge cc (parse "a") (parse "(f (f (f a)))") in
OUnit.assert_equal ~cmp:CT.eq (parse "(f (f a))") (parse "(f (f a))");
(* f^4(a) = a *)
let cc = CC.merge cc (parse "(f (f (f (f (f a)))))") (parse "a") in
(* CC.iter_equiv_class cc (parse "a") (fun t -> Format.printf "a = %a@." pp t); *)
(* hence, f^5(a) = f^2(f^3(a)) = f^2(a), and f^3(a) = f(f^2(a)) = f(a) = a *)
OUnit.assert_bool "eq" (CC.eq cc (parse "a") (parse "(f a)"));
()
let test_merge4 () =
let cc = CC.create 5 in
let cc = CC.merge cc (parse "true") (parse "(p (f (f (f (f (f (f a)))))))") in
let cc = CC.merge cc (parse "a") (parse "(f b)") in
let cc = CC.merge cc (parse "(f a)") (parse "b") in
OUnit.assert_bool "eq" (CC.eq cc (parse "a") (parse "(f (f (f (f (f (f a))))))"));
()
let test_explain () =
let cc = CC.create 5 in
(* f^3(a) = a *)
let cc = CC.merge cc (parse "a") (parse "(f (f (f a)))") in
(* f^4(a) = a *)
let cc = CC.merge cc (parse "(f (f (f (f (f a)))))") (parse "a") in
(* Format.printf "t: %a@." pp (parse "(f (f (f (f (f a)))))"); *)
(* hence, f^5(a) = f^2(f^3(a)) = f^2(a), and f^3(a) = f(f^2(a)) = f(a) = a *)
let l = CC.explain cc (parse "a") (parse "(f (f a))") in
(*
List.iter
(function
| CC.ByMerge (a,b) -> Format.printf "merge %a %a@." pp a pp b
| CC.ByCongruence (a,b) -> Format.printf "congruence %a %a@." pp a pp b)
l;
*)
OUnit.assert_equal 4 (List.length l);
()
let suite =
"test_cc" >:::
[ "test_add" >:: test_add;
"test_merge" >:: test_merge;
"test_merge2" >:: test_merge2;
"test_merge3" >:: test_merge3;
"test_merge4" >:: test_merge4;
"test_explain" >:: test_explain;
]

102
tests/test_puf.ml Normal file
View file

@ -0,0 +1,102 @@
(** Tests for persistent union find *)
open OUnit
module P = Puf.Make(struct type t = int let get_id i = i end)
let rec merge_list uf l = match l with
| [] | [_] -> uf
| x::((y::_) as l') ->
merge_list (P.union uf x y (x,y)) l'
let test_union () =
let uf = P.create 5 in
let uf = merge_list uf [1;2;3] in
let uf = merge_list uf [5;6] in
OUnit.assert_equal (P.find uf 1) (P.find uf 2);
OUnit.assert_equal (P.find uf 1) (P.find uf 3);
OUnit.assert_equal (P.find uf 5) (P.find uf 6);
OUnit.assert_bool "noteq" ((P.find uf 1) <> (P.find uf 5));
OUnit.assert_equal 10 (P.find uf 10);
let uf = P.union uf 1 5 (1,5) in
OUnit.assert_equal (P.find uf 2) (P.find uf 6);
()
let test_iter () =
let uf = P.create 5 in
let uf = merge_list uf [1;2;3] in
let uf = merge_list uf [5;6] in
let uf = merge_list uf [10;11;12;13;2] in
(* equiv classes *)
let l1 = ref [] in
P.iter_equiv_class uf 1 (fun x -> l1 := x:: !l1);
let l2 = ref [] in
P.iter_equiv_class uf 5 (fun x -> l2 := x:: !l2);
OUnit.assert_equal [1;2;3;10;11;12;13] (List.sort compare !l1);
OUnit.assert_equal [5;6] (List.sort compare !l2);
()
let test_distinct () =
let uf = P.create 5 in
let uf = merge_list uf [1;2;3] in
let uf = merge_list uf [5;6] in
let uf = P.distinct uf 1 5 in
OUnit.assert_equal None (P.inconsistent uf);
let uf' = P.union uf 2 6 (2,6) in
OUnit.assert_bool "inconsistent"
(match P.inconsistent uf' with | None -> false | Some _ -> true);
OUnit.assert_equal None (P.inconsistent uf);
let uf = P.union uf 1 10 (1,10) in
OUnit.assert_equal None (P.inconsistent uf);
()
let test_big () =
let uf = P.create 5 in
let uf = ref uf in
for i = 0 to 100_000 do
uf := P.union !uf 1 i (1,i);
done;
let uf = !uf in
let n = P.fold_equiv_class uf 1 (fun acc _ -> acc+1) 0 in
OUnit.assert_equal ~printer:string_of_int 100_001 n;
()
let test_explain () =
let uf = P.create 5 in
let uf = P.union uf 1 2 (1,2) in
let uf = P.union uf 1 3 (1,3) in
let uf = P.union uf 5 6 (5,6) in
let uf = P.union uf 4 5 (4,5) in
let uf = P.union uf 5 3 (5,3) in
OUnit.assert_bool "eq" (P.find uf 1 = P.find uf 5);
let l = P.explain uf 1 6 in
OUnit.assert_bool "not empty explanation" (l <> []);
(* List.iter (fun (a,b) -> Format.printf "%d, %d@." a b) l; *)
()
(*
let bench () =
let run n =
let uf = P.create 5 in
let uf = ref uf in
for i = 0 to n do
uf := P.union !uf 1 i;
done
in
let res = Bench.bench_args run
[ "100", 100;
"10_000", 10_000;
]
in Bench.summarize 1. res;
()
*)
let suite =
"test_puf" >:::
[ "test_union" >:: test_union;
"test_iter" >:: test_iter;
"test_distinct" >:: test_distinct;
"test_big" >:: test_big;
"test_explain" >:: test_explain;
(* "bench" >:: bench; *)
]