From bf0171fec11ca08afc34007bbb1100e4850bcb10 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 9 Feb 2019 22:11:50 -0600 Subject: [PATCH] fix(cc): merge parents properly --- src/cc/Congruence_closure.ml | 73 ++++++++++++++++++------------- src/cc/Congruence_closure_intf.ml | 10 +++++ 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml index 687fdf20..a48c41ce 100644 --- a/src/cc/Congruence_closure.ml +++ b/src/cc/Congruence_closure.ml @@ -93,6 +93,25 @@ module Make(A: ARG) = struct } in n + let[@inline] is_root (n:node) : bool = n.n_root == n + + (* traverse the equivalence class of [n] *) + let iter_class_ (n:node) : node Sequence.t = + fun yield -> + let rec aux u = + yield u; + if u.n_next != n then aux u.n_next + in + aux n + + let iter_class n = + assert (is_root n); + iter_class_ n + + let[@inline] iter_parents (n:node) : node Sequence.t = + assert (is_root n); + Bag.to_seq n.n_parents + type nonrec payload = payload = .. let set_payload ?(can_erase=fun _->false) n e = @@ -245,7 +264,6 @@ module Make(A: ARG) = struct several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) - let[@inline] is_root_ (n:node) : bool = n.n_root == n let[@inline] size_ (r:repr) = r.n_size let[@inline] true_ cc = Lazy.force cc.true_ let[@inline] false_ cc = Lazy.force cc.false_ @@ -268,15 +286,6 @@ module Make(A: ARG) = struct root ) - (* traverse the equivalence class of [n] *) - let iter_class_ (n:node) : node Sequence.t = - fun yield -> - let rec aux u = - yield u; - if u.n_next != n then aux u.n_next - in - aux n - (* non-recursive, inlinable function for [find] *) let[@inline] find_ (n:node) : repr = if n == n.n_root then n else find_rec n.n_root @@ -291,7 +300,7 @@ module Make(A: ARG) = struct let pp_next out n = Fmt.fprintf out "@ :next %a" N.pp n.n_next in let pp_root out n = - if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in + if N.is_root n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in let pp_expl out n = match n.n_expl with | FL_none -> () | FL_some e -> @@ -367,7 +376,7 @@ module Make(A: ARG) = struct let[@inline] all_classes cc : repr Sequence.t = T_tbl.values cc.tbl - |> Sequence.filter is_root_ + |> Sequence.filter N.is_root (* TODO: use markers and lockstep iteration instead *) (* distance from [t] to its root in the proof forest *) @@ -464,7 +473,7 @@ module Make(A: ARG) = struct nodes tagged with [tag] precond: [n] is a representative *) let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit = - assert (is_root_ n); + assert (N.is_root n); if not (Util.Int_map.mem tag n.n_tags) then ( on_backtrack cc (fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags); @@ -545,7 +554,7 @@ module Make(A: ARG) = struct be merged w.r.t. the theories. Side effect: also pushes sub-tasks *) let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = - assert (is_root_ rb); + assert (N.is_root rb); match cc.on_merge with | Some f -> f ra rb e | None -> () @@ -620,8 +629,8 @@ module Make(A: ARG) = struct let ra = find_ a in let rb = find_ b in if not @@ N.equal ra rb then ( - assert (is_root_ ra); - assert (is_root_ rb); + assert (N.is_root ra); + assert (N.is_root rb); (* check we're not merging [true] and [false] *) if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) || (N.equal rb (true_ cc) && N.equal ra (false_ cc)) then ( @@ -674,23 +683,22 @@ module Make(A: ARG) = struct merge_bool rb b ra a; (* perform [union r_from r_into] *) Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); - (* TODO: only iterate on parents of [rb] *) - (* TODO: [ra.parents <- ra.parent ++ rb.parents] *) begin - (* for each node in [r_from]'s class: - - make it point to [r_into] - - push it into [st.pending] *) - iter_class_ r_from + (* parents might have a different signature, check for collisions *) + N.iter_parents r_from + (fun parent -> push_pending cc parent); + (* for each node in [r_from]'s class, make it point to [r_into] *) + N.iter_class r_from (fun u -> assert (u.n_root == r_from); - on_backtrack cc (fun () -> u.n_root <- r_from); - u.n_root <- r_into; - Bag.to_seq u.n_parents - (fun parent -> push_pending cc parent)); + u.n_root <- r_into); (* now merge the classes *) let r_into_old_tags = r_into.n_tags in let r_into_old_next = r_into.n_next in let r_from_old_next = r_from.n_next in + let r_into_old_parents = r_into.n_parents in + r_into.n_parents <- Bag.append r_into.n_parents r_from.n_parents; + (* on backtrack, unmerge classes and restore the pointers to [r_from] *) on_backtrack cc (fun () -> Log.debugf 15 @@ -698,7 +706,10 @@ module Make(A: ARG) = struct N.pp r_from N.pp r_into); r_into.n_next <- r_into_old_next; r_from.n_next <- r_from_old_next; - r_into.n_tags <- r_into_old_tags); + r_into.n_tags <- r_into_old_tags; + r_into.n_parents <- r_into_old_parents; + N.iter_class_ r_from (fun u -> u.n_root <- r_from); + ); r_into.n_tags <- new_tags; (* swap [into.next] and [from.next], merging the classes *) r_into.n_next <- r_from_old_next; @@ -752,7 +763,9 @@ module Make(A: ARG) = struct let expl = explain_unfold cc e_12 in explain_eq_n ~init:expl cc r2 t2 ) in - iter_class_ r1 + (* TODO: flag per class, `or`-ed on merge, to indicate if the class + contains at least one lit *) + N.iter_class r1 (fun u1 -> (* propagate if: - [u1] is a proper literal @@ -905,10 +918,10 @@ module Make(A: ARG) = struct (* populate [repr -> value] table *) T_tbl.values cc.tbl (fun r -> - if is_root_ r then ( + if N.is_root r then ( (* find a value in the class, if any *) let v = - iter_class_ r + N.iter_class r |> Sequence.find_map (fun n -> Model.eval m n.n_term) in let v = match v with diff --git a/src/cc/Congruence_closure_intf.ml b/src/cc/Congruence_closure_intf.ml index 7c94a4b8..9f33287e 100644 --- a/src/cc/Congruence_closure_intf.ml +++ b/src/cc/Congruence_closure_intf.ml @@ -43,6 +43,16 @@ module type S = sig val hash : t -> int val pp : t Fmt.printer + val is_root : t -> bool + + val iter_class : t -> t Sequence.t + (** Traverse the congruence class. + Invariant: [is_root n] (see {!find} below) *) + + val iter_parents : t -> t Sequence.t + (** Traverse the parents of the class. + Invariant: [is_root n] (see {!find} below) *) + type nonrec payload = payload = .. val payload_find: f:(payload -> 'a option) -> t -> 'a option