mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
fix(cc): merge parents properly
This commit is contained in:
parent
1328d043e3
commit
bf0171fec1
2 changed files with 53 additions and 30 deletions
|
|
@ -93,6 +93,25 @@ module Make(A: ARG) = struct
|
||||||
} in
|
} in
|
||||||
n
|
n
|
||||||
|
|
||||||
|
let[@inline] is_root (n:node) : bool = n.n_root == n
|
||||||
|
|
||||||
|
(* traverse the equivalence class of [n] *)
|
||||||
|
let iter_class_ (n:node) : node Sequence.t =
|
||||||
|
fun yield ->
|
||||||
|
let rec aux u =
|
||||||
|
yield u;
|
||||||
|
if u.n_next != n then aux u.n_next
|
||||||
|
in
|
||||||
|
aux n
|
||||||
|
|
||||||
|
let iter_class n =
|
||||||
|
assert (is_root n);
|
||||||
|
iter_class_ n
|
||||||
|
|
||||||
|
let[@inline] iter_parents (n:node) : node Sequence.t =
|
||||||
|
assert (is_root n);
|
||||||
|
Bag.to_seq n.n_parents
|
||||||
|
|
||||||
type nonrec payload = payload = ..
|
type nonrec payload = payload = ..
|
||||||
|
|
||||||
let set_payload ?(can_erase=fun _->false) n e =
|
let set_payload ?(can_erase=fun _->false) n e =
|
||||||
|
|
@ -245,7 +264,6 @@ module Make(A: ARG) = struct
|
||||||
several times.
|
several times.
|
||||||
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
|
||||||
|
|
||||||
let[@inline] is_root_ (n:node) : bool = n.n_root == n
|
|
||||||
let[@inline] size_ (r:repr) = r.n_size
|
let[@inline] size_ (r:repr) = r.n_size
|
||||||
let[@inline] true_ cc = Lazy.force cc.true_
|
let[@inline] true_ cc = Lazy.force cc.true_
|
||||||
let[@inline] false_ cc = Lazy.force cc.false_
|
let[@inline] false_ cc = Lazy.force cc.false_
|
||||||
|
|
@ -268,15 +286,6 @@ module Make(A: ARG) = struct
|
||||||
root
|
root
|
||||||
)
|
)
|
||||||
|
|
||||||
(* traverse the equivalence class of [n] *)
|
|
||||||
let iter_class_ (n:node) : node Sequence.t =
|
|
||||||
fun yield ->
|
|
||||||
let rec aux u =
|
|
||||||
yield u;
|
|
||||||
if u.n_next != n then aux u.n_next
|
|
||||||
in
|
|
||||||
aux n
|
|
||||||
|
|
||||||
(* non-recursive, inlinable function for [find] *)
|
(* non-recursive, inlinable function for [find] *)
|
||||||
let[@inline] find_ (n:node) : repr =
|
let[@inline] find_ (n:node) : repr =
|
||||||
if n == n.n_root then n else find_rec n.n_root
|
if n == n.n_root then n else find_rec n.n_root
|
||||||
|
|
@ -291,7 +300,7 @@ module Make(A: ARG) = struct
|
||||||
let pp_next out n =
|
let pp_next out n =
|
||||||
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
|
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
|
||||||
let pp_root out n =
|
let pp_root out n =
|
||||||
if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
|
if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
|
||||||
let pp_expl out n = match n.n_expl with
|
let pp_expl out n = match n.n_expl with
|
||||||
| FL_none -> ()
|
| FL_none -> ()
|
||||||
| FL_some e ->
|
| FL_some e ->
|
||||||
|
|
@ -367,7 +376,7 @@ module Make(A: ARG) = struct
|
||||||
|
|
||||||
let[@inline] all_classes cc : repr Sequence.t =
|
let[@inline] all_classes cc : repr Sequence.t =
|
||||||
T_tbl.values cc.tbl
|
T_tbl.values cc.tbl
|
||||||
|> Sequence.filter is_root_
|
|> Sequence.filter N.is_root
|
||||||
|
|
||||||
(* TODO: use markers and lockstep iteration instead *)
|
(* TODO: use markers and lockstep iteration instead *)
|
||||||
(* distance from [t] to its root in the proof forest *)
|
(* distance from [t] to its root in the proof forest *)
|
||||||
|
|
@ -464,7 +473,7 @@ module Make(A: ARG) = struct
|
||||||
nodes tagged with [tag]
|
nodes tagged with [tag]
|
||||||
precond: [n] is a representative *)
|
precond: [n] is a representative *)
|
||||||
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
|
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
|
||||||
assert (is_root_ n);
|
assert (N.is_root n);
|
||||||
if not (Util.Int_map.mem tag n.n_tags) then (
|
if not (Util.Int_map.mem tag n.n_tags) then (
|
||||||
on_backtrack cc
|
on_backtrack cc
|
||||||
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
|
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
|
||||||
|
|
@ -545,7 +554,7 @@ module Make(A: ARG) = struct
|
||||||
be merged w.r.t. the theories.
|
be merged w.r.t. the theories.
|
||||||
Side effect: also pushes sub-tasks *)
|
Side effect: also pushes sub-tasks *)
|
||||||
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
|
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
|
||||||
assert (is_root_ rb);
|
assert (N.is_root rb);
|
||||||
match cc.on_merge with
|
match cc.on_merge with
|
||||||
| Some f -> f ra rb e
|
| Some f -> f ra rb e
|
||||||
| None -> ()
|
| None -> ()
|
||||||
|
|
@ -620,8 +629,8 @@ module Make(A: ARG) = struct
|
||||||
let ra = find_ a in
|
let ra = find_ a in
|
||||||
let rb = find_ b in
|
let rb = find_ b in
|
||||||
if not @@ N.equal ra rb then (
|
if not @@ N.equal ra rb then (
|
||||||
assert (is_root_ ra);
|
assert (N.is_root ra);
|
||||||
assert (is_root_ rb);
|
assert (N.is_root rb);
|
||||||
(* check we're not merging [true] and [false] *)
|
(* check we're not merging [true] and [false] *)
|
||||||
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
|
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
|
||||||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
|
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
|
||||||
|
|
@ -674,23 +683,22 @@ module Make(A: ARG) = struct
|
||||||
merge_bool rb b ra a;
|
merge_bool rb b ra a;
|
||||||
(* perform [union r_from r_into] *)
|
(* perform [union r_from r_into] *)
|
||||||
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
|
||||||
(* TODO: only iterate on parents of [rb] *)
|
|
||||||
(* TODO: [ra.parents <- ra.parent ++ rb.parents] *)
|
|
||||||
begin
|
begin
|
||||||
(* for each node in [r_from]'s class:
|
(* parents might have a different signature, check for collisions *)
|
||||||
- make it point to [r_into]
|
N.iter_parents r_from
|
||||||
- push it into [st.pending] *)
|
(fun parent -> push_pending cc parent);
|
||||||
iter_class_ r_from
|
(* for each node in [r_from]'s class, make it point to [r_into] *)
|
||||||
|
N.iter_class r_from
|
||||||
(fun u ->
|
(fun u ->
|
||||||
assert (u.n_root == r_from);
|
assert (u.n_root == r_from);
|
||||||
on_backtrack cc (fun () -> u.n_root <- r_from);
|
u.n_root <- r_into);
|
||||||
u.n_root <- r_into;
|
|
||||||
Bag.to_seq u.n_parents
|
|
||||||
(fun parent -> push_pending cc parent));
|
|
||||||
(* now merge the classes *)
|
(* now merge the classes *)
|
||||||
let r_into_old_tags = r_into.n_tags in
|
let r_into_old_tags = r_into.n_tags in
|
||||||
let r_into_old_next = r_into.n_next in
|
let r_into_old_next = r_into.n_next in
|
||||||
let r_from_old_next = r_from.n_next in
|
let r_from_old_next = r_from.n_next in
|
||||||
|
let r_into_old_parents = r_into.n_parents in
|
||||||
|
r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents;
|
||||||
|
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
|
||||||
on_backtrack cc
|
on_backtrack cc
|
||||||
(fun () ->
|
(fun () ->
|
||||||
Log.debugf 15
|
Log.debugf 15
|
||||||
|
|
@ -698,7 +706,10 @@ module Make(A: ARG) = struct
|
||||||
N.pp r_from N.pp r_into);
|
N.pp r_from N.pp r_into);
|
||||||
r_into.n_next <- r_into_old_next;
|
r_into.n_next <- r_into_old_next;
|
||||||
r_from.n_next <- r_from_old_next;
|
r_from.n_next <- r_from_old_next;
|
||||||
r_into.n_tags <- r_into_old_tags);
|
r_into.n_tags <- r_into_old_tags;
|
||||||
|
r_into.n_parents <- r_into_old_parents;
|
||||||
|
N.iter_class_ r_from (fun u -> u.n_root <- r_from);
|
||||||
|
);
|
||||||
r_into.n_tags <- new_tags;
|
r_into.n_tags <- new_tags;
|
||||||
(* swap [into.next] and [from.next], merging the classes *)
|
(* swap [into.next] and [from.next], merging the classes *)
|
||||||
r_into.n_next <- r_from_old_next;
|
r_into.n_next <- r_from_old_next;
|
||||||
|
|
@ -752,7 +763,9 @@ module Make(A: ARG) = struct
|
||||||
let expl = explain_unfold cc e_12 in
|
let expl = explain_unfold cc e_12 in
|
||||||
explain_eq_n ~init:expl cc r2 t2
|
explain_eq_n ~init:expl cc r2 t2
|
||||||
) in
|
) in
|
||||||
iter_class_ r1
|
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
|
||||||
|
contains at least one lit *)
|
||||||
|
N.iter_class r1
|
||||||
(fun u1 ->
|
(fun u1 ->
|
||||||
(* propagate if:
|
(* propagate if:
|
||||||
- [u1] is a proper literal
|
- [u1] is a proper literal
|
||||||
|
|
@ -905,10 +918,10 @@ module Make(A: ARG) = struct
|
||||||
(* populate [repr -> value] table *)
|
(* populate [repr -> value] table *)
|
||||||
T_tbl.values cc.tbl
|
T_tbl.values cc.tbl
|
||||||
(fun r ->
|
(fun r ->
|
||||||
if is_root_ r then (
|
if N.is_root r then (
|
||||||
(* find a value in the class, if any *)
|
(* find a value in the class, if any *)
|
||||||
let v =
|
let v =
|
||||||
iter_class_ r
|
N.iter_class r
|
||||||
|> Sequence.find_map (fun n -> Model.eval m n.n_term)
|
|> Sequence.find_map (fun n -> Model.eval m n.n_term)
|
||||||
in
|
in
|
||||||
let v = match v with
|
let v = match v with
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,16 @@ module type S = sig
|
||||||
val hash : t -> int
|
val hash : t -> int
|
||||||
val pp : t Fmt.printer
|
val pp : t Fmt.printer
|
||||||
|
|
||||||
|
val is_root : t -> bool
|
||||||
|
|
||||||
|
val iter_class : t -> t Sequence.t
|
||||||
|
(** Traverse the congruence class.
|
||||||
|
Invariant: [is_root n] (see {!find} below) *)
|
||||||
|
|
||||||
|
val iter_parents : t -> t Sequence.t
|
||||||
|
(** Traverse the parents of the class.
|
||||||
|
Invariant: [is_root n] (see {!find} below) *)
|
||||||
|
|
||||||
type nonrec payload = payload = ..
|
type nonrec payload = payload = ..
|
||||||
|
|
||||||
val payload_find: f:(payload -> 'a option) -> t -> 'a option
|
val payload_find: f:(payload -> 'a option) -> t -> 'a option
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue