fix(cc): merge parents properly

This commit is contained in:
Simon Cruanes 2019-02-09 22:11:50 -06:00
parent 1328d043e3
commit bf0171fec1
2 changed files with 53 additions and 30 deletions

View file

@ -93,6 +93,25 @@ module Make(A: ARG) = struct
} in
n
let[@inline] is_root (n:node) : bool = n.n_root == n
(* traverse the equivalence class of [n] *)
let iter_class_ (n:node) : node Sequence.t =
fun yield ->
let rec aux u =
yield u;
if u.n_next != n then aux u.n_next
in
aux n
let iter_class n =
assert (is_root n);
iter_class_ n
let[@inline] iter_parents (n:node) : node Sequence.t =
assert (is_root n);
Bag.to_seq n.n_parents
type nonrec payload = payload = ..
let set_payload ?(can_erase=fun _->false) n e =
@ -245,7 +264,6 @@ module Make(A: ARG) = struct
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
let[@inline] is_root_ (n:node) : bool = n.n_root == n
let[@inline] size_ (r:repr) = r.n_size
let[@inline] true_ cc = Lazy.force cc.true_
let[@inline] false_ cc = Lazy.force cc.false_
@ -268,15 +286,6 @@ module Make(A: ARG) = struct
root
)
(* traverse the equivalence class of [n] *)
let iter_class_ (n:node) : node Sequence.t =
fun yield ->
let rec aux u =
yield u;
if u.n_next != n then aux u.n_next
in
aux n
(* non-recursive, inlinable function for [find] *)
let[@inline] find_ (n:node) : repr =
if n == n.n_root then n else find_rec n.n_root
@ -291,7 +300,7 @@ module Make(A: ARG) = struct
let pp_next out n =
Fmt.fprintf out "@ :next %a" N.pp n.n_next in
let pp_root out n =
if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in
let pp_expl out n = match n.n_expl with
| FL_none -> ()
| FL_some e ->
@ -367,7 +376,7 @@ module Make(A: ARG) = struct
let[@inline] all_classes cc : repr Sequence.t =
T_tbl.values cc.tbl
|> Sequence.filter is_root_
|> Sequence.filter N.is_root
(* TODO: use markers and lockstep iteration instead *)
(* distance from [t] to its root in the proof forest *)
@ -464,7 +473,7 @@ module Make(A: ARG) = struct
nodes tagged with [tag]
precond: [n] is a representative *)
let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit =
assert (is_root_ n);
assert (N.is_root n);
if not (Util.Int_map.mem tag n.n_tags) then (
on_backtrack cc
(fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags);
@ -545,7 +554,7 @@ module Make(A: ARG) = struct
be merged w.r.t. the theories.
Side effect: also pushes sub-tasks *)
let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
assert (is_root_ rb);
assert (N.is_root rb);
match cc.on_merge with
| Some f -> f ra rb e
| None -> ()
@ -620,8 +629,8 @@ module Make(A: ARG) = struct
let ra = find_ a in
let rb = find_ b in
if not @@ N.equal ra rb then (
assert (is_root_ ra);
assert (is_root_ rb);
assert (N.is_root ra);
assert (N.is_root rb);
(* check we're not merging [true] and [false] *)
if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) ||
(N.equal rb (true_ cc) && N.equal ra (false_ cc)) then (
@ -674,23 +683,22 @@ module Make(A: ARG) = struct
merge_bool rb b ra a;
(* perform [union r_from r_into] *)
Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into);
(* TODO: only iterate on parents of [rb] *)
(* TODO: [ra.parents <- ra.parent ++ rb.parents] *)
begin
(* for each node in [r_from]'s class:
- make it point to [r_into]
- push it into [st.pending] *)
iter_class_ r_from
(* parents might have a different signature, check for collisions *)
N.iter_parents r_from
(fun parent -> push_pending cc parent);
(* for each node in [r_from]'s class, make it point to [r_into] *)
N.iter_class r_from
(fun u ->
assert (u.n_root == r_from);
on_backtrack cc (fun () -> u.n_root <- r_from);
u.n_root <- r_into;
Bag.to_seq u.n_parents
(fun parent -> push_pending cc parent));
u.n_root <- r_into);
(* now merge the classes *)
let r_into_old_tags = r_into.n_tags in
let r_into_old_next = r_into.n_next in
let r_from_old_next = r_from.n_next in
let r_into_old_parents = r_into.n_parents in
r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents;
(* on backtrack, unmerge classes and restore the pointers to [r_from] *)
on_backtrack cc
(fun () ->
Log.debugf 15
@ -698,7 +706,10 @@ module Make(A: ARG) = struct
N.pp r_from N.pp r_into);
r_into.n_next <- r_into_old_next;
r_from.n_next <- r_from_old_next;
r_into.n_tags <- r_into_old_tags);
r_into.n_tags <- r_into_old_tags;
r_into.n_parents <- r_into_old_parents;
N.iter_class_ r_from (fun u -> u.n_root <- r_from);
);
r_into.n_tags <- new_tags;
(* swap [into.next] and [from.next], merging the classes *)
r_into.n_next <- r_from_old_next;
@ -752,7 +763,9 @@ module Make(A: ARG) = struct
let expl = explain_unfold cc e_12 in
explain_eq_n ~init:expl cc r2 t2
) in
iter_class_ r1
(* TODO: flag per class, `or`-ed on merge, to indicate if the class
contains at least one lit *)
N.iter_class r1
(fun u1 ->
(* propagate if:
- [u1] is a proper literal
@ -905,10 +918,10 @@ module Make(A: ARG) = struct
(* populate [repr -> value] table *)
T_tbl.values cc.tbl
(fun r ->
if is_root_ r then (
if N.is_root r then (
(* find a value in the class, if any *)
let v =
iter_class_ r
N.iter_class r
|> Sequence.find_map (fun n -> Model.eval m n.n_term)
in
let v = match v with

View file

@ -43,6 +43,16 @@ module type S = sig
val hash : t -> int
val pp : t Fmt.printer
val is_root : t -> bool
val iter_class : t -> t Sequence.t
(** Traverse the congruence class.
Invariant: [is_root n] (see {!find} below) *)
val iter_parents : t -> t Sequence.t
(** Traverse the parents of the class.
Invariant: [is_root n] (see {!find} below) *)
type nonrec payload = payload = ..
val payload_find: f:(payload -> 'a option) -> t -> 'a option