mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 11:15:43 -05:00
fix(cc): merge parents properly
This commit is contained in:
parent
1328d043e3
commit
bf0171fec1
2 changed files with 53 additions and 30 deletions
|
|
@ -93,6 +93,25 @@ module Make(A: ARG) = struct
|
|||
} in
|
||||
n
|
||||
|
||||
let[@inline] is_root (n:node) : bool = n.n_root == n
|
||||
|
||||
(* traverse the equivalence class of [n] *)
|
||||
let iter_class_ (n:node) : node Sequence.t =
|
||||
fun yield ->
|
||||
let rec aux u =
|
||||
yield u;
|
||||
if u.n_next != n then aux u.n_next
|
||||
in
|
||||
aux n
|
||||
|
||||
let iter_class n =
|
||||
assert (is_root n);
|
||||
iter_class_ n
|
||||
|
||||
let[@inline] iter_parents (n:node) : node Sequence.t =
|
||||
assert (is_root n);
|
||||
Bag.to_seq n.n_parents
|
||||
|
||||
type nonrec payload = payload = ..
|
||||
|
||||
let set_payload ?(can_erase=fun _->false) n e =
|
||||
|
|
@ -245,7 +264,6 @@ module Make(A: ARG) = struct
|
|||
several times.
|
||||
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
||||
|
||||
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
||||
let[@inline] size_ (r:repr) = r.n_size
|
||||
let[@inline] true_ cc = Lazy.force cc.true_
|
||||
let[@inline] false_ cc = Lazy.force cc.false_
|
||||
|
|
@ -268,15 +286,6 @@ module Make(A: ARG) = struct
|
|||
root
|
||||
)
|
||||
|
||||
(* traverse the equivalence class of [n] *)
|
||||
let iter_class_ (n:node) : node Sequence.t =
|
||||
fun yield ->
|
||||
let rec aux u =
|
||||
yield u;
|
||||
if u.n_next != n then aux u.n_next
|
||||
in
|
||||
aux n
|
||||
|
||||
(* non-recursive, inlinable function for [find] *)
|
||||
let[@inline] find_ (n:node) : repr =
|
||||
if n == n.n_root then n else find_rec n.n_root
|
||||
|
|
@ -291,7 +300,7 @@ module Make(A: ARG) = struct
|
|||
let pp_next out n =
|
||||
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
|
||||
let pp_root out n =
|
||||
if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
|
||||
if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
|
||||
let pp_expl out n = match n.n_expl with
|
||||
| FL_none -> ()
|
||||
| FL_some e ->
|
||||
|
|
@ -367,7 +376,7 @@ module Make(A: ARG) = struct
|
|||
|
||||
let[@inline] all_classes cc : repr Sequence.t =
|
||||
T_tbl.values cc.tbl
|
||||
|> Sequence.filter is_root_
|
||||
|> Sequence.filter N.is_root
|
||||
|
||||
(* TODO: use markers and lockstep iteration instead *)
|
||||
(* distance from [t] to its root in the proof forest *)
|
||||
|
|
@ -464,7 +473,7 @@ module Make(A: ARG) = struct
|
|||
nodes tagged with [tag]
|
||||
precond: [n] is a representative *)
|
||||
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
|
||||
assert (is_root_ n);
|
||||
assert (N.is_root n);
|
||||
if not (Util.Int_map.mem tag n.n_tags) then (
|
||||
on_backtrack cc
|
||||
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
|
||||
|
|
@ -545,7 +554,7 @@ module Make(A: ARG) = struct
|
|||
be merged w.r.t. the theories.
|
||||
Side effect: also pushes sub-tasks *)
|
||||
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
|
||||
assert (is_root_ rb);
|
||||
assert (N.is_root rb);
|
||||
match cc.on_merge with
|
||||
| Some f -> f ra rb e
|
||||
| None -> ()
|
||||
|
|
@ -620,8 +629,8 @@ module Make(A: ARG) = struct
|
|||
let ra = find_ a in
|
||||
let rb = find_ b in
|
||||
if not @@ N.equal ra rb then (
|
||||
assert (is_root_ ra);
|
||||
assert (is_root_ rb);
|
||||
assert (N.is_root ra);
|
||||
assert (N.is_root rb);
|
||||
(* check we're not merging [true] and [false] *)
|
||||
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
|
||||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
|
||||
|
|
@ -674,23 +683,22 @@ module Make(A: ARG) = struct
|
|||
merge_bool rb b ra a;
|
||||
(* perform [union r_from r_into] *)
|
||||
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
||||
(* TODO: only iterate on parents of [rb] *)
|
||||
(* TODO: [ra.parents <- ra.parent ++ rb.parents] *)
|
||||
begin
|
||||
(* for each node in [r_from]'s class:
|
||||
- make it point to [r_into]
|
||||
- push it into [st.pending] *)
|
||||
iter_class_ r_from
|
||||
(* parents might have a different signature, check for collisions *)
|
||||
N.iter_parents r_from
|
||||
(fun parent -> push_pending cc parent);
|
||||
(* for each node in [r_from]'s class, make it point to [r_into] *)
|
||||
N.iter_class r_from
|
||||
(fun u ->
|
||||
assert (u.n_root == r_from);
|
||||
on_backtrack cc (fun () -> u.n_root <- r_from);
|
||||
u.n_root <- r_into;
|
||||
Bag.to_seq u.n_parents
|
||||
(fun parent -> push_pending cc parent));
|
||||
u.n_root <- r_into);
|
||||
(* now merge the classes *)
|
||||
let r_into_old_tags = r_into.n_tags in
|
||||
let r_into_old_next = r_into.n_next in
|
||||
let r_from_old_next = r_from.n_next in
|
||||
let r_into_old_parents = r_into.n_parents in
|
||||
r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents;
|
||||
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
|
||||
on_backtrack cc
|
||||
(fun () ->
|
||||
Log.debugf 15
|
||||
|
|
@ -698,7 +706,10 @@ module Make(A: ARG) = struct
|
|||
N.pp r_from N.pp r_into);
|
||||
r_into.n_next <- r_into_old_next;
|
||||
r_from.n_next <- r_from_old_next;
|
||||
r_into.n_tags <- r_into_old_tags);
|
||||
r_into.n_tags <- r_into_old_tags;
|
||||
r_into.n_parents <- r_into_old_parents;
|
||||
N.iter_class_ r_from (fun u -> u.n_root <- r_from);
|
||||
);
|
||||
r_into.n_tags <- new_tags;
|
||||
(* swap [into.next] and [from.next], merging the classes *)
|
||||
r_into.n_next <- r_from_old_next;
|
||||
|
|
@ -752,7 +763,9 @@ module Make(A: ARG) = struct
|
|||
let expl = explain_unfold cc e_12 in
|
||||
explain_eq_n ~init:expl cc r2 t2
|
||||
) in
|
||||
iter_class_ r1
|
||||
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
|
||||
contains at least one lit *)
|
||||
N.iter_class r1
|
||||
(fun u1 ->
|
||||
(* propagate if:
|
||||
- [u1] is a proper literal
|
||||
|
|
@ -905,10 +918,10 @@ module Make(A: ARG) = struct
|
|||
(* populate [repr -> value] table *)
|
||||
T_tbl.values cc.tbl
|
||||
(fun r ->
|
||||
if is_root_ r then (
|
||||
if N.is_root r then (
|
||||
(* find a value in the class, if any *)
|
||||
let v =
|
||||
iter_class_ r
|
||||
N.iter_class r
|
||||
|> Sequence.find_map (fun n -> Model.eval m n.n_term)
|
||||
in
|
||||
let v = match v with
|
||||
|
|
|
|||
|
|
@ -43,6 +43,16 @@ module type S = sig
|
|||
val hash : t -> int
|
||||
val pp : t Fmt.printer
|
||||
|
||||
val is_root : t -> bool
|
||||
|
||||
val iter_class : t -> t Sequence.t
|
||||
(** Traverse the congruence class.
|
||||
Invariant: [is_root n] (see {!find} below) *)
|
||||
|
||||
val iter_parents : t -> t Sequence.t
|
||||
(** Traverse the parents of the class.
|
||||
Invariant: [is_root n] (see {!find} below) *)
|
||||
|
||||
type nonrec payload = payload = ..
|
||||
|
||||
val payload_find: f:(payload -> 'a option) -> t -> 'a option
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue