refactor: rewrite production of explanation in CC

- use a mutable bit in nodes for finding common ancestor
- use fold-like traversal of explanations
This commit is contained in:
Simon Cruanes 2019-06-07 15:48:20 -05:00
parent 12ea0c3be4
commit 8dcb67552e

View file

@ -32,6 +32,9 @@ module Make(CC_A: ARG) = struct
let field_is_pending = Bits.mk_field()
(** true iff the node is in the [cc.pending] queue *)
let field_marked_explain = Bits.mk_field()
(** used to mark traversed nodes when looking for a common ancestor *)
(** A node of the congruence closure.
An equivalence class is represented by its "root" element,
the representative. *)
@ -214,9 +217,6 @@ module Make(CC_A: ARG) = struct
mutable on_new_term: ev_on_new_term list;
mutable on_conflict: ev_on_conflict list;
mutable on_propagate: ev_on_propagate list;
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
(* proof state *)
ps_queue: (node*node) Vec.t;
(* pairs to explain *)
true_ : node lazy_t;
false_ : node lazy_t;
@ -248,21 +248,11 @@ module Make(CC_A: ARG) = struct
Invariant: [in_cc t do_cc t => forall u subterm t, in_cc u] *)
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
(* find representative, recursively *)
let[@unroll 2] rec find_rec (n:node) : repr =
if n==n.n_root then (
n
) else (
let root = find_rec n.n_root in
if root != n.n_root then (
n.n_root <- root; (* path compression *)
);
root
)
(* non-recursive, inlinable function for [find] *)
let[@inline] find_ (n:node) : repr =
if n == n.n_root then n else find_rec n.n_root
let n2 = n.n_root in
assert (n2.n_root == n2);
n2
let[@inline] same_class (n1:node)(n2:node): bool =
N.equal (find_ n1) (find_ n2)
@ -330,11 +320,11 @@ module Make(CC_A: ARG) = struct
(* re-root the explanation tree of the equivalence class of [n]
so that it points to [n].
postcondition: [n.n_expl = None] *)
let rec reroot_expl (cc:t) (n:node): unit =
let old_expl = n.n_expl in
begin match old_expl with
let[@unroll 2] rec reroot_expl (cc:t) (n:node): unit =
begin match n.n_expl with
| FL_none -> () (* already root *)
| FL_some {next=u; expl=e_n_u} ->
(* reroot to [u], then invert link between [u] and [n] *)
reroot_expl cc u;
u.n_expl <- FL_some {next=n; expl=e_n_u};
n.n_expl <- FL_none;
@ -353,118 +343,103 @@ module Make(CC_A: ARG) = struct
T_tbl.values cc.tbl
|> Iter.filter N.is_root
(* TODO: use markers and lockstep iteration instead *)
(* distance from [t] to its root in the proof forest *)
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
| FL_none -> 0
| FL_some {next=t'; _} -> 1 + distance_to_root t'
(* TODO: new bool flag on nodes + stepwise progress + cleanup *)
(* find the closest common ancestor of [a] and [b] in the proof forest *)
let find_common_ancestor (a:node) (b:node) : node =
let d_a = distance_to_root a in
let d_b = distance_to_root b in
(* drop [n] nodes in the path from [t] to its root *)
let rec drop_ n t =
if n=0 then t
else match t.n_expl with
(* catch up to the other node *)
let rec find1 a =
if N.get_field field_marked_explain a then a
else (
match a.n_expl with
| FL_none -> assert false
| FL_some {next=t'; _} -> drop_ (n-1) t'
| FL_some r -> find1 r.next
)
in
(* reduce to the problem where [a] and [b] have the same distance to root *)
let a, b =
if d_a > d_b then drop_ (d_a-d_b) a, b
else if d_a < d_b then a, drop_ (d_b-d_a) b
else a, b
let rec find2 a b =
if N.equal a b then a
else if N.get_field field_marked_explain a then a
else if N.get_field field_marked_explain b then b
else (
N.set_field field_marked_explain true a;
N.set_field field_marked_explain true b;
match a.n_expl, b.n_expl with
| FL_some r1, FL_some r2 -> find2 r1.next r2.next
| FL_some r, FL_none -> find1 r.next
| FL_none, FL_some r -> find1 r.next
| FL_none, FL_none -> assert false (* no common ancestor *)
)
in
(* traverse stepwise until a==b *)
let rec aux_same_dist a b =
if a==b then a
else match a.n_expl, b.n_expl with
| FL_none, _ | _, FL_none -> assert false
| FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b'
(* cleanup tags on nodes traversed in [find_] *)
let rec cleanup_ n =
if N.get_field field_marked_explain n then (
N.set_field field_marked_explain false n;
match n.n_expl with
| FL_none -> ()
| FL_some {next;_} -> cleanup_ next;
)
in
aux_same_dist a b
let n = find2 a b in
cleanup_ a;
cleanup_ b;
n
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits
let ps_clear (cc:t) =
cc.ps_lits <- [];
Vec.clear cc.ps_queue;
()
(* decompose explanation [e] of why [n1 = n2] *)
let rec decompose_explain cc (e:explanation) : unit =
(* decompose explanation [e] into a list of literals added to [acc] *)
let rec explain_decompose cc (acc:lit list) (e:explanation) : _ list =
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
match e with
| E_reduction -> ()
| E_reduction -> acc
| E_congruence (n1, n2) ->
begin match n1.n_sig0, n2.n_sig0 with
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
assert (Fun.equal f1 f2);
assert (List.length a1 = List.length a2);
List.iter2 (ps_add_obligation cc) a1 a2;
List.fold_left2 (explain_pair cc) acc a1 a2
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
assert (List.length a1 = List.length a2);
ps_add_obligation cc f1 f2;
List.iter2 (ps_add_obligation cc) a1 a2;
let acc = explain_pair cc acc f1 f2 in
List.fold_left2 (explain_pair cc) acc a1 a2
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
ps_add_obligation cc a1 a2;
ps_add_obligation cc b1 b2;
ps_add_obligation cc c1 c2;
let acc = explain_pair cc acc a1 a2 in
let acc = explain_pair cc acc b1 b2 in
explain_pair cc acc c1 c2
| _ ->
assert false
end
| E_lit lit -> ps_add_lit cc lit
| E_merge (a,b) -> ps_add_obligation cc a b
| E_lit lit -> lit :: acc
| E_merge (a,b) -> explain_pair cc acc a b
| E_merge_t (a,b) ->
(* find nodes for [a] and [b] on the fly *)
begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with
| a, b -> ps_add_obligation cc a b
| a, b -> explain_pair cc acc a b
| exception Not_found ->
Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b
end
| E_and (a,b) -> decompose_explain cc a; decompose_explain cc b
| E_and (a,b) ->
let acc = explain_decompose cc acc a in
explain_decompose cc acc b
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
and explain_pair (cc:t) (acc:lit list) (a:node) (b:node) : _ list =
Log.debugf 5
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
assert (N.equal (find_ a) (find_ b));
let ancestor = find_common_ancestor a b in
let acc = explain_along_path cc acc a ancestor in
explain_along_path cc acc b ancestor
(* explain why [a = parent_a], where [a -> ... -> target] in the
proof forest *)
let explain_along_path ps (a:node) (parent_a:node) : unit =
let rec aux n =
if n != parent_a then (
and explain_along_path cc acc (a:node) (target:node) : _ list =
let rec aux acc n =
if n == target then acc
else (
match n.n_expl with
| FL_none -> assert false
| FL_some {next=next_n; expl=expl} ->
decompose_explain ps expl;
(* now prove [next_n = parent_a] *)
aux next_n
let acc = explain_decompose cc acc expl in
(* now prove [next_n = target] *)
aux acc next_n
)
in aux a
(* find explanation *)
let explain_loop (cc : t) : lit list =
while not (Vec.is_empty cc.ps_queue) do
let a, b = Vec.pop cc.ps_queue in
Log.debugf 5
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
assert (N.equal (find_ a) (find_ b));
let c = find_common_ancestor a b in
explain_along_path cc a c;
explain_along_path cc b c;
done;
cc.ps_lits
let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list =
ps_clear cc;
cc.ps_lits <- init;
ps_add_obligation cc n1 n2;
explain_loop cc
let explain_unfold ?(init=[]) cc (e:explanation) : lit list =
ps_clear cc;
cc.ps_lits <- init;
decompose_explain cc e;
explain_loop cc
in aux acc a
(* add a term *)
let [@inline] rec add_term_rec_ cc t : node =
@ -606,9 +581,9 @@ module Make(CC_A: ARG) = struct
Log.debugf 5
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])"
N.pp ra N.pp rb Expl.pp e_ab);
let lits = explain_unfold cc e_ab in
let lits = explain_eq_n ~init:lits cc a ra in
let lits = explain_eq_n ~init:lits cc b rb in
let lits = explain_decompose cc [] e_ab in
let lits = explain_pair cc lits a ra in
let lits = explain_pair cc lits b rb in
raise_conflict cc acts (List.rev_map Lit.neg lits)
);
(* We will merge [r_from] into [r_into].
@ -647,11 +622,15 @@ module Make(CC_A: ARG) = struct
(fun u ->
assert (u.n_root == r_from);
u.n_root <- r_into);
(* now merge the classes *)
(* capture current state *)
let r_into_old_next = r_into.n_next in
let r_from_old_next = r_from.n_next in
let r_into_old_parents = r_into.n_parents in
(* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next;
r_from.n_next <- r_into_old_next;
r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents;
r_into.n_size <- r_into.n_size + r_from.n_size;
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
on_backtrack cc
(fun () ->
@ -661,11 +640,11 @@ module Make(CC_A: ARG) = struct
r_into.n_next <- r_into_old_next;
r_from.n_next <- r_from_old_next;
r_into.n_parents <- r_into_old_parents;
(* NOTE: this must come after the restoration of [next] pointers,
otherwise we'd iterate on too big a class *)
N.iter_class_ r_from (fun u -> u.n_root <- r_from);
r_into.n_size <- r_into.n_size - r_from.n_size;
);
(* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next;
r_from.n_next <- r_into_old_next;
end;
(* update explanations (a -> b), arbitrarily.
Note that here we merge the classes by adding a bridge between [a]
@ -691,8 +670,8 @@ module Make(CC_A: ARG) = struct
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
(* explanation for [t1 =e= t2 = r2] *)
let half_expl = lazy (
let expl = explain_unfold cc e_12 in
explain_eq_n ~init:expl cc r2 t2
let lits = explain_decompose cc [] e_12 in
explain_pair cc lits r2 t2
) in
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
contains at least one lit *)
@ -709,7 +688,7 @@ module Make(CC_A: ARG) = struct
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
(* complete explanation with the [u1=t1] chunk *)
let reason =
let e = lazy (explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1) in
let e = lazy (explain_pair cc (Lazy.force half_expl) u1 t1) in
fun () -> Lazy.force e
in
List.iter (fun f -> f cc lit reason) cc.on_propagate;
@ -807,9 +786,8 @@ module Make(CC_A: ARG) = struct
let raise_conflict_from_expl cc (acts:actions) expl =
Log.debugf 5
(fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl);
ps_clear cc;
decompose_explain cc expl;
let lits = List.rev_map Lit.neg cc.ps_lits in
let lits = explain_decompose cc [] expl in
let lits = List.rev_map Lit.neg lits in
raise_conflict cc acts lits
let merge cc n1 n2 expl =
@ -837,9 +815,7 @@ module Make(CC_A: ARG) = struct
on_propagate;
pending=Vec.create();
combine=Vec.create();
ps_lits=[];
undo=Backtrack_stack.create();
ps_queue=Vec.create();
true_;
false_;
stat;