From 543f8a5a99558e4d4f83febddc16f0e2c946bf6c Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 23 Feb 2018 00:44:23 -0600 Subject: [PATCH] add `distinct` handling to congruence closure --- src/smt/Congruence_closure.ml | 592 ++++++++++++++++++------------- src/smt/Congruence_closure.mli | 18 +- src/smt/Equiv_class.ml | 3 +- src/smt/Equiv_class.mli | 2 +- src/smt/Solver.ml | 5 +- src/smt/Solver.mli | 2 +- src/smt/Solver_types.ml | 1 + src/smt/Theory.ml | 8 +- src/smt/Theory_combine.ml | 11 +- src/smt/th_bool/Dagon_th_bool.ml | 21 +- 10 files changed, 389 insertions(+), 274 deletions(-) diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 848545d5..f0db0b6d 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -28,7 +28,7 @@ type actions = { on_merge:repr -> repr -> explanation -> unit; (** Call this when two classes are merged *) - raise_conflict: 'a. Explanation.t Bag.t -> 'a; + raise_conflict: 'a. Lit.Set.t -> 'a; (** Report a conflict *) propagate: Lit.t -> Explanation.t Bag.t -> unit; @@ -65,6 +65,11 @@ type t = { several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) +let[@inline] on_backtrack_if_not_lvl_0 cc f : unit = + if not (cc.acts.at_lvl_0 ()) then ( + cc.acts.on_backtrack f + ) + let[@inline] is_root_ (n:node) : bool = n.n_root == n let[@inline] size_ (r:repr) = @@ -85,9 +90,7 @@ let rec find_rec cc (n:node) : repr = let root = find_rec cc old_root in (* path compression *) if (root :> node) != old_root then ( - if not (cc.acts.at_lvl_0 ()) then ( - cc.acts.on_backtrack (fun () -> n.n_root <- old_root); - ); + on_backtrack_if_not_lvl_0 cc (fun () -> n.n_root <- old_root); n.n_root <- (root :> node); ); root @@ -152,10 +155,8 @@ let add_signature cc (t:term) (r:repr): unit = match signature cc t with (* add, but only if not present already *) begin match Sig_tbl.get cc.signatures_tbl s with | None -> - if not (cc.acts.at_lvl_0 ()) then ( - cc.acts.on_backtrack - (fun () -> Sig_tbl.remove cc.signatures_tbl s); - ); + on_backtrack_if_not_lvl_0 cc + (fun () -> Sig_tbl.remove cc.signatures_tbl s); Sig_tbl.add cc.signatures_tbl s r; | Some r' -> assert (Equiv_class.equal r r'); @@ -166,18 +167,18 @@ let is_done (cc:t): bool = Vec.is_empty cc.combine let push_pending cc t : unit = - Log.debugf 5 (fun k->k "(@[push_pending@ %a@])" Equiv_class.pp t); + Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" Equiv_class.pp t); Vec.push cc.pending t let push_combine cc t u e : unit = Log.debugf 5 - (fun k->k "(@[push_combine@ %a@ %a@ expl: %a@])" + (fun k->k "(@[cc.push_combine@ %a@ %a@ expl: %a@])" Equiv_class.pp t Equiv_class.pp u Explanation.pp e); Vec.push cc.combine (t,u,e) let push_propagation cc (lit:lit) (expl:explanation Bag.t): unit = Log.debugf 5 - (fun k->k "(@[push_propagate@ %a@ expl: (@[%a@])@])" + (fun k->k "(@[cc.push_propagate@ %a@ expl: (@[%a@])@])" Lit.pp lit (Util.pp_seq Explanation.pp) @@ Bag.to_seq expl); cc.acts.propagate lit expl @@ -191,9 +192,8 @@ let[@inline] union cc (a:node) (b:node) (e:explanation): unit = postcondition: [n.n_expl = None] *) let rec reroot_expl (cc:t) (n:node): unit = let old_expl = n.n_expl in - if not (cc.acts.at_lvl_0 ()) then ( - cc.acts.on_backtrack (fun () -> n.n_expl <- old_expl); - ); + on_backtrack_if_not_lvl_0 cc + (fun () -> n.n_expl <- old_expl); begin match old_expl with | E_none -> () (* already root *) | E_some {next=u; expl=e_n_u} -> @@ -202,244 +202,13 @@ let rec reroot_expl (cc:t) (n:node): unit = n.n_expl <- E_none; end -let[@inline] raise_conflict (cc:t) (e:explanation Bag.t): _ = +let[@inline] raise_conflict (cc:t) (e:Lit.Set.t): _ = cc.acts.raise_conflict e let[@inline] all_classes cc : repr Sequence.t = Term.Tbl.values cc.tbl |> Sequence.filter is_root_ -(* main CC algo: add terms from [pending] to the signature table, - check for collisions *) -let rec update_pending (cc:t): unit = - (* step 2 deal with pending (parent) terms whose equiv class - might have changed *) - while not (Vec.is_empty cc.pending) do - let n = Vec.pop_last cc.pending in - (* check if some parent collided *) - begin match find_by_signature cc n.n_term with - | None -> - (* add to the signature table [n --> n.root] *) - add_signature cc n.n_term (find cc n) - | Some u -> - (* must combine [t] with [r] *) - push_combine cc n u(E_congruence (n,u)) - end; - (* FIXME: when to actually evaluate? - eval_pending cc; - *) - done; - if not (is_done cc) then ( - update_combine cc (* repeat *) - ) - -(* main CC algo: merge equivalence classes in [st.combine]. - @raise Exn_unsat if merge fails *) -and update_combine cc = - while not (Vec.is_empty cc.combine) do - let a, b, e_ab = Vec.pop_last cc.combine in - let ra = find cc a in - let rb = find cc b in - if not (Equiv_class.equal ra rb) then ( - assert (is_root_ ra); - assert (is_root_ (rb:>node)); - (* We will merge [r_from] into [r_into]. - we try to ensure that [size ra <= size rb] in general, unless - it clashes with the invariant that the representative must - be a normal form if the class contains a normal form *) - let must_solve, r_from, r_into = - match Term.is_semantic ra.n_term, Term.is_semantic rb.n_term with - | true, true -> - if size_ ra > size_ rb then true, rb, ra else true, ra, rb - | false, false -> - if size_ ra > size_ rb then false, rb, ra else false, ra, rb - | true, false -> false, rb, ra (* semantic ==> representative *) - | false, true -> false, ra, rb - in - (* solve the equation, if both [ra] and [rb] are semantic. - The equation is between signatures, so as to canonize w.r.t the - current congruence before solving *) - if must_solve then ( - let t_a = ra.n_term and t_b = rb.n_term in - match signature cc t_a, signature cc t_b with - | Some (Custom t1), Some (Custom t2) -> - begin match t1.tc.tc_t_solve t1.view t2.view with - | Solve_ok {subst=l} -> - Log.debugf 5 - (fun k->k "(@[solve@ (@[= %a %a@])@ :yields (@[%a@])@])" - Term.pp t_a Term.pp t_b - (Util.pp_list @@ Util.pp_pair Equiv_class.pp Term.pp) l); - List.iter (fun (u1,u2) -> push_combine cc u1 (add cc u2) e_ab) l - | Solve_fail {expl} -> - Log.debugf 5 - (fun k->k "(@[solve-fail@ (@[= %a %a@])@ :expl %a@])" - Term.pp t_a Term.pp t_b Explanation.pp expl); - - raise_conflict cc (Bag.return expl) - end - | _ -> assert false - ); - (* remove [ra.parents] from signature, put them into [st.pending] *) - begin - Bag.to_seq (r_from:>node).n_parents - |> Sequence.iter - (fun parent -> - (* FIXME: with OCaml's hashtable, we should be able - to keep this entry (and have it become relevant later - once the signature of [parent] is backtracked) *) - remove_signature cc parent.n_term; - push_pending cc parent) - end; - (* perform [union ra rb] *) - begin - let r_from = (r_from :> node) in - let r_into = (r_into :> node) in - let rb_old_parents = r_into.n_parents in - cc.acts.on_backtrack - (fun () -> - r_from.n_root <- r_from; - r_into.n_parents <- rb_old_parents); - r_from.n_root <- r_into; - r_from.n_parents <- Bag.append rb_old_parents r_from.n_parents; - end; - (* update explanations (a -> b), arbitrarily *) - begin - reroot_expl cc a; - assert (a.n_expl = E_none); - if not (cc.acts.at_lvl_0 ()) then ( - cc.acts.on_backtrack (fun () -> a.n_expl <- E_none); - ); - a.n_expl <- E_some {next=b; expl=e_ab}; - end; - (* notify listeners of the merge *) - notify_merge cc r_from ~into:r_into e_ab; - ) - done; - (* now update pending terms again *) - update_pending cc - -(* Checks if [ra] and [~into] have compatible normal forms and can - be merged w.r.t. the theories. - Side effect: also pushes sub-tasks *) -and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = - assert (is_root_ rb); - cc.acts.on_merge ra rb e - - -(* FIXME: callback? -(* evaluation rules: if, case... *) -and eval_pending (t:term): unit = - List.iter - (fun ((module Theory):repr theory) -> Theory.eval t) - theories - *) - -(* FIXME: remove? -(* main CC algo: add missing terms to the congruence class *) -and update_add (cc:t) terms () = - while not (Queue.is_empty cc.terms_to_add) do - let t = Queue.pop cc.terms_to_add in - add cc t - done -*) - -(* add [t] to [cc] when not present already *) -and add_new_term cc (t:term) : node = - assert (not @@ mem cc t); - let n = Equiv_class.make t in - (* how to add a subterm *) - let add_to_parents_of_sub_node (sub:node) : unit = - let old_parents = sub.n_parents in - if not @@ cc.acts.at_lvl_0 () then ( - cc.acts.on_backtrack (fun () -> sub.n_parents <- old_parents); - ); - sub.n_parents <- Bag.cons n sub.n_parents; - push_pending cc sub - in - (* add sub-term to [cc], and register [n] to its parents *) - let add_sub_t (u:term) : unit = - let n_u = add cc u in - add_to_parents_of_sub_node n_u - in - (* register sub-terms, add [t] to their parent list *) - begin match t.term_cell with - | Bool _-> () - | App_cst (_, a) -> IArray.iter add_sub_t a - | If (a,b,c) -> - add_sub_t a; - add_sub_t b; - add_sub_t c - | Case (u, _) -> add_sub_t u - | Custom {view;tc} -> - (* add relevant subterms to the CC *) - tc.tc_t_relevant view add_sub_t - end; - (* remove term when we backtrack *) - if not (cc.acts.at_lvl_0 ()) then ( - cc.acts.on_backtrack (fun () -> Term.Tbl.remove cc.tbl t); - ); - (* add term to the table *) - Term.Tbl.add cc.tbl t n; - (* [n] might be merged with other equiv classes *) - push_pending cc n; - n - -(* TODO? *) -(* add [t=u] to the congruence closure, unconditionally (reduction relation) *) -and[@inline] add_eqn (cc:t) (eqn:merge_op): unit = - let t, u, expl = eqn in - push_combine cc t u expl - -(* add a term *) -and[@inline] add cc t = - try Term.Tbl.find cc.tbl t - with Not_found -> add_new_term cc t - -let[@inline] add_seq cc seq = seq (fun t -> ignore @@ add cc t) - -(* assert that this boolean literal holds *) -let assert_lit cc lit : unit = match Lit.view lit with - | Lit_fresh _ - | Lit_expanded _ -> () - | Lit_atom t -> - assert (Ty.is_prop t.term_ty); - let sign = Lit.sign lit in - (* equate t and true/false *) - let rhs = if sign then true_ cc else false_ cc in - let n = add cc t in - push_combine cc n rhs (E_lit lit); - () - -let assert_eq cc (t:term) (u:term) expl : unit = - let n1 = add cc t in - let n2 = add cc u in - if not (same_class cc n1 n2) then ( - union cc n1 n2 expl - ) - -let assert_distinct _cc (l:term list) _expl : unit = - assert (match l with[] | [_] -> false | _ -> true); - Util.errorf "unimplemented: CC.distinct" - -let create ?(size=2048) ~actions (tst:Term.state) : t = - assert (actions.at_lvl_0 ()); - let nd = Equiv_class.dummy in - let rec cc = { - tst; - acts=actions; - tbl = Term.Tbl.create size; - signatures_tbl = Sig_tbl.create size; - pending=Vec.make_empty Equiv_class.dummy; - combine= Vec.make_empty (nd,nd,E_reduce_eq(nd,nd)); - ps_lits=Lit.Set.empty; - ps_queue=Vec.make_empty (nd,nd); - true_ = lazy (add cc (Term.true_ tst)); - false_ = lazy (add cc (Term.false_ tst)); - } in - ignore (Lazy.force cc.true_); - ignore (Lazy.force cc.false_); - cc - (* distance from [t] to its root in the proof forest *) let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with | E_none -> 0 @@ -544,20 +313,343 @@ let explain_loop (cc : t) : Lit.Set.t = done; cc.ps_lits -let explain_unfold_seq cc (seq:explanation Sequence.t): Lit.Set.t = +let explain_eq_n ?(init=Lit.Set.empty) cc (n1:node) (n2:node) : Lit.Set.t = ps_clear cc; + cc.ps_lits <- init; + ps_add_obligation cc n1 n2; + explain_loop cc + +let explain_eq_t ?(init=Lit.Set.empty) cc (t1:term) (t2:term) : Lit.Set.t = + ps_clear cc; + cc.ps_lits <- init; + ps_add_obligation_t cc t1 t2; + explain_loop cc + +let explain_unfold ?(init=Lit.Set.empty) cc (e:explanation) : Lit.Set.t = + ps_clear cc; + cc.ps_lits <- init; + decompose_explain cc e; + explain_loop cc + +let explain_unfold_seq ?(init=Lit.Set.empty) cc (seq:explanation Sequence.t): Lit.Set.t = + ps_clear cc; + cc.ps_lits <- init; Sequence.iter (decompose_explain cc) seq; explain_loop cc -let explain_unfold_bag cc (b:explanation Bag.t) : Lit.Set.t = +let explain_unfold_bag ?(init=Lit.Set.empty) cc (b:explanation Bag.t) : Lit.Set.t = match b with - | Bag.E -> Lit.Set.empty - | Bag.L (E_lit lit) -> Lit.Set.singleton lit + | Bag.E -> init + | Bag.L (E_lit lit) -> Lit.Set.add lit init | _ -> ps_clear cc; + cc.ps_lits <- init; Sequence.iter (decompose_explain cc) (Bag.to_seq b); explain_loop cc +(* add [tag] to [n] + precond: [n] is a representative *) +let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit = + assert (is_root_ n); + if not (Util.Int_map.mem tag n.n_tags) then ( + on_backtrack_if_not_lvl_0 cc + (fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags); + n.n_tags <- Util.Int_map.add tag expl n.n_tags; + ) + +(* conflict because [expl => t1 ≠ t2] but they are the same *) +let distinct_conflict cc (t1 : term) (t2: term) (expl:explanation Bag.t) : 'a = + let lits = explain_unfold_bag cc expl in + let lits = explain_eq_t ~init:lits cc t1 t2 in + raise_conflict cc lits + +(* main CC algo: add terms from [pending] to the signature table, + check for collisions *) +let rec update_pending (cc:t): unit = + (* step 2 deal with pending (parent) terms whose equiv class + might have changed *) + while not (Vec.is_empty cc.pending) do + let n = Vec.pop_last cc.pending in + (* check if some parent collided *) + begin match find_by_signature cc n.n_term with + | None -> + (* add to the signature table [n --> n.root] *) + add_signature cc n.n_term (find cc n) + | Some u -> + (* must combine [t] with [r] *) + push_combine cc n u(E_congruence (n,u)) + end; + (* FIXME: when to actually evaluate? + eval_pending cc; + *) + done; + if not (is_done cc) then ( + update_combine cc (* repeat *) + ) + +(* main CC algo: merge equivalence classes in [st.combine]. + @raise Exn_unsat if merge fails *) +and update_combine cc = + while not (Vec.is_empty cc.combine) do + let a, b, e_ab = Vec.pop_last cc.combine in + let ra = find cc a in + let rb = find cc b in + if not (Equiv_class.equal ra rb) then ( + assert (is_root_ ra); + assert (is_root_ (rb:>node)); + (* We will merge [r_from] into [r_into]. + we try to ensure that [size ra <= size rb] in general, unless + it clashes with the invariant that the representative must + be a normal form if the class contains a normal form *) + let must_solve, r_from, r_into = + match Term.is_semantic ra.n_term, Term.is_semantic rb.n_term with + | true, true -> + if size_ ra > size_ rb then true, rb, ra else true, ra, rb + | false, false -> + if size_ ra > size_ rb then false, rb, ra else false, ra, rb + | true, false -> false, rb, ra (* semantic ==> representative *) + | false, true -> false, ra, rb + in + let new_tags = + Util.Int_map.union + (fun _i e1 e2 -> + (* both maps contain same tag [_i]. conflict clause: + [e1 & e2 & e_ab] impossible *) + Log.debugf 5 (fun k->k "(cc.merge.distinct_conflict@ :tag %d@])" _i); + let lits = explain_unfold cc e1 in + let lits = explain_unfold ~init:lits cc e2 in + let lits = explain_unfold ~init:lits cc e_ab in + raise_conflict cc lits) + ra.n_tags rb.n_tags + in + (* solve the equation, if both [ra] and [rb] are semantic. + The equation is between signatures, so as to canonize w.r.t the + current congruence before solving *) + if must_solve then ( + let t_a = ra.n_term and t_b = rb.n_term in + match signature cc t_a, signature cc t_b with + | Some (Custom t1), Some (Custom t2) -> + begin match t1.tc.tc_t_solve t1.view t2.view with + | Solve_ok {subst=l} -> + Log.debugf 5 + (fun k->k "(@[solve@ (@[= %a %a@])@ :yields (@[%a@])@])" + Term.pp t_a Term.pp t_b + (Util.pp_list @@ Util.pp_pair Equiv_class.pp Term.pp) l); + List.iter (fun (u1,u2) -> push_combine cc u1 (add cc u2) e_ab) l + | Solve_fail {expl} -> + Log.debugf 5 + (fun k->k "(@[solve-fail@ (@[= %a %a@])@ :expl %a@])" + Term.pp t_a Term.pp t_b Explanation.pp expl); + let lits = explain_unfold cc expl in + raise_conflict cc lits + end + | _ -> assert false + ); + (* remove [ra.parents] from signature, put them into [st.pending] *) + begin + Bag.to_seq (r_from:>node).n_parents + |> Sequence.iter + (fun parent -> + (* FIXME: with OCaml's hashtable, we should be able + to keep this entry (and have it become relevant later + once the signature of [parent] is backtracked) *) + remove_signature cc parent.n_term; + push_pending cc parent) + end; + (* perform [union ra rb] *) + begin + let r_from = (r_from :> node) in + let r_into = (r_into :> node) in + let r_into_old_parents = r_into.n_parents in + let r_into_old_tags = r_into.n_tags in + on_backtrack_if_not_lvl_0 cc + (fun () -> + r_from.n_root <- r_from; + r_into.n_tags <- r_into_old_tags; + r_into.n_parents <- r_into_old_parents); + r_from.n_root <- r_into; + r_into.n_tags <- new_tags; + r_from.n_parents <- Bag.append r_into_old_parents r_from.n_parents; + end; + (* update explanations (a -> b), arbitrarily *) + begin + reroot_expl cc a; + assert (a.n_expl = E_none); + on_backtrack_if_not_lvl_0 cc (fun () -> a.n_expl <- E_none); + a.n_expl <- E_some {next=b; expl=e_ab}; + end; + (* notify listeners of the merge *) + notify_merge cc r_from ~into:r_into e_ab; + ) + done; + (* now update pending terms again *) + update_pending cc + +(* Checks if [ra] and [~into] have compatible normal forms and can + be merged w.r.t. the theories. + Side effect: also pushes sub-tasks *) +and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = + assert (is_root_ rb); + cc.acts.on_merge ra rb e + + +(* FIXME: callback? +(* evaluation rules: if, case... *) +and eval_pending (t:term): unit = + List.iter + (fun ((module Theory):repr theory) -> Theory.eval t) + theories + *) + +(* FIXME: remove? +(* main CC algo: add missing terms to the congruence class *) +and update_add (cc:t) terms () = + while not (Queue.is_empty cc.terms_to_add) do + let t = Queue.pop cc.terms_to_add in + add cc t + done +*) + +(* add [t] to [cc] when not present already *) +and add_new_term cc (t:term) : node = + assert (not @@ mem cc t); + let n = Equiv_class.make t in + (* how to add a subterm *) + let add_to_parents_of_sub_node (sub:node) : unit = + let old_parents = sub.n_parents in + on_backtrack_if_not_lvl_0 cc + (fun () -> sub.n_parents <- old_parents); + sub.n_parents <- Bag.cons n sub.n_parents; + push_pending cc sub + in + (* add sub-term to [cc], and register [n] to its parents *) + let add_sub_t (u:term) : unit = + let n_u = add cc u in + add_to_parents_of_sub_node n_u + in + (* register sub-terms, add [t] to their parent list *) + begin match t.term_cell with + | Bool _-> () + | App_cst (_, a) -> IArray.iter add_sub_t a + | If (a,b,c) -> + add_sub_t a; + add_sub_t b; + add_sub_t c + | Case (u, _) -> add_sub_t u + | Custom {view;tc} -> + (* add relevant subterms to the CC *) + tc.tc_t_relevant view add_sub_t + end; + (* remove term when we backtrack *) + on_backtrack_if_not_lvl_0 cc (fun () -> Term.Tbl.remove cc.tbl t); + (* add term to the table *) + Term.Tbl.add cc.tbl t n; + (* [n] might be merged with other equiv classes *) + push_pending cc n; + n + +(* TODO? *) +(* add [t=u] to the congruence closure, unconditionally (reduction relation) *) +and[@inline] add_eqn (cc:t) (eqn:merge_op): unit = + let t, u, expl = eqn in + push_combine cc t u expl + +(* add a term *) +and[@inline] add cc t = + try Term.Tbl.find cc.tbl t + with Not_found -> add_new_term cc t + +let[@inline] add_seq cc seq = seq (fun t -> ignore @@ add cc t) + +(* assert that this boolean literal holds *) +let assert_lit cc lit : unit = match Lit.view lit with + | Lit_fresh _ + | Lit_expanded _ -> () + | Lit_atom t -> + assert (Ty.is_prop t.term_ty); + let sign = Lit.sign lit in + (* equate t and true/false *) + let rhs = if sign then true_ cc else false_ cc in + let n = add cc t in + push_combine cc n rhs (E_lit lit); + () + +let assert_eq cc (t:term) (u:term) expl : unit = + let n1 = add cc t in + let n2 = add cc u in + if not (same_class cc n1 n2) then ( + union cc n1 n2 expl + ) + +let assert_distinct cc (l:term list) ~neq expl : unit = + assert (match l with[] | [_] -> false | _ -> true); + let tag = Term.id neq in + Log.debugf 5 + (fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list Term.pp) l tag); + let l = List.map (fun t -> t, add cc t |> find cc) l in + let coll = + Sequence.diagonal_l l + |> Sequence.find_pred (fun ((_,n1),(_,n2)) -> Equiv_class.equal n1 n2) + in + begin match coll with + | Some ((t1,_n1),(t2,_n2)) -> + (* two classes are already equal *) + Log.debugf 5 (fun k->k "(@[cc.assert_distinct.conflict@ %a = %a@])" Term.pp t1 Term.pp t2); + let lits = explain_unfold cc expl in + let lits = explain_eq_t ~init:lits cc t1 t2 in + raise_conflict cc lits + | None -> + (* put a tag on all equivalence classes, that will make their merge fail *) + List.iter (fun (_,n) -> add_tag_n cc n tag expl) l + end + +(* handling "distinct" constraints *) +module Distinct_ = struct + module Int_set = Util.Int_set + + type Equiv_class.payload += + | P_dist of { + mutable tags: Int_set.t; + } + + let get (n:Equiv_class.t) : Int_set.t = + Equiv_class.payload_find + ~f:(function + | P_dist {tags} -> Some tags + | _ -> None) + n + |> CCOpt.get_or ~default:Int_set.empty + + let add_tag (tag:int) (n:Equiv_class.t) : unit = + if not @@ + CCList.exists + (function + | P_dist r -> r.tags <- Int_set.add tag r.tags; true + | _ -> false) + (Equiv_class.payload n) + then ( + Equiv_class.set_payload n (P_dist {tags=Int_set.singleton tag}) + ) +end + +let create ?(size=2048) ~actions (tst:Term.state) : t = + assert (actions.at_lvl_0 ()); + let nd = Equiv_class.dummy in + let rec cc = { + tst; + acts=actions; + tbl = Term.Tbl.create size; + signatures_tbl = Sig_tbl.create size; + pending=Vec.make_empty Equiv_class.dummy; + combine= Vec.make_empty (nd,nd,E_reduce_eq(nd,nd)); + ps_lits=Lit.Set.empty; + ps_queue=Vec.make_empty (nd,nd); + true_ = lazy (add cc (Term.true_ tst)); + false_ = lazy (add cc (Term.false_ tst)); + } in + ignore (Lazy.force cc.true_); + ignore (Lazy.force cc.false_); + cc (* check satisfiability, update congruence closure *) let check (cc:t) : unit = Log.debug 5 "(cc.check)"; diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli index 989b9dd2..db70adb5 100644 --- a/src/smt/Congruence_closure.mli +++ b/src/smt/Congruence_closure.mli @@ -21,9 +21,10 @@ type actions = { on_merge:repr -> repr -> explanation -> unit; (** Call this when two classes are merged *) - raise_conflict: 'a. Explanation.t Bag.t -> 'a; + raise_conflict: 'a. Lit.Set.t -> 'a; (** Report a conflict *) + (* FIXME: take a delayed Lit.Set.t? *) propagate: Lit.t -> Explanation.t Bag.t -> unit; (** Propagate a literal *) } @@ -65,14 +66,23 @@ val assert_lit : t -> Lit.t -> unit val assert_eq : t -> term -> term -> explanation -> unit -val assert_distinct : t -> term list -> explanation -> unit +val assert_distinct : t -> term list -> neq:term -> explanation -> unit +(** [assert_distinct l ~expl:u e] asserts all elements of [l] are distinct + with explanation [e] + precond: [u = distinct l] *) val check : t -> unit val final_check : t -> unit -val explain_unfold_bag : t -> explanation Bag.t -> Lit.Set.t +val explain_eq_n : ?init:Lit.Set.t -> t -> node -> node -> Lit.Set.t +(** explain why the two nodes are equal *) -val explain_unfold_seq : t -> explanation Sequence.t -> Lit.Set.t +val explain_eq_t : ?init:Lit.Set.t -> t -> term -> term -> Lit.Set.t +(** explain why the two terms are equal *) + +val explain_unfold_bag : ?init:Lit.Set.t -> t -> explanation Bag.t -> Lit.Set.t + +val explain_unfold_seq : ?init:Lit.Set.t -> t -> explanation Sequence.t -> Lit.Set.t (** Unfold those explanations into a complete set of literals implying them *) diff --git a/src/smt/Equiv_class.ml b/src/smt/Equiv_class.ml index 19e63f59..33488346 100644 --- a/src/smt/Equiv_class.ml +++ b/src/smt/Equiv_class.ml @@ -2,7 +2,7 @@ open Solver_types type t = cc_node -type payload = cc_node_payload +type payload = cc_node_payload = .. let field_expanded = Node_bits.mk_field () let field_has_expansion_lit = Node_bits.mk_field () @@ -26,6 +26,7 @@ let make (t:term) : t = n_root=n; n_expl=E_none; n_payload=[]; + n_tags=Util.Int_map.empty; } in n diff --git a/src/smt/Equiv_class.mli b/src/smt/Equiv_class.mli index fe8e2513..482574a7 100644 --- a/src/smt/Equiv_class.mli +++ b/src/smt/Equiv_class.mli @@ -21,7 +21,7 @@ open Solver_types *) type t = cc_node -type payload = cc_node_payload +type payload = cc_node_payload = .. val field_expanded : Node_bits.field (** Term is expanded? *) diff --git a/src/smt/Solver.ml b/src/smt/Solver.ml index ff194d06..3c769653 100644 --- a/src/smt/Solver.ml +++ b/src/smt/Solver.ml @@ -395,9 +395,8 @@ let assume (self:t) (c:Lit.t IArray.t) : unit = let[@inline] assume_eq self t u expl : unit = Congruence_closure.assert_eq (cc self) t u (E_lit expl) -let[@inline] assume_distinct self l expl : unit = - (* FIXME: custom evaluation instead (register to subterms) *) - Congruence_closure.assert_distinct (cc self) l (E_lit expl) +let[@inline] assume_distinct self l ~neq expl : unit = + Congruence_closure.assert_distinct (cc self) l (E_lit expl) ~neq (* type unsat_core = Sat.clause list diff --git a/src/smt/Solver.mli b/src/smt/Solver.mli index ce7fcc75..652abed9 100644 --- a/src/smt/Solver.mli +++ b/src/smt/Solver.mli @@ -51,7 +51,7 @@ val tst : t -> Term.state val assume : t -> Lit.t IArray.t -> unit val assume_eq : t -> Term.t -> Term.t -> Lit.t -> unit -val assume_distinct : t -> Term.t list -> Lit.t -> unit +val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit val solve : ?on_exit:(unit -> unit) list -> diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index 9e77b7c4..a5feffc7 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -88,6 +88,7 @@ and cc_node = { mutable n_root: cc_node; (* representative of congruence class (itself if a representative) *) mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) mutable n_payload: cc_node_payload list; (* list of theory payloads *) + mutable n_tags: explanation Util.Int_map.t; (* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *) } (** Theory-extensible payloads *) diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index a7d696ff..d44c1b68 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -30,9 +30,8 @@ type state = State : { } -> state (** Unsatisfiable conjunction. - Will be turned into a set of literals, whose negation becomes a - conflict clause *) -type conflict = Explanation.t Bag.t + Its negation will become a conflict clause *) +type conflict = Lit.Set.t (** Actions available to a theory during its lifetime *) type actions = { @@ -48,6 +47,9 @@ type actions = { propagate_eq: Term.t -> Term.t -> Explanation.t -> unit; (** Propagate an equality [t = u] because [e] *) + propagate_distinct: Term.t list -> neq:Term.t -> Explanation.t -> unit; + (** Propagate a [distinct l] because [e] (where [e = atom neq] *) + propagate: Lit.t -> Explanation.t Bag.t -> unit; (** Propagate a boolean using a unit clause. [expl => lit] must be a theory lemma, that is, a T-tautology *) diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index 1a5acd1e..16191c86 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -16,7 +16,7 @@ module Form = Lit type formula = Lit.t type proof = Proof.t -type conflict = Explanation.t Bag.t +type conflict = Lit.Set.t (* raise upon conflict *) exception Exn_conflict of conflict @@ -62,10 +62,7 @@ let cdcl_return_res (self:t) : _ Sat_solver.res = begin match self.conflict with | None -> Sat_solver.Sat - | Some c -> - let lit_set = - Congruence_closure.explain_unfold_bag (cc self) c - in + | Some lit_set -> let conflict_clause = Lit.Set.to_list lit_set |> IArray.of_list_map Lit.neg @@ -183,6 +180,9 @@ let act_all_classes self = Congruence_closure.all_classes (cc self) let act_propagate_eq self t u guard = Congruence_closure.assert_eq (cc self) t u guard +let act_propagate_distinct self l ~neq guard = + Congruence_closure.assert_distinct (cc self) l ~neq guard + let act_find self t = Congruence_closure.add (cc self) t |> Congruence_closure.find (cc self) @@ -208,6 +208,7 @@ let mk_theory_actions (self:t) : Theory.actions = propagate = act_propagate self; all_classes = act_all_classes self; propagate_eq = act_propagate_eq self; + propagate_distinct = act_propagate_distinct self; add_local_axiom = act_add_local_axiom self; add_persistent_axiom = act_add_persistent_axiom self; find = act_find self; diff --git a/src/smt/th_bool/Dagon_th_bool.ml b/src/smt/th_bool/Dagon_th_bool.ml index d5ec2caf..7af479d2 100644 --- a/src/smt/th_bool/Dagon_th_bool.ml +++ b/src/smt/th_bool/Dagon_th_bool.ml @@ -249,14 +249,23 @@ type t = { acts: Theory.actions; } -let tseitin (self:t) (lit:Lit.t) (b:term builtin) : unit = +let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit = Log.debugf 5 (fun k->k "(@[th_bool.tseitin@ %a@])" Lit.pp lit); match b with | B_not _ -> assert false (* normalized *) | B_eq (t,u) -> - self.acts.Theory.propagate_eq t u (Explanation.lit lit) - | B_distinct _ -> - assert false (* TODO: go to CC, or custom ineq? *) + if Lit.sign lit then ( + self.acts.Theory.propagate_eq t u (Explanation.lit lit) + ) else ( + self.acts.Theory.propagate_distinct [t;u] ~neq:lit_t (Explanation.lit lit) + ) + | B_distinct l -> + if Lit.sign lit then ( + self.acts.Theory.propagate_distinct l ~neq:lit_t (Explanation.lit lit) + ) else ( + (* TODO: propagate pairwise equalities? *) + Util.errorf "cannot process negative distinct lit %a" Lit.pp lit; + ) | B_and subs -> if Lit.sign lit then ( (* propagate [lit => subs_i] *) @@ -304,8 +313,8 @@ let tseitin (self:t) (lit:Lit.t) (b:term builtin) : unit = let on_assert (self:t) (lit:Lit.t) = match Lit.view lit with - | Lit.Lit_atom { Term.term_cell=Term.Custom{view=Builtin {view=b};_}; _ } -> - tseitin self lit b + | Lit.Lit_atom ({ Term.term_cell=Term.Custom{view=Builtin {view=b};_}; _ } as t) -> + tseitin self lit t b | _ -> () let final_check _ _ : unit = ()