From 8dcb67552ede8304ea2cf0bf5f5824aa966c0942 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 7 Jun 2019 15:48:20 -0500 Subject: [PATCH] refactor: rewrite production of explanation in CC - use a mutable bit in nodes for finding common ancestor - use fold-like traversal of explanations --- src/cc/Sidekick_cc.ml | 204 +++++++++++++++++++----------------------- 1 file changed, 90 insertions(+), 114 deletions(-) diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index cabc3e6a..2dca5508 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -32,6 +32,9 @@ module Make(CC_A: ARG) = struct let field_is_pending = Bits.mk_field() (** true iff the node is in the [cc.pending] queue *) + let field_marked_explain = Bits.mk_field() + (** used to mark traversed nodes when looking for a common ancestor *) + (** A node of the congruence closure. An equivalence class is represented by its "root" element, the representative. *) @@ -214,9 +217,6 @@ module Make(CC_A: ARG) = struct mutable on_new_term: ev_on_new_term list; mutable on_conflict: ev_on_conflict list; mutable on_propagate: ev_on_propagate list; - mutable ps_lits: lit list; (* TODO: thread it around instead? *) - (* proof state *) - ps_queue: (node*node) Vec.t; (* pairs to explain *) true_ : node lazy_t; false_ : node lazy_t; @@ -248,21 +248,11 @@ module Make(CC_A: ARG) = struct Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t - (* find representative, recursively *) - let[@unroll 2] rec find_rec (n:node) : repr = - if n==n.n_root then ( - n - ) else ( - let root = find_rec n.n_root in - if root != n.n_root then ( - n.n_root <- root; (* path compression *) - ); - root - ) - (* non-recursive, inlinable function for [find] *) let[@inline] find_ (n:node) : repr = - if n == n.n_root then n else find_rec n.n_root + let n2 = n.n_root in + assert (n2.n_root == n2); + n2 let[@inline] same_class (n1:node)(n2:node): bool = N.equal (find_ n1) (find_ n2) @@ -330,11 +320,11 @@ module Make(CC_A: ARG) = struct (* re-root the explanation tree of the equivalence class of [n] so that it points to [n]. postcondition: [n.n_expl = None] *) - let rec reroot_expl (cc:t) (n:node): unit = - let old_expl = n.n_expl in - begin match old_expl with + let[@unroll 2] rec reroot_expl (cc:t) (n:node): unit = + begin match n.n_expl with | FL_none -> () (* already root *) | FL_some {next=u; expl=e_n_u} -> + (* reroot to [u], then invert link between [u] and [n] *) reroot_expl cc u; u.n_expl <- FL_some {next=n; expl=e_n_u}; n.n_expl <- FL_none; @@ -353,118 +343,103 @@ module Make(CC_A: ARG) = struct T_tbl.values cc.tbl |> Iter.filter N.is_root - (* TODO: use markers and lockstep iteration instead *) - (* distance from [t] to its root in the proof forest *) - let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with - | FL_none -> 0 - | FL_some {next=t'; _} -> 1 + distance_to_root t' - - (* TODO: new bool flag on nodes + stepwise progress + cleanup *) (* find the closest common ancestor of [a] and [b] in the proof forest *) let find_common_ancestor (a:node) (b:node) : node = - let d_a = distance_to_root a in - let d_b = distance_to_root b in - (* drop [n] nodes in the path from [t] to its root *) - let rec drop_ n t = - if n=0 then t - else match t.n_expl with + (* catch up to the other node *) + let rec find1 a = + if N.get_field field_marked_explain a then a + else ( + match a.n_expl with | FL_none -> assert false - | FL_some {next=t'; _} -> drop_ (n-1) t' + | FL_some r -> find1 r.next + ) in - (* reduce to the problem where [a] and [b] have the same distance to root *) - let a, b = - if d_a > d_b then drop_ (d_a-d_b) a, b - else if d_a < d_b then a, drop_ (d_b-d_a) b - else a, b + let rec find2 a b = + if N.equal a b then a + else if N.get_field field_marked_explain a then a + else if N.get_field field_marked_explain b then b + else ( + N.set_field field_marked_explain true a; + N.set_field field_marked_explain true b; + match a.n_expl, b.n_expl with + | FL_some r1, FL_some r2 -> find2 r1.next r2.next + | FL_some r, FL_none -> find1 r.next + | FL_none, FL_some r -> find1 r.next + | FL_none, FL_none -> assert false (* no common ancestor *) + ) + in - (* traverse stepwise until a==b *) - let rec aux_same_dist a b = - if a==b then a - else match a.n_expl, b.n_expl with - | FL_none, _ | _, FL_none -> assert false - | FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b' + (* cleanup tags on nodes traversed in [find_] *) + let rec cleanup_ n = + if N.get_field field_marked_explain n then ( + N.set_field field_marked_explain false n; + match n.n_expl with + | FL_none -> () + | FL_some {next;_} -> cleanup_ next; + ) in - aux_same_dist a b + let n = find2 a b in + cleanup_ a; + cleanup_ b; + n - let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) - let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits - - let ps_clear (cc:t) = - cc.ps_lits <- []; - Vec.clear cc.ps_queue; - () - - (* decompose explanation [e] of why [n1 = n2] *) - let rec decompose_explain cc (e:explanation) : unit = + (* decompose explanation [e] into a list of literals added to [acc] *) + let rec explain_decompose cc (acc:lit list) (e:explanation) : _ list = Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); match e with - | E_reduction -> () + | E_reduction -> acc | E_congruence (n1, n2) -> begin match n1.n_sig0, n2.n_sig0 with | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> assert (Fun.equal f1 f2); assert (List.length a1 = List.length a2); - List.iter2 (ps_add_obligation cc) a1 a2; + List.fold_left2 (explain_pair cc) acc a1 a2 | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> assert (List.length a1 = List.length a2); - ps_add_obligation cc f1 f2; - List.iter2 (ps_add_obligation cc) a1 a2; + let acc = explain_pair cc acc f1 f2 in + List.fold_left2 (explain_pair cc) acc a1 a2 | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> - ps_add_obligation cc a1 a2; - ps_add_obligation cc b1 b2; - ps_add_obligation cc c1 c2; + let acc = explain_pair cc acc a1 a2 in + let acc = explain_pair cc acc b1 b2 in + explain_pair cc acc c1 c2 | _ -> assert false end - | E_lit lit -> ps_add_lit cc lit - | E_merge (a,b) -> ps_add_obligation cc a b + | E_lit lit -> lit :: acc + | E_merge (a,b) -> explain_pair cc acc a b | E_merge_t (a,b) -> (* find nodes for [a] and [b] on the fly *) begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with - | a, b -> ps_add_obligation cc a b + | a, b -> explain_pair cc acc a b | exception Not_found -> Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b end - | E_and (a,b) -> decompose_explain cc a; decompose_explain cc b + | E_and (a,b) -> + let acc = explain_decompose cc acc a in + explain_decompose cc acc b - (* explain why [a = parent_a], where [a -> ... -> parent_a] in the + and explain_pair (cc:t) (acc:lit list) (a:node) (b:node) : _ list = + Log.debugf 5 + (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); + assert (N.equal (find_ a) (find_ b)); + let ancestor = find_common_ancestor a b in + let acc = explain_along_path cc acc a ancestor in + explain_along_path cc acc b ancestor + + (* explain why [a = parent_a], where [a -> ... -> target] in the proof forest *) - let explain_along_path ps (a:node) (parent_a:node) : unit = - let rec aux n = - if n != parent_a then ( + and explain_along_path cc acc (a:node) (target:node) : _ list = + let rec aux acc n = + if n == target then acc + else ( match n.n_expl with | FL_none -> assert false | FL_some {next=next_n; expl=expl} -> - decompose_explain ps expl; - (* now prove [next_n = parent_a] *) - aux next_n + let acc = explain_decompose cc acc expl in + (* now prove [next_n = target] *) + aux acc next_n ) - in aux a - - (* find explanation *) - let explain_loop (cc : t) : lit list = - while not (Vec.is_empty cc.ps_queue) do - let a, b = Vec.pop cc.ps_queue in - Log.debugf 5 - (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); - assert (N.equal (find_ a) (find_ b)); - let c = find_common_ancestor a b in - explain_along_path cc a c; - explain_along_path cc b c; - done; - cc.ps_lits - - let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list = - ps_clear cc; - cc.ps_lits <- init; - ps_add_obligation cc n1 n2; - explain_loop cc - - let explain_unfold ?(init=[]) cc (e:explanation) : lit list = - ps_clear cc; - cc.ps_lits <- init; - decompose_explain cc e; - explain_loop cc + in aux acc a (* add a term *) let [@inline] rec add_term_rec_ cc t : node = @@ -606,9 +581,9 @@ module Make(CC_A: ARG) = struct Log.debugf 5 (fun k->k "(@[cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])" N.pp ra N.pp rb Expl.pp e_ab); - let lits = explain_unfold cc e_ab in - let lits = explain_eq_n ~init:lits cc a ra in - let lits = explain_eq_n ~init:lits cc b rb in + let lits = explain_decompose cc [] e_ab in + let lits = explain_pair cc lits a ra in + let lits = explain_pair cc lits b rb in raise_conflict cc acts (List.rev_map Lit.neg lits) ); (* We will merge [r_from] into [r_into]. @@ -647,11 +622,15 @@ module Make(CC_A: ARG) = struct (fun u -> assert (u.n_root == r_from); u.n_root <- r_into); - (* now merge the classes *) + (* capture current state *) let r_into_old_next = r_into.n_next in let r_from_old_next = r_from.n_next in let r_into_old_parents = r_into.n_parents in + (* swap [into.next] and [from.next], merging the classes *) + r_into.n_next <- r_from_old_next; + r_from.n_next <- r_into_old_next; r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; + r_into.n_size <- r_into.n_size + r_from.n_size; (* on backtrack, unmerge classes and restore the pointers to [r_from] *) on_backtrack cc (fun () -> @@ -661,11 +640,11 @@ module Make(CC_A: ARG) = struct r_into.n_next <- r_into_old_next; r_from.n_next <- r_from_old_next; r_into.n_parents <- r_into_old_parents; + (* NOTE: this must come after the restoration of [next] pointers, + otherwise we'd iterate on too big a class *) N.iter_class_ r_from (fun u -> u.n_root <- r_from); + r_into.n_size <- r_into.n_size - r_from.n_size; ); - (* swap [into.next] and [from.next], merging the classes *) - r_into.n_next <- r_from_old_next; - r_from.n_next <- r_into_old_next; end; (* update explanations (a -> b), arbitrarily. Note that here we merge the classes by adding a bridge between [a] @@ -691,8 +670,8 @@ module Make(CC_A: ARG) = struct and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = (* explanation for [t1 =e= t2 = r2] *) let half_expl = lazy ( - let expl = explain_unfold cc e_12 in - explain_eq_n ~init:expl cc r2 t2 + let lits = explain_decompose cc [] e_12 in + explain_pair cc lits r2 t2 ) in (* TODO: flag per class, `or`-ed on merge, to indicate if the class contains at least one lit *) @@ -709,7 +688,7 @@ module Make(CC_A: ARG) = struct Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); (* complete explanation with the [u1=t1] chunk *) let reason = - let e = lazy (explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1) in + let e = lazy (explain_pair cc (Lazy.force half_expl) u1 t1) in fun () -> Lazy.force e in List.iter (fun f -> f cc lit reason) cc.on_propagate; @@ -807,9 +786,8 @@ module Make(CC_A: ARG) = struct let raise_conflict_from_expl cc (acts:actions) expl = Log.debugf 5 (fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl); - ps_clear cc; - decompose_explain cc expl; - let lits = List.rev_map Lit.neg cc.ps_lits in + let lits = explain_decompose cc [] expl in + let lits = List.rev_map Lit.neg lits in raise_conflict cc acts lits let merge cc n1 n2 expl = @@ -837,9 +815,7 @@ module Make(CC_A: ARG) = struct on_propagate; pending=Vec.create(); combine=Vec.create(); - ps_lits=[]; undo=Backtrack_stack.create(); - ps_queue=Vec.create(); true_; false_; stat;