diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 9f8fd49d..d4acd0ff 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -40,8 +40,8 @@ type t = { acts: actions; tbl: node Term.Tbl.t; (* internalization [term -> node] *) - signatures_tbl : repr Sig_tbl.t; - (* map a signature to the corresponding term in some equivalence class. + signatures_tbl : node Sig_tbl.t; + (* map a signature to the corresponding node in some equivalence class. A signature is a [term_cell] in which every immediate subterm that participates in the congruence/evaluation relation is normalized (i.e. is its own representative). @@ -80,6 +80,8 @@ let[@inline] size_ (r:repr) = let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t +(* TODO: remove path compression, point to new root explicitely during `union` *) + (* find representative, recursively, and perform path compression *) let rec find_rec cc (n:node) : repr = if n==n.n_root then ( @@ -135,21 +137,16 @@ let find_by_signature cc (t:term) : repr option = match signature cc t with | None -> None | Some s -> Sig_tbl.get cc.signatures_tbl s -let remove_signature cc (t:term): unit = match signature cc t with - | None -> () - | Some s -> - Sig_tbl.remove cc.signatures_tbl s - -let add_signature cc (t:term) (r:repr): unit = match signature cc t with +let add_signature cc (t:term) (r:node): unit = match signature cc t with | None -> () | Some s -> (* add, but only if not present already *) - begin match Sig_tbl.get cc.signatures_tbl s with - | None -> + begin match Sig_tbl.find cc.signatures_tbl s with + | exception Not_found -> on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s); Sig_tbl.add cc.signatures_tbl s r; - | Some r' -> - assert (Equiv_class.equal r r'); + | r' -> + assert (same_class cc r r'); end let is_done (cc:t): bool = @@ -246,18 +243,7 @@ let rec decompose_explain cc (e:explanation): unit = | E_lits l -> List.iter (ps_add_lit cc) l | E_merges l -> (* need to explain each merge in [l] *) - List.iter (fun (t,u) -> ps_add_obligation cc t u) l - | E_congruence (t1,t2) -> - (* [t1] and [t2] must be applications of the same symbol to - arguments that are pairwise equal *) - begin match t1.n_term.term_view, t2.n_term.term_view with - | App_cst (f1, a1), App_cst (f2, a2) -> - assert (Cst.equal f1 f2); - assert (IArray.length a1 = IArray.length a2); - IArray.iter2 (ps_add_obligation_t cc) a1 a2 - | If _, _ | App_cst _, _ | Bool _, _ - -> assert false - end + IArray.iter (fun (t,u) -> ps_add_obligation cc t u) l end (* explain why [a = parent_a], where [a -> ... -> parent_a] in the @@ -339,12 +325,24 @@ let rec update_pending (cc:t): unit = (* check if some parent collided *) begin match find_by_signature cc n.n_term with | None -> - (* add to the signature table [n --> n.root] *) - add_signature cc n.n_term (find cc n) + (* add to the signature table [sig(n) --> n] *) + add_signature cc n.n_term n | Some u -> (* must combine [t] with [r] *) - if not @@ Equiv_class.equal n u then ( - push_combine cc n u (Explanation.mk_congruence n u) + if not @@ same_class cc n u then ( + (* [t1] and [t2] must be applications of the same symbol to + arguments that are pairwise equal *) + assert (n != u); + let expl = match n.n_term.term_view, u.n_term.term_view with + | App_cst (f1, a1), App_cst (f2, a2) -> + assert (Cst.equal f1 f2); + assert (IArray.length a1 = IArray.length a2); + Explanation.mk_merges @@ + IArray.map2 (fun u1 u2 -> add cc u1, add cc u2) a1 a2 + | If _, _ | App_cst _, _ | Bool _, _ + -> assert false + in + push_combine cc n u expl ) end; (* FIXME: when to actually evaluate? @@ -390,12 +388,7 @@ and update_combine cc = begin Bag.to_seq (r_from:>node).n_parents |> Sequence.iter - (fun parent -> - (* FIXME: with OCaml's hashtable, we should be able - to keep this entry (and have it become relevant later - once the signature of [parent] is backtracked) *) - remove_signature cc parent.n_term; - push_pending cc parent) + (fun parent -> push_pending cc parent) end; (* perform [union ra rb] *) begin diff --git a/src/smt/Explanation.ml b/src/smt/Explanation.ml index 19b866ff..2248daed 100644 --- a/src/smt/Explanation.ml +++ b/src/smt/Explanation.ml @@ -3,10 +3,9 @@ open Solver_types type t = explanation = | E_reduction (* by pure reduction, tautologically equal *) - | E_merges of (cc_node * cc_node) list (* caused by these merges *) + | E_merges of (cc_node * cc_node) IArray.t (* caused by these merges *) | E_lit of lit (* because of this literal *) | E_lits of lit list (* because of this (true) conjunction *) - | E_congruence of cc_node * cc_node (* these terms are congruent *) let compare = cmp_exp let equal a b = cmp_exp a b = 0 @@ -17,7 +16,6 @@ let mk_merges l : t = E_merges l let mk_lit l : t = E_lit l let mk_lits = function [x] -> mk_lit x | l -> E_lits l let mk_reduction : t = E_reduction -let mk_congruence t u = E_congruence (t,u) let[@inline] lit l : t = E_lit l diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index e722b64a..c7fed204 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -49,10 +49,9 @@ and explanation_forest_link = (* atomic explanation in the congruence closure *) and explanation = | E_reduction (* by pure reduction, tautologically equal *) - | E_merges of (cc_node * cc_node) list (* caused by these merges *) + | E_merges of (cc_node * cc_node) IArray.t (* caused by these merges *) | E_lit of lit (* because of this literal *) | E_lits of lit list (* because of this (true) conjunction *) - | E_congruence of cc_node * cc_node (* these terms are congruent *) (* boolean literal *) and lit = { @@ -152,18 +151,14 @@ let rec cmp_exp a b = let toint = function | E_merges _ -> 0 | E_lit _ -> 1 | E_reduction -> 2 | E_lits _ -> 3 - | E_congruence _ -> 4 in begin match a, b with - | E_congruence (t1,t2), E_congruence (u1,u2) -> - CCOrd.(cmp_cc_node t1 u1 (cmp_cc_node, t2, u2)) | E_merges l1, E_merges l2 -> - CCList.compare (CCOrd.pair cmp_cc_node cmp_cc_node) l1 l2 + IArray.compare (CCOrd.pair cmp_cc_node cmp_cc_node) l1 l2 | E_reduction, E_reduction -> 0 | E_lit l1, E_lit l2 -> cmp_lit l1 l2 | E_lits l1, E_lits l2 -> CCList.compare cmp_lit l1 l2 - | E_merges _, _ | E_lit _, _ | E_lits _, _ - | E_reduction, _ | E_congruence _, _ + | E_merges _, _ | E_lit _, _ | E_lits _, _ | E_reduction, _ -> CCInt.compare (toint a)(toint b) end @@ -215,10 +210,8 @@ let pp_explanation out (e:explanation) = match e with | E_reduction -> Fmt.string out "reduction" | E_lit lit -> pp_lit out lit | E_lits l -> CCFormat.Dump.list pp_lit out l - | E_congruence (a,b) -> - Format.fprintf out "(@[congruence@ %a@ %a@])" pp_cc_node a pp_cc_node b | E_merges l -> Format.fprintf out "(@[merges@ %a@])" - Fmt.(list ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ + Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ pair ~sep:(return "@ <-> ") pp_cc_node pp_cc_node) - l + (IArray.to_seq l)