mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-09 12:45:48 -05:00
refactor: rewrite production of explanation in CC
- use a mutable bit in nodes for finding common ancestor - use fold-like traversal of explanations
This commit is contained in:
parent
12ea0c3be4
commit
8dcb67552e
1 changed files with 90 additions and 114 deletions
|
|
@ -32,6 +32,9 @@ module Make(CC_A: ARG) = struct
|
|||
let field_is_pending = Bits.mk_field()
|
||||
(** true iff the node is in the [cc.pending] queue *)
|
||||
|
||||
let field_marked_explain = Bits.mk_field()
|
||||
(** used to mark traversed nodes when looking for a common ancestor *)
|
||||
|
||||
(** A node of the congruence closure.
|
||||
An equivalence class is represented by its "root" element,
|
||||
the representative. *)
|
||||
|
|
@ -214,9 +217,6 @@ module Make(CC_A: ARG) = struct
|
|||
mutable on_new_term: ev_on_new_term list;
|
||||
mutable on_conflict: ev_on_conflict list;
|
||||
mutable on_propagate: ev_on_propagate list;
|
||||
mutable ps_lits: lit list; (* TODO: thread it around instead? *)
|
||||
(* proof state *)
|
||||
ps_queue: (node*node) Vec.t;
|
||||
(* pairs to explain *)
|
||||
true_ : node lazy_t;
|
||||
false_ : node lazy_t;
|
||||
|
|
@ -248,21 +248,11 @@ module Make(CC_A: ARG) = struct
|
|||
Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *)
|
||||
let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t
|
||||
|
||||
(* find representative, recursively *)
|
||||
let[@unroll 2] rec find_rec (n:node) : repr =
|
||||
if n==n.n_root then (
|
||||
n
|
||||
) else (
|
||||
let root = find_rec n.n_root in
|
||||
if root != n.n_root then (
|
||||
n.n_root <- root; (* path compression *)
|
||||
);
|
||||
root
|
||||
)
|
||||
|
||||
(* non-recursive, inlinable function for [find] *)
|
||||
let[@inline] find_ (n:node) : repr =
|
||||
if n == n.n_root then n else find_rec n.n_root
|
||||
let n2 = n.n_root in
|
||||
assert (n2.n_root == n2);
|
||||
n2
|
||||
|
||||
let[@inline] same_class (n1:node)(n2:node): bool =
|
||||
N.equal (find_ n1) (find_ n2)
|
||||
|
|
@ -330,11 +320,11 @@ module Make(CC_A: ARG) = struct
|
|||
(* re-root the explanation tree of the equivalence class of [n]
|
||||
so that it points to [n].
|
||||
postcondition: [n.n_expl = None] *)
|
||||
let rec reroot_expl (cc:t) (n:node): unit =
|
||||
let old_expl = n.n_expl in
|
||||
begin match old_expl with
|
||||
let[@unroll 2] rec reroot_expl (cc:t) (n:node): unit =
|
||||
begin match n.n_expl with
|
||||
| FL_none -> () (* already root *)
|
||||
| FL_some {next=u; expl=e_n_u} ->
|
||||
(* reroot to [u], then invert link between [u] and [n] *)
|
||||
reroot_expl cc u;
|
||||
u.n_expl <- FL_some {next=n; expl=e_n_u};
|
||||
n.n_expl <- FL_none;
|
||||
|
|
@ -353,118 +343,103 @@ module Make(CC_A: ARG) = struct
|
|||
T_tbl.values cc.tbl
|
||||
|> Iter.filter N.is_root
|
||||
|
||||
(* TODO: use markers and lockstep iteration instead *)
|
||||
(* distance from [t] to its root in the proof forest *)
|
||||
let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with
|
||||
| FL_none -> 0
|
||||
| FL_some {next=t'; _} -> 1 + distance_to_root t'
|
||||
|
||||
(* TODO: new bool flag on nodes + stepwise progress + cleanup *)
|
||||
(* find the closest common ancestor of [a] and [b] in the proof forest *)
|
||||
let find_common_ancestor (a:node) (b:node) : node =
|
||||
let d_a = distance_to_root a in
|
||||
let d_b = distance_to_root b in
|
||||
(* drop [n] nodes in the path from [t] to its root *)
|
||||
let rec drop_ n t =
|
||||
if n=0 then t
|
||||
else match t.n_expl with
|
||||
(* catch up to the other node *)
|
||||
let rec find1 a =
|
||||
if N.get_field field_marked_explain a then a
|
||||
else (
|
||||
match a.n_expl with
|
||||
| FL_none -> assert false
|
||||
| FL_some {next=t'; _} -> drop_ (n-1) t'
|
||||
| FL_some r -> find1 r.next
|
||||
)
|
||||
in
|
||||
(* reduce to the problem where [a] and [b] have the same distance to root *)
|
||||
let a, b =
|
||||
if d_a > d_b then drop_ (d_a-d_b) a, b
|
||||
else if d_a < d_b then a, drop_ (d_b-d_a) b
|
||||
else a, b
|
||||
let rec find2 a b =
|
||||
if N.equal a b then a
|
||||
else if N.get_field field_marked_explain a then a
|
||||
else if N.get_field field_marked_explain b then b
|
||||
else (
|
||||
N.set_field field_marked_explain true a;
|
||||
N.set_field field_marked_explain true b;
|
||||
match a.n_expl, b.n_expl with
|
||||
| FL_some r1, FL_some r2 -> find2 r1.next r2.next
|
||||
| FL_some r, FL_none -> find1 r.next
|
||||
| FL_none, FL_some r -> find1 r.next
|
||||
| FL_none, FL_none -> assert false (* no common ancestor *)
|
||||
)
|
||||
|
||||
in
|
||||
(* traverse stepwise until a==b *)
|
||||
let rec aux_same_dist a b =
|
||||
if a==b then a
|
||||
else match a.n_expl, b.n_expl with
|
||||
| FL_none, _ | _, FL_none -> assert false
|
||||
| FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b'
|
||||
(* cleanup tags on nodes traversed in [find_] *)
|
||||
let rec cleanup_ n =
|
||||
if N.get_field field_marked_explain n then (
|
||||
N.set_field field_marked_explain false n;
|
||||
match n.n_expl with
|
||||
| FL_none -> ()
|
||||
| FL_some {next;_} -> cleanup_ next;
|
||||
)
|
||||
in
|
||||
aux_same_dist a b
|
||||
let n = find2 a b in
|
||||
cleanup_ a;
|
||||
cleanup_ b;
|
||||
n
|
||||
|
||||
let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b)
|
||||
let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits
|
||||
|
||||
let ps_clear (cc:t) =
|
||||
cc.ps_lits <- [];
|
||||
Vec.clear cc.ps_queue;
|
||||
()
|
||||
|
||||
(* decompose explanation [e] of why [n1 = n2] *)
|
||||
let rec decompose_explain cc (e:explanation) : unit =
|
||||
(* decompose explanation [e] into a list of literals added to [acc] *)
|
||||
let rec explain_decompose cc (acc:lit list) (e:explanation) : _ list =
|
||||
Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e);
|
||||
match e with
|
||||
| E_reduction -> ()
|
||||
| E_reduction -> acc
|
||||
| E_congruence (n1, n2) ->
|
||||
begin match n1.n_sig0, n2.n_sig0 with
|
||||
| Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) ->
|
||||
assert (Fun.equal f1 f2);
|
||||
assert (List.length a1 = List.length a2);
|
||||
List.iter2 (ps_add_obligation cc) a1 a2;
|
||||
List.fold_left2 (explain_pair cc) acc a1 a2
|
||||
| Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) ->
|
||||
assert (List.length a1 = List.length a2);
|
||||
ps_add_obligation cc f1 f2;
|
||||
List.iter2 (ps_add_obligation cc) a1 a2;
|
||||
let acc = explain_pair cc acc f1 f2 in
|
||||
List.fold_left2 (explain_pair cc) acc a1 a2
|
||||
| Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) ->
|
||||
ps_add_obligation cc a1 a2;
|
||||
ps_add_obligation cc b1 b2;
|
||||
ps_add_obligation cc c1 c2;
|
||||
let acc = explain_pair cc acc a1 a2 in
|
||||
let acc = explain_pair cc acc b1 b2 in
|
||||
explain_pair cc acc c1 c2
|
||||
| _ ->
|
||||
assert false
|
||||
end
|
||||
| E_lit lit -> ps_add_lit cc lit
|
||||
| E_merge (a,b) -> ps_add_obligation cc a b
|
||||
| E_lit lit -> lit :: acc
|
||||
| E_merge (a,b) -> explain_pair cc acc a b
|
||||
| E_merge_t (a,b) ->
|
||||
(* find nodes for [a] and [b] on the fly *)
|
||||
begin match T_tbl.find cc.tbl a, T_tbl.find cc.tbl b with
|
||||
| a, b -> ps_add_obligation cc a b
|
||||
| a, b -> explain_pair cc acc a b
|
||||
| exception Not_found ->
|
||||
Error.errorf "expl: cannot find node(s) for %a, %a" T.pp a T.pp b
|
||||
end
|
||||
| E_and (a,b) -> decompose_explain cc a; decompose_explain cc b
|
||||
| E_and (a,b) ->
|
||||
let acc = explain_decompose cc acc a in
|
||||
explain_decompose cc acc b
|
||||
|
||||
(* explain why [a = parent_a], where [a -> ... -> parent_a] in the
|
||||
and explain_pair (cc:t) (acc:lit list) (a:node) (b:node) : _ list =
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
|
||||
assert (N.equal (find_ a) (find_ b));
|
||||
let ancestor = find_common_ancestor a b in
|
||||
let acc = explain_along_path cc acc a ancestor in
|
||||
explain_along_path cc acc b ancestor
|
||||
|
||||
(* explain why [a = parent_a], where [a -> ... -> target] in the
|
||||
proof forest *)
|
||||
let explain_along_path ps (a:node) (parent_a:node) : unit =
|
||||
let rec aux n =
|
||||
if n != parent_a then (
|
||||
and explain_along_path cc acc (a:node) (target:node) : _ list =
|
||||
let rec aux acc n =
|
||||
if n == target then acc
|
||||
else (
|
||||
match n.n_expl with
|
||||
| FL_none -> assert false
|
||||
| FL_some {next=next_n; expl=expl} ->
|
||||
decompose_explain ps expl;
|
||||
(* now prove [next_n = parent_a] *)
|
||||
aux next_n
|
||||
let acc = explain_decompose cc acc expl in
|
||||
(* now prove [next_n = target] *)
|
||||
aux acc next_n
|
||||
)
|
||||
in aux a
|
||||
|
||||
(* find explanation *)
|
||||
let explain_loop (cc : t) : lit list =
|
||||
while not (Vec.is_empty cc.ps_queue) do
|
||||
let a, b = Vec.pop cc.ps_queue in
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b);
|
||||
assert (N.equal (find_ a) (find_ b));
|
||||
let c = find_common_ancestor a b in
|
||||
explain_along_path cc a c;
|
||||
explain_along_path cc b c;
|
||||
done;
|
||||
cc.ps_lits
|
||||
|
||||
let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list =
|
||||
ps_clear cc;
|
||||
cc.ps_lits <- init;
|
||||
ps_add_obligation cc n1 n2;
|
||||
explain_loop cc
|
||||
|
||||
let explain_unfold ?(init=[]) cc (e:explanation) : lit list =
|
||||
ps_clear cc;
|
||||
cc.ps_lits <- init;
|
||||
decompose_explain cc e;
|
||||
explain_loop cc
|
||||
in aux acc a
|
||||
|
||||
(* add a term *)
|
||||
let [@inline] rec add_term_rec_ cc t : node =
|
||||
|
|
@ -606,9 +581,9 @@ module Make(CC_A: ARG) = struct
|
|||
Log.debugf 5
|
||||
(fun k->k "(@[<hv>cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])"
|
||||
N.pp ra N.pp rb Expl.pp e_ab);
|
||||
let lits = explain_unfold cc e_ab in
|
||||
let lits = explain_eq_n ~init:lits cc a ra in
|
||||
let lits = explain_eq_n ~init:lits cc b rb in
|
||||
let lits = explain_decompose cc [] e_ab in
|
||||
let lits = explain_pair cc lits a ra in
|
||||
let lits = explain_pair cc lits b rb in
|
||||
raise_conflict cc acts (List.rev_map Lit.neg lits)
|
||||
);
|
||||
(* We will merge [r_from] into [r_into].
|
||||
|
|
@ -647,11 +622,15 @@ module Make(CC_A: ARG) = struct
|
|||
(fun u ->
|
||||
assert (u.n_root == r_from);
|
||||
u.n_root <- r_into);
|
||||
(* now merge the classes *)
|
||||
(* capture current state *)
|
||||
let r_into_old_next = r_into.n_next in
|
||||
let r_from_old_next = r_from.n_next in
|
||||
let r_into_old_parents = r_into.n_parents in
|
||||
(* swap [into.next] and [from.next], merging the classes *)
|
||||
r_into.n_next <- r_from_old_next;
|
||||
r_from.n_next <- r_into_old_next;
|
||||
r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents;
|
||||
r_into.n_size <- r_into.n_size + r_from.n_size;
|
||||
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
|
||||
on_backtrack cc
|
||||
(fun () ->
|
||||
|
|
@ -661,11 +640,11 @@ module Make(CC_A: ARG) = struct
|
|||
r_into.n_next <- r_into_old_next;
|
||||
r_from.n_next <- r_from_old_next;
|
||||
r_into.n_parents <- r_into_old_parents;
|
||||
(* NOTE: this must come after the restoration of [next] pointers,
|
||||
otherwise we'd iterate on too big a class *)
|
||||
N.iter_class_ r_from (fun u -> u.n_root <- r_from);
|
||||
r_into.n_size <- r_into.n_size - r_from.n_size;
|
||||
);
|
||||
(* swap [into.next] and [from.next], merging the classes *)
|
||||
r_into.n_next <- r_from_old_next;
|
||||
r_from.n_next <- r_into_old_next;
|
||||
end;
|
||||
(* update explanations (a -> b), arbitrarily.
|
||||
Note that here we merge the classes by adding a bridge between [a]
|
||||
|
|
@ -691,8 +670,8 @@ module Make(CC_A: ARG) = struct
|
|||
and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit =
|
||||
(* explanation for [t1 =e= t2 = r2] *)
|
||||
let half_expl = lazy (
|
||||
let expl = explain_unfold cc e_12 in
|
||||
explain_eq_n ~init:expl cc r2 t2
|
||||
let lits = explain_decompose cc [] e_12 in
|
||||
explain_pair cc lits r2 t2
|
||||
) in
|
||||
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
|
||||
contains at least one lit *)
|
||||
|
|
@ -709,7 +688,7 @@ module Make(CC_A: ARG) = struct
|
|||
Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit);
|
||||
(* complete explanation with the [u1=t1] chunk *)
|
||||
let reason =
|
||||
let e = lazy (explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1) in
|
||||
let e = lazy (explain_pair cc (Lazy.force half_expl) u1 t1) in
|
||||
fun () -> Lazy.force e
|
||||
in
|
||||
List.iter (fun f -> f cc lit reason) cc.on_propagate;
|
||||
|
|
@ -807,9 +786,8 @@ module Make(CC_A: ARG) = struct
|
|||
let raise_conflict_from_expl cc (acts:actions) expl =
|
||||
Log.debugf 5
|
||||
(fun k->k "(@[cc.theory.raise-conflict@ :expl %a@])" Expl.pp expl);
|
||||
ps_clear cc;
|
||||
decompose_explain cc expl;
|
||||
let lits = List.rev_map Lit.neg cc.ps_lits in
|
||||
let lits = explain_decompose cc [] expl in
|
||||
let lits = List.rev_map Lit.neg lits in
|
||||
raise_conflict cc acts lits
|
||||
|
||||
let merge cc n1 n2 expl =
|
||||
|
|
@ -837,9 +815,7 @@ module Make(CC_A: ARG) = struct
|
|||
on_propagate;
|
||||
pending=Vec.create();
|
||||
combine=Vec.create();
|
||||
ps_lits=[];
|
||||
undo=Backtrack_stack.create();
|
||||
ps_queue=Vec.create();
|
||||
true_;
|
||||
false_;
|
||||
stat;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue