diff --git a/src/cc/CC_types.ml b/src/cc/CC_types.ml new file mode 100644 index 00000000..d5e098f0 --- /dev/null +++ b/src/cc/CC_types.ml @@ -0,0 +1,112 @@ + +(** {1 Types used by the congruence closure} *) + +type ('f, 't, 'ts) view = + | Bool of bool + | App_fun of 'f * 'ts + | App_ho of 't * 'ts + | If of 't * 't * 't + | Eq of 't * 't + | Opaque of 't (* do not enter *) + +let[@inline] map_view ~f_f ~f_t ~f_ts (v:_ view) : _ view = + match v with + | Bool b -> Bool b + | App_fun (f, args) -> App_fun (f_f f, f_ts args) + | App_ho (f, args) -> App_ho (f_t f, f_ts args) + | If (a,b,c) -> If (f_t a, f_t b, f_t c) + | Eq (a,b) -> Eq (f_t a, f_t b) + | Opaque t -> Opaque (f_t t) + +let iter_view ~f_f ~f_t ~f_ts (v:_ view) : unit = + match v with + | Bool _ -> () + | App_fun (f, args) -> f_f f; f_ts args + | App_ho (f, args) -> f_t f; f_ts args + | If (a,b,c) -> f_t a; f_t b; f_t c; + | Eq (a,b) -> f_t a; f_t b + | Opaque t -> f_t t + +module type TERM = sig + module Fun : sig + type t + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + end + + module Term : sig + type t + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + + type state + + val bool : state -> bool -> t + + (** View the term through the lens of the congruence closure *) + val cc_view : t -> (Fun.t, t, t Sequence.t) view + end +end + +module type TERM_LIT = sig + include TERM + + module Lit : sig + type t + val neg : t -> t + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + + val sign : t -> bool + val term : t -> Term.t + end +end + +module type FULL = sig + include TERM_LIT + + module Proof : sig + type t + val pp : t Fmt.printer + + val default : t + (* TODO: to give more details + val cc_lemma : unit -> t + *) + end + + module Ty : sig + type t + + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + end + + module Value : sig + type t + + val pp : t Fmt.printer + + val fresh : Term.t -> t + + val true_ : t + val false_ : t + end + + module Model : sig + type t + + val pp : t Fmt.printer + + val eval : t -> Term.t -> Value.t option + (** Evaluate the term in the current model *) + + val add : Term.t -> Value.t -> t -> t + end +end + +(* TODO: micro theory *) diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml new file mode 100644 index 00000000..b71a728b --- /dev/null +++ b/src/cc/Congruence_closure.ml @@ -0,0 +1,939 @@ + +open CC_types + +module type ARG = Congruence_closure_intf.ARG +module type S = Congruence_closure_intf.S + +module Bits = CCBitField.Make() + +let field_is_pending = Bits.mk_field() +(** true iff the node is in the [cc.pending] queue *) + +let () = Bits.freeze() + +type payload = Congruence_closure_intf.payload = .. + +module Make(A: ARG) = struct + type term = A.Term.t + type term_state = A.Term.state + type lit = A.Lit.t + type fun_ = A.Fun.t + type proof = A.Proof.t + type value = A.Value.t + type model = A.Model.t + + (** Actions available to the theory *) + type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts + + module T = A.Term + module Fun = A.Fun + + (** A node of the congruence closure. + An equivalence class is represented by its "root" element, + the representative. *) + type node = { + n_term: term; + mutable n_sig0: signature option; (* initial signature *) + mutable n_bits: Bits.t; (* bitfield for various properties *) + mutable n_parents: node Bag.t; (* parent terms of this node *) + mutable n_root: node; (* representative of congruence class (itself if a representative) *) + mutable n_next: node; (* pointer to next element of congruence class *) + mutable n_size: int; (* size of the class *) + mutable n_as_lit: lit option; + mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) + mutable n_payload: payload list; (* list of theory payloads *) + (* TODO: make a micro theory and move this inside *) + mutable n_tags: (node * explanation) Util.Int_map.t; + (* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *) + } + + and signature = (fun_, node, node list) view + + and explanation_forest_link = + | FL_none + | FL_some of { + next: node; + expl: explanation; + } + + (* atomic explanation in the congruence closure *) + and explanation = + | E_reduction (* by pure reduction, tautologically equal *) + | E_merges of (node * node) list (* caused by these merges *) + | E_lit of lit (* because of this literal *) + | E_lits of lit list (* because of this (true) conjunction *) + (* TODO: congruence case (cheaper than "merges") *) + + type repr = node + type conflict = lit list + + module N = struct + type t = node + + let[@inline] equal (n1:t) n2 = T.equal n1.n_term n2.n_term + let[@inline] hash n = T.hash n.n_term + let[@inline] term n = n.n_term + let[@inline] payload n = n.n_payload + let[@inline] pp out n = T.pp out n.n_term + let[@inline] as_lit n = n.n_as_lit + + let make (t:term) : t = + let rec n = { + n_term=t; + n_sig0= None; + n_bits=Bits.empty; + n_parents=Bag.empty; + n_as_lit=None; (* TODO: provide a method to do it *) + n_root=n; + n_expl=FL_none; + n_payload=[]; + n_next=n; + n_size=1; + n_tags=Util.Int_map.empty; + } in + n + + type nonrec payload = payload = .. + + let set_payload ?(can_erase=fun _->false) n e = + let rec aux = function + | [] -> [e] + | e' :: tail when can_erase e' -> e :: tail + | e' :: tail -> e' :: aux tail + in + n.n_payload <- aux n.n_payload + + let payload_find ~f:p n = + let[@unroll 2] rec aux = function + | [] -> None + | e1 :: tail -> + match p e1 with + | Some _ as res -> res + | None -> aux tail + in + aux n.n_payload + + let payload_pred ~f:p n = + begin match n.n_payload with + | [] -> false + | e :: _ when p e -> true + | _ :: e :: _ when p e -> true + | _ :: _ :: e :: _ when p e -> true + | l -> List.exists p l + end + + let[@inline] get_field f t = Bits.get f t.n_bits + let[@inline] set_field f b t = t.n_bits <- Bits.set f b t.n_bits + end + + module N_tbl = CCHashtbl.Make(N) + + module Expl = struct + type t = explanation + + let equal (a:t) b = + match a, b with + | E_merges l1, E_merges l2 -> + CCList.equal (CCPair.equal N.equal N.equal) l1 l2 + | E_reduction, E_reduction -> true + | E_lit l1, E_lit l2 -> A.Lit.equal l1 l2 + | E_lits l1, E_lits l2 -> CCList.equal A.Lit.equal l1 l2 + | E_merges _, _ | E_lit _, _ | E_lits _, _ | E_reduction, _ + -> false + + let hash (a:t) : int = + let module H = CCHash in + match a with + | E_lit lit -> H.combine2 10 (A.Lit.hash lit) + | E_lits l -> + H.combine2 20 (H.list A.Lit.hash l) + | E_merges l -> + H.combine2 30 (H.list (H.pair N.hash N.hash) l) + | E_reduction -> H.int 40 + + let pp out (e:explanation) = match e with + | E_reduction -> Fmt.string out "reduction" + | E_lit lit -> A.Lit.pp out lit + | E_lits l -> CCFormat.Dump.list A.Lit.pp out l + | E_merges l -> + Format.fprintf out "(@[merges@ %a@])" + Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ + pair ~sep:(return " ~@ ") N.pp N.pp) + (Sequence.of_list l) + + let[@inline] mk_merges l : t = E_merges l + let[@inline] mk_lit l : t = E_lit l + let[@inline] mk_lits = function [x] -> mk_lit x | l -> E_lits l + let mk_reduction : t = E_reduction + end + + (** A signature is a shallow term shape where immediate subterms + are representative *) + module Signature = struct + type t = signature + let equal (s1:t) s2 : bool = + match s1, s2 with + | Bool b1, Bool b2 -> b1=b2 + | App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2 + | App_fun (f1,l1), App_fun (f2,l2) -> + Fun.equal f1 f2 && CCList.equal N.equal l1 l2 + | App_ho (f1,l1), App_ho (f2,l2) -> + N.equal f1 f2 && CCList.equal N.equal l1 l2 + | If (a1,b1,c1), If (a2,b2,c2) -> + N.equal a1 a2 && N.equal b1 b2 && N.equal c1 c2 + | Eq (a1,b1), Eq (a2,b2) -> + N.equal a1 a2 && N.equal b1 b2 + | Opaque u1, Opaque u2 -> N.equal u1 u2 + | Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _ + | Eq _, _ | Opaque _, _ + -> false + + let hash (s:t) : int = + let module H = CCHash in + match s with + | Bool b -> H.combine2 10 (H.bool b) + | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list N.hash l) + | App_ho (f, l) -> H.combine3 30 (N.hash f) (H.list N.hash l) + | Eq (a,b) -> H.combine3 40 (N.hash a) (N.hash b) + | Opaque u -> H.combine2 50 (N.hash u) + | If (a,b,c) -> H.combine4 60 (N.hash a)(N.hash b)(N.hash c) + + let pp out = function + | Bool b -> Fmt.bool out b + | App_fun (f, []) -> Fun.pp out f + | App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list N.pp) l + | App_ho (f, []) -> N.pp out f + | App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" N.pp f (Util.pp_list N.pp) l + | Opaque t -> N.pp out t + | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" N.pp a N.pp b + | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" N.pp a N.pp b N.pp c + end + + module Sig_tbl = CCHashtbl.Make(Signature) + module T_tbl = CCHashtbl.Make(T) + + type combine_task = + | CT_merge of node * node * explanation + | CT_distinct of node list * int * explanation + + type t = { + tst: term_state; + tbl: node T_tbl.t; + (* internalization [term -> node] *) + signatures_tbl : node Sig_tbl.t; + (* map a signature to the corresponding node in some equivalence class. + A signature is a [term_cell] in which every immediate subterm + that participates in the congruence/evaluation relation + is normalized (i.e. is its own representative). + The critical property is that all members of an equivalence class + that have the same "shape" (including head symbol) + have the same signature *) + pending: node Vec.t; + combine: combine_task Vec.t; + undo: (unit -> unit) Backtrack_stack.t; + on_merge: (repr -> repr -> explanation -> unit) option; + mutable ps_lits: lit list; (* TODO: thread it around instead? *) + (* proof state *) + ps_queue: (node*node) Vec.t; + (* pairs to explain *) + true_ : node lazy_t; + false_ : node lazy_t; + } + (* TODO: an additional union-find to keep track, for each term, + of the terms they are known to be equal to, according + to the current explanation. That allows not to prove some equality + several times. + See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) + + let[@inline] is_root_ (n:node) : bool = n.n_root == n + let[@inline] size_ (r:repr) = r.n_size + let[@inline] true_ cc = Lazy.force cc.true_ + let[@inline] false_ cc = Lazy.force cc.false_ + + let[@inline] on_backtrack cc f : unit = + Backtrack_stack.push_if_nonzero_level cc.undo f + + (* check if [t] is in the congruence closure. + Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) + let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t + + (* find representative, recursively *) + let[@unroll 1] rec find_rec (n:node) : repr = + if n==n.n_root then ( + n + ) else ( + (* TODO: path compression, assuming backtracking restores equiv classes + properly *) + let root = find_rec n.n_root in + root + ) + + (* traverse the equivalence class of [n] *) + let iter_class_ (n:node) : node Sequence.t = + fun yield -> + let rec aux u = + yield u; + if u.n_next != n then aux u.n_next + in + aux n + + (* non-recursive, inlinable function for [find] *) + let[@inline] find_ (n:node) : repr = + if n == n.n_root then n else find_rec n.n_root + + let[@inline] same_class (n1:node)(n2:node): bool = + N.equal (find_ n1) (find_ n2) + + let[@inline] find _ n = find_ n + + (* print full state *) + let pp_full out (cc:t) : unit = + let pp_next out n = + Fmt.fprintf out "@ :next %a" N.pp n.n_next in + let pp_root out n = + if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in + let pp_expl out n = match n.n_expl with + | FL_none -> () + | FL_some e -> + Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Expl.pp e.expl + in + let pp_n out n = + Fmt.fprintf out "(@[%a%a%a%a@])" T.pp n.n_term pp_root n pp_next n pp_expl n + and pp_sig_e out (s,n) = + Fmt.fprintf out "(@[<1>%a@ ~~> %a%a@])" Signature.pp s N.pp n pp_root n + in + Fmt.fprintf out + "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig-tbl@ %a@])@])" + (Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl) + (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) + + (* compute up-to-date signature *) + let update_sig (s:signature) : Signature.t = + CC_types.map_view s + ~f_f:(fun x->x) + ~f_t:find_ + ~f_ts:(List.map find_) + + (* find whether the given (parent) term corresponds to some signature + in [signatures_] *) + let[@inline] find_signature cc (s:signature) : repr option = + Sig_tbl.get cc.signatures_tbl s + + let add_signature cc (s:signature) (n:node) : unit = + (* add, but only if not present already *) + match Sig_tbl.find cc.signatures_tbl s with + | exception Not_found -> + Log.debugf 15 + (fun k->k "(@[cc.add-sig@ %a@ ~~> %a@])" Signature.pp s N.pp n); + on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s); + Sig_tbl.add cc.signatures_tbl s n; + | r' -> + assert (same_class n r'); + () + + let push_pending cc t : unit = + if not @@ N.get_field field_is_pending t then ( + Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" N.pp t); + N.set_field field_is_pending true t; + Vec.push cc.pending t + ) + + let push_combine cc t u e : unit = + Log.debugf 5 + (fun k->k "(@[cc.push_combine@ %a ~@ %a@ :expl %a@])" + N.pp t N.pp u Expl.pp e); + Vec.push cc.combine @@ CT_merge (t,u,e) + + (* re-root the explanation tree of the equivalence class of [n] + so that it points to [n]. + postcondition: [n.n_expl = None] *) + let rec reroot_expl (cc:t) (n:node): unit = + let old_expl = n.n_expl in + begin match old_expl with + | FL_none -> () (* already root *) + | FL_some {next=u; expl=e_n_u} -> + reroot_expl cc u; + u.n_expl <- FL_some {next=n; expl=e_n_u}; + n.n_expl <- FL_none; + end + + let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ = + (* clear tasks queue *) + Vec.iter (N.set_field field_is_pending false) cc.pending; + Vec.clear cc.pending; + Vec.clear cc.combine; + let c = List.rev_map A.Lit.neg e in + acts.Msat.acts_raise_conflict c A.Proof.default + + let[@inline] all_classes cc : repr Sequence.t = + T_tbl.values cc.tbl + |> Sequence.filter is_root_ + + (* TODO: use markers and lockstep iteration instead *) + (* distance from [t] to its root in the proof forest *) + let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with + | FL_none -> 0 + | FL_some {next=t'; _} -> 1 + distance_to_root t' + + (* TODO: bool flag on nodes + stepwise progress + cleanup *) + (* find the closest common ancestor of [a] and [b] in the proof forest *) + let find_common_ancestor (a:node) (b:node) : node = + let d_a = distance_to_root a in + let d_b = distance_to_root b in + (* drop [n] nodes in the path from [t] to its root *) + let rec drop_ n t = + if n=0 then t + else match t.n_expl with + | FL_none -> assert false + | FL_some {next=t'; _} -> drop_ (n-1) t' + in + (* reduce to the problem where [a] and [b] have the same distance to root *) + let a, b = + if d_a > d_b then drop_ (d_a-d_b) a, b + else if d_a < d_b then a, drop_ (d_b-d_a) b + else a, b + in + (* traverse stepwise until a==b *) + let rec aux_same_dist a b = + if a==b then a + else match a.n_expl, b.n_expl with + | FL_none, _ | _, FL_none -> assert false + | FL_some {next=a'; _}, FL_some {next=b'; _} -> aux_same_dist a' b' + in + aux_same_dist a b + + let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) + let[@inline] ps_add_lit ps l = ps.ps_lits <- l :: ps.ps_lits + + (* TODO: remove *) + let ps_clear (cc:t) = + cc.ps_lits <- []; + Vec.clear cc.ps_queue; + () + + let decompose_explain cc (e:explanation): unit = + Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); + begin match e with + | E_reduction -> () + | E_lit lit -> ps_add_lit cc lit + | E_lits l -> List.iter (ps_add_lit cc) l + | E_merges l -> + (* need to explain each merge in [l] *) + List.iter (fun (t,u) -> ps_add_obligation cc t u) l + end + + (* explain why [a = parent_a], where [a -> ... -> parent_a] in the + proof forest *) + let rec explain_along_path ps (a:node) (parent_a:node) : unit = + if a!=parent_a then ( + match a.n_expl with + | FL_none -> assert false + | FL_some {next=next_a; expl=e_a_b} -> + decompose_explain ps e_a_b; + (* now prove [next_a = parent_a] *) + explain_along_path ps next_a parent_a + ) + + (* find explanation *) + let explain_loop (cc : t) : lit list = + while not (Vec.is_empty cc.ps_queue) do + let a, b = Vec.pop cc.ps_queue in + Log.debugf 5 + (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); + assert (N.equal (find_ a) (find_ b)); + let c = find_common_ancestor a b in + explain_along_path cc a c; + explain_along_path cc b c; + done; + cc.ps_lits + + (* TODO: do not use ps_lits anymore *) + let explain_eq_n ?(init=[]) cc (n1:node) (n2:node) : lit list = + ps_clear cc; + cc.ps_lits <- init; + ps_add_obligation cc n1 n2; + explain_loop cc + + let explain_unfold ?(init=[]) cc (e:explanation) : lit list = + ps_clear cc; + cc.ps_lits <- init; + decompose_explain cc e; + explain_loop cc + + (* add [tag] to [n], indicating that [n] is distinct from all the other + nodes tagged with [tag] + precond: [n] is a representative *) + let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit = + assert (is_root_ n); + if not (Util.Int_map.mem tag n.n_tags) then ( + on_backtrack cc + (fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags); + n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags; + ) + + (* add a term *) + let [@inline] rec add_term_rec_ cc t : node = + try T_tbl.find cc.tbl t + with Not_found -> add_new_term_ cc t + + (* add [t] to [cc] when not present already *) + and add_new_term_ cc (t:term) : node = + assert (not @@ mem cc t); + Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" T.pp t); + let n = N.make t in + (* register sub-terms, add [t] to their parent list, and return the + corresponding initial signature *) + let sig0 = compute_sig0 cc n in + n.n_sig0 <- sig0; + (* remove term when we backtrack *) + on_backtrack cc + (fun () -> + Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" T.pp t); + T_tbl.remove cc.tbl t); + (* add term to the table *) + T_tbl.add cc.tbl t n; + if CCOpt.is_some sig0 then ( + (* [n] might be merged with other equiv classes *) + push_pending cc n; + ); + n + + (* compute the initial signature of the given node *) + and compute_sig0 (self:t) (n:node) : Signature.t option = + (* add sub-term to [cc], and register [n] to its parents *) + let deref_sub (u:term) : node = + let sub = add_term_rec_ self u in + (* add [n] to [sub.root]'s parent list *) + begin + let sub = find_ sub in + let old_parents = sub.n_parents in + on_backtrack self (fun () -> sub.n_parents <- old_parents); + sub.n_parents <- Bag.cons n sub.n_parents; + end; + sub + in + let[@inline] return x = Some x in + match T.cc_view n.n_term with + | Bool _ | Opaque _ -> None + | Eq (a,b) -> + let a = deref_sub a in + let b = deref_sub b in + return @@ Eq (a,b) + | App_fun (f, args) -> + let args = args |> Sequence.map deref_sub |> Sequence.to_list in + if args<>[] then ( + return @@ App_fun (f, args) + ) else None + | App_ho (f, args) -> + let args = args |> Sequence.map deref_sub |> Sequence.to_list in + return @@ App_ho (deref_sub f, args) + | If (a,b,c) -> + return @@ If (deref_sub a, deref_sub b, deref_sub c) + + let[@inline] add_term cc t : node = add_term_rec_ cc t + let[@inline] add_term' cc t : unit = ignore (add_term_rec_ cc t : node) + + let set_as_lit cc (n:node) (lit:lit) : unit = + match n.n_as_lit with + | Some _ -> () + | None -> + Log.debugf 15 (fun k->k "(@[cc.set-as-lit@ %a@ %a@])" N.pp n A.Lit.pp lit); + on_backtrack cc (fun () -> n.n_as_lit <- None); + n.n_as_lit <- Some lit + + (* Checks if [ra] and [~into] have compatible normal forms and can + be merged w.r.t. the theories. + Side effect: also pushes sub-tasks *) + let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = + assert (is_root_ rb); + match cc.on_merge with + | Some f -> f ra rb e + | None -> () + + let[@inline] n_is_bool (self:t) n : bool = + N.equal n (true_ self) || N.equal n (false_ self) + + (* main CC algo: add terms from [pending] to the signature table, + check for collisions *) + let rec update_tasks (cc:t) (acts:sat_actions) : unit = + while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do + while not @@ Vec.is_empty cc.pending do + task_pending_ cc (Vec.pop cc.pending); + done; + while not @@ Vec.is_empty cc.combine do + task_combine_ cc acts (Vec.pop cc.combine); + done; + done + + and task_pending_ cc (n:node) : unit = + N.set_field field_is_pending false n; + (* check if some parent collided *) + begin match n.n_sig0 with + | None -> () (* no-op *) + | Some (Eq (a,b)) -> + (* if [a=b] is now true, merge [(a=b)] and [true] *) + if same_class a b then ( + let expl = Expl.mk_merges [(a,b)] in + push_combine cc n (true_ cc) expl + ) + | Some s0 -> + (* update the signature by using [find] on each sub-node *) + let s = update_sig s0 in + match find_signature cc s with + | None -> + (* add to the signature table [sig(n) --> n] *) + add_signature cc s n + | Some u when n == u -> () + | Some u -> + (* [t1] and [t2] must be applications of the same symbol to + arguments that are pairwise equal *) + assert (n != u); + let expl = + match n.n_sig0, u.n_sig0 with + | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> + assert (Fun.equal f1 f2); + assert (List.length a1 = List.length a2); + (* TODO: just use "congruence" as explanation *) + Expl.mk_merges @@ List.combine a1 a2 + | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> + assert (List.length a1 = List.length a2); + (* TODO: just use "congruence" as explanation *) + Expl.mk_merges @@ (f1,f2)::List.combine a1 a2 + | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> + Expl.mk_merges @@ [a1,a2; b1,b2; c1,c2] + | _ + -> assert false + in + push_combine cc n u expl + (* FIXME: when to actually evaluate? + eval_pending cc; + *) + end + + and[@inline] task_combine_ cc acts = function + | CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab + | CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e + + (* main CC algo: merge equivalence classes in [st.combine]. + @raise Exn_unsat if merge fails *) + and task_merge_ cc acts a b e_ab : unit = + let ra = find_ a in + let rb = find_ b in + if not @@ N.equal ra rb then ( + assert (is_root_ ra); + assert (is_root_ rb); + (* check we're not merging [true] and [false] *) + if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) || + (N.equal rb (true_ cc) && N.equal ra (false_ cc)) then ( + Log.debugf 5 + (fun k->k "(@[cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])" + N.pp ra N.pp rb Expl.pp e_ab); + let lits = explain_unfold cc e_ab in + let lits = explain_eq_n ~init:lits cc a ra in + let lits = explain_eq_n ~init:lits cc b rb in + raise_conflict cc acts lits + ); + (* We will merge [r_from] into [r_into]. + we try to ensure that [size ra <= size rb] in general, but always + keep values as representative *) + let r_from, r_into = + if n_is_bool cc ra then rb, ra + else if n_is_bool cc rb then ra, rb + else if size_ ra > size_ rb then rb, ra + else ra, rb + in + (* TODO: instead call micro theories, including "distinct" *) + (* update set of tags the new node cannot be equal to *) + let new_tags = + Util.Int_map.union + (fun _i (n1,e1) (n2,e2) -> + (* both maps contain same tag [_i]. conflict clause: + [e1 & e2 & e_ab] impossible *) + Log.debugf 5 + (fun k->k "(@[cc.merge.distinct_conflict@ :tag %d@ \ + @[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])" + _i N.pp n1 Expl.pp e1 + N.pp n2 Expl.pp e2 Expl.pp e_ab); + let lits = explain_unfold cc e1 in + let lits = explain_unfold ~init:lits cc e2 in + let lits = explain_unfold ~init:lits cc e_ab in + let lits = explain_eq_n ~init:lits cc a n1 in + let lits = explain_eq_n ~init:lits cc b n2 in + raise_conflict cc acts lits) + ra.n_tags rb.n_tags + in + (* when merging terms with [true] or [false], possibly propagate them to SAT *) + let merge_bool r1 t1 r2 t2 = + if N.equal r1 (true_ cc) then ( + propagate_bools cc acts r2 t2 r1 t1 e_ab true + ) else if N.equal r1 (false_ cc) then ( + propagate_bools cc acts r2 t2 r1 t1 e_ab false + ) + in + merge_bool ra a rb b; + merge_bool rb b ra a; + (* perform [union r_from r_into] *) + Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); + (* TODO: only iterate on parents of [rb] *) + (* TODO: [ra.parents <- ra.parent ++ rb.parents] *) + begin + (* for each node in [r_from]'s class: + - make it point to [r_into] + - push it into [st.pending] *) + iter_class_ r_from + (fun u -> + assert (u.n_root == r_from); + on_backtrack cc (fun () -> u.n_root <- r_from); + u.n_root <- r_into; + Bag.to_seq u.n_parents + (fun parent -> push_pending cc parent)); + (* now merge the classes *) + let r_into_old_tags = r_into.n_tags in + let r_into_old_next = r_into.n_next in + let r_from_old_next = r_from.n_next in + on_backtrack cc + (fun () -> + Log.debugf 15 + (fun k->k "(@[cc.undo_merge@ :from %a :into %a@])" + N.pp r_from N.pp r_into); + r_into.n_next <- r_into_old_next; + r_from.n_next <- r_from_old_next; + r_into.n_tags <- r_into_old_tags); + r_into.n_tags <- new_tags; + (* swap [into.next] and [from.next], merging the classes *) + r_into.n_next <- r_from_old_next; + r_from.n_next <- r_into_old_next; + end; + (* update explanations (a -> b), arbitrarily. + Note that here we merge the classes by adding a bridge between [a] + and [b], not their roots. *) + begin + reroot_expl cc a; + assert (a.n_expl = FL_none); + (* on backtracking, link may be inverted, but we delete the one + that bridges between [a] and [b] *) + on_backtrack cc + (fun () -> match a.n_expl, b.n_expl with + | FL_some e, _ when N.equal e.next b -> a.n_expl <- FL_none + | _, FL_some e when N.equal e.next a -> b.n_expl <- FL_none + | _ -> assert false); + a.n_expl <- FL_some {next=b; expl=e_ab}; + end; + (* notify listeners of the merge *) + notify_merge cc r_from ~into:r_into e_ab; + ) + + and task_distinct_ cc acts (l:node list) tag expl : unit = + let l = List.map (fun n -> n, find_ n) l in + let coll = + Sequence.diagonal_l l + |> Sequence.find_pred (fun ((_,r1),(_,r2)) -> N.equal r1 r2) + in + begin match coll with + | Some ((n1,_r1),(n2,_r2)) -> + (* two classes are already equal *) + Log.debugf 5 + (fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp + n2 Expl.pp expl); + let lits = explain_unfold cc expl in + raise_conflict cc acts lits + | None -> + (* put a tag on all equivalence classes, that will make their merge fail *) + List.iter (fun (_,n) -> add_tag_n cc n tag expl) l + end + + (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] + in the equiv class of [r1] that is a known literal back to the SAT solver + and which is not the one initially merged. + We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) + and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = + (* explanation for [t1 =e= t2 = r2] *) + let half_expl = lazy ( + let expl = explain_unfold cc e_12 in + explain_eq_n ~init:expl cc r2 t2 + ) in + iter_class_ r1 + (fun u1 -> + (* propagate if: + - [u1] is a proper literal + - [t2 != r2], because that can only happen + after an explicit merge (no way to obtain that by propagation) + *) + match N.as_lit u1 with + | Some lit when not (N.equal r2 t2) -> + let lit = if sign then lit else A.Lit.neg lit in (* apply sign *) + Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" A.Lit.pp lit); + (* complete explanation with the [u1=t1] chunk *) + let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in + let reason = Msat.Consequence (expl, A.Proof.default) in + acts.Msat.acts_propagate lit reason + | _ -> ()) + + let check_invariants_ (cc:t) = + Log.debug 5 "(cc.check-invariants)"; + Log.debugf 15 (fun k-> k "%a" pp_full cc); + assert (T.equal (T.bool cc.tst true) (true_ cc).n_term); + assert (T.equal (T.bool cc.tst false) (false_ cc).n_term); + assert (not @@ same_class (true_ cc) (false_ cc)); + assert (Vec.is_empty cc.combine); + assert (Vec.is_empty cc.pending); + (* check that subterms are internalized *) + T_tbl.iter + (fun t n -> + assert (T.equal t n.n_term); + assert (not @@ N.get_field field_is_pending n); + assert (N.equal n.n_root n.n_next.n_root); + (* check proper signature. + note that some signatures in the sig table can be obsolete (they + were not removed) but there must be a valid, up-to-date signature for + each term *) + begin match CCOpt.map update_sig n.n_sig0 with + | None -> () + | Some s -> + Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" T.pp t Signature.pp s); + (* add, but only if not present already *) + begin match Sig_tbl.find cc.signatures_tbl s with + | exception Not_found -> assert false + | repr_s -> assert (same_class n repr_s) + end + end; + ) + cc.tbl; + () + + let[@inline] check_invariants (cc:t) : unit = + if Util._CHECK_INVARIANTS then check_invariants_ cc + + let add_seq cc seq = + seq (fun t -> ignore @@ add_term_rec_ cc t); + () + + let[@inline] push_level (self:t) : unit = + Backtrack_stack.push_level self.undo + + let pop_levels (self:t) n : unit = + Vec.iter (N.set_field field_is_pending false) self.pending; + Vec.clear self.pending; + Vec.clear self.combine; + Log.debugf 15 + (fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo)); + Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f()); + () + + (* assert that this boolean literal holds. + if a lit is [= a b], merge [a] and [b]; + if it's [distinct a1…an], make them distinct, etc. etc. *) + let assert_lit cc lit : unit = + let t = A.Lit.term lit in + Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" A.Lit.pp lit); + let sign = A.Lit.sign lit in + begin match T.cc_view t with + | Eq (a,b) when sign -> + (* merge [a] and [b] *) + let a = add_term cc a in + let b = add_term cc b in + push_combine cc a b (Expl.mk_lit lit) + | _ -> + (* equate t and true/false *) + let rhs = if sign then true_ cc else false_ cc in + let n = add_term cc t in + (* TODO: ensure that this is O(1). + basically, just have [n] point to true/false and thus acquire + the corresponding value, so its superterms (like [ite]) can evaluate + properly *) + push_combine cc n rhs (Expl.mk_lit lit) + end + + let[@inline] assert_lits cc lits : unit = + Sequence.iter (assert_lit cc) lits + + let assert_eq cc t1 t2 (e:lit list) : unit = + let expl = Expl.mk_lits e in + let n1 = add_term cc t1 in + let n2 = add_term cc t2 in + push_combine cc n1 n2 expl + + let assert_distinct cc (l:term list) ~neq (lit:lit) : unit = + assert (match l with[] | [_] -> false | _ -> true); + assert false + (* FIXME + let tag = Term.id neq in + Log.debugf 5 + (fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list Term.pp) l tag); + let l = List.map (add cc) l in + Vec.push cc.combine @@ CT_distinct (l, tag, Expl.lit lit) + *) + + let create ?on_merge ?(size=`Big) (tst:term_state) : t = + let size = match size with `Small -> 128 | `Big -> 2048 in + let rec cc = { + tst; + tbl = T_tbl.create size; + signatures_tbl = Sig_tbl.create size; + on_merge; + pending=Vec.create(); + combine=Vec.create(); + ps_lits=[]; + undo=Backtrack_stack.create(); + ps_queue=Vec.create(); + true_; + false_; + } and true_ = lazy ( + add_term cc (T.bool tst true) + ) and false_ = lazy ( + add_term cc (T.bool tst false) + ) + in + ignore (Lazy.force true_ : node); + ignore (Lazy.force false_ : node); + cc + + let[@inline] find_t cc t : repr = + let n = T_tbl.find cc.tbl t in + find_ n + + let[@inline] check cc acts : unit = + Log.debug 5 "(cc.check)"; + update_tasks cc acts + + (* model: map each uninterpreted equiv class to some ID *) + let mk_model (cc:t) (m:A.Model.t) : A.Model.t = + let module Model = A.Model in + let module Value = A.Value in + Log.debugf 15 (fun k->k "(@[cc.mk-model@ %a@])" pp_full cc); + let t_tbl = N_tbl.create 32 in + (* populate [repr -> value] table *) + T_tbl.values cc.tbl + (fun r -> + if is_root_ r then ( + (* find a value in the class, if any *) + let v = + iter_class_ r + |> Sequence.find_map (fun n -> Model.eval m n.n_term) + in + let v = match v with + | Some v -> v + | None -> + if same_class r (true_ cc) then Value.true_ + else if same_class r (false_ cc) then Value.false_ + else Value.fresh r.n_term + in + N_tbl.add t_tbl r v + )); + (* now map every term to its representative's value *) + let pairs = + T_tbl.values cc.tbl + |> Sequence.map + (fun n -> + let r = find_ n in + let v = + try N_tbl.find t_tbl r + with Not_found -> + Error.errorf "didn't allocate a value for repr %a" N.pp r + in + n.n_term, v) + in + let m = Sequence.fold (fun m (t,v) -> Model.add t v m) m pairs in + Log.debugf 5 (fun k->k "(@[cc.model@ %a@])" Model.pp m); + m +end diff --git a/src/cc/Congruence_closure.mli b/src/cc/Congruence_closure.mli new file mode 100644 index 00000000..e9f38fbf --- /dev/null +++ b/src/cc/Congruence_closure.mli @@ -0,0 +1,14 @@ +(** {2 Congruence Closure} *) + +module type ARG = Congruence_closure_intf.ARG +module type S = Congruence_closure_intf.S + +type payload = Congruence_closure_intf.payload = .. + +module Make(A: ARG) + : S with type term = A.Term.t + and type lit = A.Lit.t + and type fun_ = A.Fun.t + and type term_state = A.Term.state + and type proof = A.Proof.t + and type model = A.Model.t diff --git a/src/cc/Congruence_closure_intf.ml b/src/cc/Congruence_closure_intf.ml new file mode 100644 index 00000000..7c94a4b8 --- /dev/null +++ b/src/cc/Congruence_closure_intf.ml @@ -0,0 +1,136 @@ + +module type ARG = CC_types.FULL + +(** Theory-extensible payloads in the equivalence classes *) +type payload = .. + +module type S = sig + type term_state + type term + type fun_ + type lit + type proof + type model + + (** Actions available to the theory *) + type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts + + type t + (** Global state of the congruence closure *) + + + (** An equivalence class is a set of terms that are currently equal + in the partial model built by the solver. + The class is represented by a collection of nodes, one of which is + distinguished and is called the "representative". + + All information pertaining to the whole equivalence class is stored + in this representative's node. + + When two classes become equal (are "merged"), one of the two + representatives is picked as the representative of the new class. + The new class contains the union of the two old classes' nodes. + + We also allow theories to store additional information in the + representative. This information can be used when two classes are + merged, to detect conflicts and solve equations à la Shostak. + *) + module N : sig + type t + + val term : t -> term + val equal : t -> t -> bool + val hash : t -> int + val pp : t Fmt.printer + + type nonrec payload = payload = .. + + val payload_find: f:(payload -> 'a option) -> t -> 'a option + + val payload_pred: f:(payload -> bool) -> t -> bool + + val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit + (** Add given payload + @param can_erase if provided, checks whether an existing value + is to be replaced instead of adding a new entry *) + end + + module Expl : sig + type t + val pp : t Fmt.printer + end + + type node = N.t + (** A node of the congruence closure *) + + type repr = N.t + (** Node that is currently a representative *) + + type explanation = Expl.t + + type conflict = lit list + + (* TODO micro theories as parameters *) + val create : + ?on_merge:(repr -> repr -> explanation -> unit) -> + ?size:[`Small | `Big] -> + term_state -> + t + (** Create a new congruence closure. + @param on_merge callback to be called on every merge + *) + + val find : t -> node -> repr + (** Current representative *) + + val add_term : t -> term -> node + (** Add the term to the congruence closure, if not present already. + Will be backtracked. *) + + val set_as_lit : t -> N.t -> lit -> unit + (** map the given node to a literal. *) + + val add_term' : t -> term -> unit + (** Same as {!add_term} but ignore the result *) + + val find_t : t -> term -> repr + (** Current representative of the term. + @raise Not_found if the term is not already {!add}-ed. *) + + val add_seq : t -> term Sequence.t -> unit + (** Add a sequence of terms to the congruence closure *) + + val all_classes : t -> repr Sequence.t + (** All current classes *) + + val assert_lit : t -> lit -> unit + (** Given a literal, assume it in the congruence closure and propagate + its consequences. Will be backtracked. *) + + val assert_lits : t -> lit Sequence.t -> unit + + val assert_eq : t -> term -> term -> lit list -> unit + (** merge the given terms with some explanations *) + + val assert_distinct : t -> term list -> neq:term -> lit -> unit + (** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct + because [lit] is true + precond: [u = distinct l] *) + + val check : t -> sat_actions -> unit + (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. + Will use the [sat_actions] to propagate literals, declare conflicts, etc. *) + + val push_level : t -> unit + + val pop_levels : t -> int -> unit + + val mk_model : t -> model -> model + (** Enrich a model by mapping terms to their representative's value, + if any. Otherwise map the representative to a fresh value *) + + (**/**) + val check_invariants : t -> unit + val pp_full : t Fmt.printer + (**/**) +end diff --git a/src/smt/Mini_cc.ml b/src/cc/Mini_cc.ml similarity index 55% rename from src/smt/Mini_cc.ml rename to src/cc/Mini_cc.ml index 80250408..ce0d7f56 100644 --- a/src/smt/Mini_cc.ml +++ b/src/cc/Mini_cc.ml @@ -1,23 +1,34 @@ - -module H = CCHash -type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view = - | Bool of bool - | App of 'f * 'ts - | If of 't * 't * 't - -type res = Mini_cc_intf.res = +type res = | Sat | Unsat -module type ARG = Mini_cc_intf.ARG -module type S = Mini_cc_intf.S +module type TERM = CC_types.TERM + +module type S = sig + type term + type fun_ + type term_state + + type t + + val create : term_state -> t + + val add_lit : t -> term -> bool -> unit + val distinct : t -> term list -> unit + + val check : t -> res +end + + +module Make(A: TERM) = struct + open CC_types -module Make(A: ARG) = struct module Fun = A.Fun module T = A.Term type fun_ = A.Fun.t type term = T.t + type term_state = A.Term.state module T_tbl = CCHashtbl.Make(T) @@ -65,49 +76,77 @@ module Make(A: ARG) = struct let equal (s1:t) s2 : bool = match s1, s2 with | Bool b1, Bool b2 -> b1=b2 - | App (f1,[]), App (f2,[]) -> Fun.equal f1 f2 - | App (f1,l1), App (f2,l2) -> + | App_fun (f1,[]), App_fun (f2,[]) -> Fun.equal f1 f2 + | App_fun (f1,l1), App_fun (f2,l2) -> Fun.equal f1 f2 && CCList.equal Node.equal l1 l2 + | App_ho (f1,l1), App_ho (f2,l2) -> + Node.equal f1 f2 && CCList.equal Node.equal l1 l2 | If (a1,b1,c1), If (a2,b2,c2) -> Node.equal a1 a2 && Node.equal b1 b2 && Node.equal c1 c2 - | Bool _, _ | App _, _ | If _, _ + | Eq (a1,b1), Eq (a2,b2) -> + Node.equal a1 a2 && Node.equal b1 b2 + | Opaque u1, Opaque u2 -> Node.equal u1 u2 + | Bool _, _ | App_fun _, _ | App_ho _, _ | If _, _ + | Eq _, _ | Opaque _, _ -> false let hash (s:t) : int = + let module H = CCHash in match s with | Bool b -> H.combine2 10 (H.bool b) - | App (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l) - | If (a,b,c) -> H.combine4 30 (Node.hash a)(Node.hash b)(Node.hash c) + | App_fun (f, l) -> H.combine3 20 (Fun.hash f) (H.list Node.hash l) + | App_ho (f, l) -> H.combine3 30 (Node.hash f) (H.list Node.hash l) + | Eq (a,b) -> H.combine3 40 (Node.hash a) (Node.hash b) + | Opaque u -> H.combine2 50 (Node.hash u) + | If (a,b,c) -> H.combine4 60 (Node.hash a)(Node.hash b)(Node.hash c) let pp out = function | Bool b -> Fmt.bool out b - | App (f, []) -> Fun.pp out f - | App (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l + | App_fun (f, []) -> Fun.pp out f + | App_fun (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Fun.pp f (Util.pp_list Node.pp) l + | App_ho (f, []) -> Node.pp out f + | App_ho (f, l) -> Fmt.fprintf out "(@[%a@ %a@])" Node.pp f (Util.pp_list Node.pp) l + | Opaque t -> Node.pp out t + | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" Node.pp a Node.pp b | If (a,b,c) -> Fmt.fprintf out "(@[ite@ %a@ %a@ %a@])" Node.pp a Node.pp b Node.pp c end module Sig_tbl = CCHashtbl.Make(Signature) type t = { + mutable ok: bool; (* unsat? *) tbl: node T_tbl.t; sig_tbl: node Sig_tbl.t; combine: (node * node) Vec.t; pending: node Vec.t; (* refresh signature *) distinct: node list ref Vec.t; (* disjoint sets *) + true_: node; + false_: node; } - let create() : t = - { tbl= T_tbl.create 128; + let create tst : t = + let true_ = T.bool tst true in + let false_ = T.bool tst false in + let self = { + ok=true; + tbl= T_tbl.create 128; sig_tbl=Sig_tbl.create 128; combine=Vec.create(); pending=Vec.create(); distinct=Vec.create(); - } + true_=Node.make true_; + false_=Node.make false_; + } in + T_tbl.add self.tbl true_ self.true_; + T_tbl.add self.tbl false_ self.false_; + self let sub_ t k : unit = - match T.view t with - | Bool _ -> () - | App (_, args) -> args k + match T.cc_view t with + | Bool _ | Opaque _ -> () + | App_fun (_, args) -> args k + | App_ho (f, args) -> k f; args k + | Eq (a,b) -> k a; k b | If(a,b,c) -> k a; k b; k c let rec add_t (self:t) (t:term) : node = @@ -152,8 +191,37 @@ module Make(A: ARG) = struct if has_dups !r then raise_notrace E_unsat) self.distinct + let compute_sig (self:t) (n:node) : Signature.t option = + let[@inline] return x = Some x in + match T.cc_view n.n_t with + | Bool _ | Opaque _ -> None + | Eq (a,b) -> + let a = find_t_ self a in + let b = find_t_ self b in + return @@ Eq (a,b) + | App_fun (f, args) -> + let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in + if args<>[] then ( + return @@ App_fun (f, args) + ) else None + | App_ho (f, args) -> + let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in + return @@ App_ho (find_t_ self f, args) + | If (a,b,c) -> + return @@ If(find_t_ self a, find_t_ self b, find_t_ self c) + let update_sig_ (self:t) (n: node) : unit = - let aux s = + match compute_sig self n with + | None -> () + | Some (Eq (a,b)) -> + if Node.equal a b then ( + (* reduce to [true] *) + let n2 = self.true_ in + Log.debugf 5 + (fun k->k "(@[minicc.congruence-by-eq@ %a@ %a@])" Node.pp n Node.pp n2); + Vec.push self.combine (n,n2) + ) + | Some s -> Log.debugf 5 (fun k->k "(@[minicc.update-sig@ %a@])" Signature.pp s); match Sig_tbl.find self.sig_tbl s with | n2 when Node.equal n n2 -> () @@ -164,23 +232,28 @@ module Make(A: ARG) = struct Vec.push self.combine (n,n2) | exception Not_found -> Sig_tbl.add self.sig_tbl s n - in - match T.view n.n_t with - | Bool _ -> () - | App (f, args) -> - let args = args |> Sequence.map (find_t_ self) |> Sequence.to_list in - aux @@ App (f, args) - | If (a,b,c) -> aux @@ If(find_t_ self a, find_t_ self b, find_t_ self c) + + let[@inline] is_bool self n = Node.equal self.true_ n || Node.equal self.false_ n (* merge the two classes *) let merge_ self (n1,n2) : unit = let n1 = find_ n1 in let n2 = find_ n2 in if not @@ Node.equal n1 n2 then ( - (* merge into largest class *) - let n1, n2 = if Node.size n1 > Node.size n2 then n1, n2 else n2, n1 in + (* merge into largest class, or into a boolean *) + let n1, n2 = + if is_bool self n1 then n1, n2 + else if is_bool self n2 then n2, n1 + else if Node.size n1 > Node.size n2 then n1, n2 + else n2, n1 in Log.debugf 5 (fun k->k "(@[minicc.merge@ :into %a@ %a@])" Node.pp n1 Node.pp n2); + if is_bool self n1 && is_bool self n2 then ( + Log.debugf 5 (fun k->k "(minicc.conflict.merge-true-false)"); + self.ok <- false; + raise E_unsat + ); + List.iter (Vec.push self.pending) n2.n_parents; (* will change signature *) (* merge parent lists *) @@ -191,9 +264,13 @@ module Make(A: ARG) = struct Node.iter_cls n2 (fun n -> n.n_root <- n1); ) + let check_ok_ self = + if not self.ok then raise_notrace E_unsat + (* fixpoint of the congruence closure *) let fixpoint (self:t) : unit = while not (Vec.is_empty self.pending && Vec.is_empty self.combine) do + check_ok_ self; while not @@ Vec.is_empty self.pending do update_sig_ self @@ Vec.pop self.pending done; @@ -205,10 +282,17 @@ module Make(A: ARG) = struct (* API *) - let merge (self:t) t1 t2 : unit = - let n1 = add_t self t1 in - let n2 = add_t self t2 in - Vec.push self.combine (n1,n2) + let add_lit (self:t) (p:T.t) (sign:bool) : unit = + match T.cc_view p with + | Eq (t1,t2) when sign -> + let n1 = add_t self t1 in + let n2 = add_t self t2 in + Vec.push self.combine (n1,n2) + | _ -> + (* just merge with true/false *) + let n = add_t self p in + let n2 = if sign then self.true_ else self.false_ in + Vec.push self.combine (n,n2) let distinct (self:t) l = begin match l with @@ -220,6 +304,8 @@ module Make(A: ARG) = struct let check (self:t) : res = try fixpoint self; Sat - with E_unsat -> Unsat + with E_unsat -> + self.ok <- false; + Unsat end diff --git a/src/cc/Mini_cc.mli b/src/cc/Mini_cc.mli new file mode 100644 index 00000000..b460c74a --- /dev/null +++ b/src/cc/Mini_cc.mli @@ -0,0 +1,36 @@ + +(** {1 Mini congruence closure} + + This implementation is as simple as possible, and doesn't provide + backtracking, theories, or explanations. + It just decides the satisfiability of a set of (dis)equations. +*) + +type res = + | Sat + | Unsat + +module type TERM = CC_types.TERM + +module type S = sig + type term + type fun_ + type term_state + + type t + + val create : term_state -> t + + val add_lit : t -> term -> bool -> unit + (** [add_lit cc p sign] asserts that [p=sign] *) + + val distinct : t -> term list -> unit + (** [distinct cc l] asserts that all terms in [l] are distinct *) + + val check : t -> res +end + +module Make(A: TERM) + : S with type term = A.Term.t + and type fun_ = A.Fun.t + and type term_state = A.Term.state diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml new file mode 100644 index 00000000..595d635e --- /dev/null +++ b/src/cc/Sidekick_cc.ml @@ -0,0 +1,23 @@ + +type ('f, 't, 'ts) view = ('f, 't, 'ts) CC_types.view = + | Bool of bool + | App_fun of 'f * 'ts + | App_ho of 't * 'ts + | If of 't * 't * 't + | Eq of 't * 't + | Opaque of 't (* do not enter *) + +type payload = Congruence_closure.payload = .. + +module CC_types = CC_types + +(** Parameter for the congruence closure *) +module type TERM_LIT = CC_types.TERM_LIT +module type FULL = CC_types.FULL +module type S = Congruence_closure.S + +module Mini_cc = Mini_cc +module Congruence_closure = Congruence_closure + +module Make = Congruence_closure.Make + diff --git a/src/cc/dune b/src/cc/dune new file mode 100644 index 00000000..9c089050 --- /dev/null +++ b/src/cc/dune @@ -0,0 +1,10 @@ + + +(library + (name Sidekick_cc) + (public_name sidekick.cc) + (libraries containers containers.data msat sequence sidekick.util) + (flags :standard -warn-error -a+8 + -color always -safe-string -short-paths -open Sidekick_util) + (ocamlopt_flags :standard -O3 -color always + -unbox-closures -unbox-closures-factor 20)) diff --git a/src/smt/CC.ml b/src/smt/CC.ml new file mode 100644 index 00000000..3cf15e59 --- /dev/null +++ b/src/smt/CC.ml @@ -0,0 +1,18 @@ + +module Arg = struct + module Fun = Cst + module Term = Term + module Lit = Lit + module Value = Value + module Ty = Ty + module Model = Model + module Proof = struct + type t = Solver_types.proof + let pp = Solver_types.pp_proof + let default = Solver_types.Proof_default + end +end + +include Sidekick_cc.Make(Arg) + +module Mini_cc = Sidekick_cc.Mini_cc.Make(Arg) diff --git a/src/smt/CC.mli b/src/smt/CC.mli new file mode 100644 index 00000000..f58d7efd --- /dev/null +++ b/src/smt/CC.mli @@ -0,0 +1,13 @@ + +include Sidekick_cc.S + with type term = Term.t + and type model = Model.t + and type lit = Lit.t + and type fun_ = Cst.t + and type term_state = Term.state + and type proof = Solver_types.proof + +module Mini_cc : Sidekick_cc.Mini_cc.S + with type term = Term.t + and type fun_ = Cst.t + and type term_state = Term.state diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml deleted file mode 100644 index b9cd62f1..00000000 --- a/src/smt/Congruence_closure.ml +++ /dev/null @@ -1,763 +0,0 @@ - -open Solver_types - -module N = Eq_class - -type node = N.t -type repr = N.t -type conflict = Theory.conflict - -module T_arg = struct - module Fun = Cst - module Term = struct - include Term - let view = cc_view - end -end -module Mini_cc = Mini_cc.Make(T_arg) - -(** A signature is a shallow term shape where immediate subterms - are representative *) -module Signature = struct - type t = node Term.view - include Term_cell.Make_eq(N) -end - -module Sig_tbl = CCHashtbl.Make(Signature) - -type explanation_thunk = explanation lazy_t - -type combine_task = - | CT_merge of node * node * explanation_thunk - | CT_distinct of node list * int * explanation - -type t = { - tst: Term.state; - tbl: node Term.Tbl.t; - (* internalization [term -> node] *) - signatures_tbl : node Sig_tbl.t; - (* map a signature to the corresponding node in some equivalence class. - A signature is a [term_cell] in which every immediate subterm - that participates in the congruence/evaluation relation - is normalized (i.e. is its own representative). - The critical property is that all members of an equivalence class - that have the same "shape" (including head symbol) - have the same signature *) - pending: node Vec.t; - combine: combine_task Vec.t; - undo: (unit -> unit) Backtrack_stack.t; - on_merge: (repr -> repr -> explanation -> unit) option; - mutable ps_lits: Lit.Set.t; (* TODO: thread it around instead? *) - (* proof state *) - ps_queue: (node*node) Vec.t; - (* pairs to explain *) - true_ : node lazy_t; - false_ : node lazy_t; -} -(* TODO: an additional union-find to keep track, for each term, - of the terms they are known to be equal to, according - to the current explanation. That allows not to prove some equality - several times. - See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) - -let[@inline] is_root_ (n:node) : bool = n.n_root == n -let[@inline] size_ (r:repr) = r.n_size -let[@inline] true_ cc = Lazy.force cc.true_ -let[@inline] false_ cc = Lazy.force cc.false_ - -let[@inline] on_backtrack cc f : unit = - Backtrack_stack.push_if_nonzero_level cc.undo f - -(* check if [t] is in the congruence closure. - Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) -let[@inline] mem (cc:t) (t:term): bool = Term.Tbl.mem cc.tbl t - -(* find representative, recursively *) -let rec find_rec cc (n:node) : repr = - if n==n.n_root then ( - n - ) else ( - (* TODO: path compression, assuming backtracking restores equiv classes - properly *) - let root = find_rec cc n.n_root in - root - ) - -(* traverse the equivalence class of [n] *) -let iter_class_ (n:node) : node Sequence.t = - fun yield -> - let rec aux u = - yield u; - if u.n_next != n then aux u.n_next - in - aux n - -(* get term that should be there *) -let[@inline] get_ cc (t:term) : node = - try Term.Tbl.find cc.tbl t - with Not_found -> - Log.debugf 1 (fun k->k "(@[cc.error@ :missing-term %a@])" Term.pp t); - assert false - -(* non-recursive, inlinable function for [find] *) -let[@inline] find st (n:node) : repr = - if n == n.n_root then n else find_rec st n - -let[@inline] find_tn cc (t:term) : repr = get_ cc t |> find cc - -let[@inline] same_class cc (n1:node)(n2:node): bool = - N.equal (find cc n1) (find cc n2) - -(* print full state *) -let pp_full out (cc:t) : unit = - let pp_next out n = - Fmt.fprintf out "@ :next %a" N.pp n.n_next in - let pp_root out n = - if is_root_ n then Fmt.string out " :is-root" else Fmt.fprintf out "@ :root %a" N.pp n.n_root in - let pp_expl out n = match n.n_expl with - | E_none -> () - | E_some e -> - Fmt.fprintf out " (@[:forest %a :expl %a@])" N.pp e.next Explanation.pp e.expl - in - let pp_n out n = - Fmt.fprintf out "(@[%a%a%a%a@])" Term.pp n.n_term pp_root n pp_next n pp_expl n - and pp_sig_e out (s,n) = - Fmt.fprintf out "(@[<1>%a@ <--> %a%a@])" Signature.pp s N.pp n pp_root n - in - Fmt.fprintf out - "(@[@{cc.state@}@ (@[:nodes@ %a@])@ (@[:sig@ %a@])@])" - (Util.pp_seq ~sep:" " pp_n) (Term.Tbl.values cc.tbl) - (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) - -(* compute signature *) -let signature cc (t:term): Signature.t option = - let find = find_tn cc in - begin match Term.view t with - | App_cst (_, a) when IArray.is_empty a -> None - | App_cst (c, _) when not @@ Cst.do_cc c -> None (* no CC *) - | App_cst (f, a) -> Some (App_cst (f, IArray.map find a)) (* FIXME: relevance? *) - | Bool _ | If _ - -> None (* no congruence for these *) - end - -(* find whether the given (parent) term corresponds to some signature - in [signatures_] *) -let find_by_signature cc (t:term) : repr option = - match signature cc t with - | None -> None - | Some s -> Sig_tbl.get cc.signatures_tbl s - -let add_signature cc (n:node): unit = - match signature cc n.n_term with - | None -> () - | Some s -> - (* add, but only if not present already *) - begin match Sig_tbl.find cc.signatures_tbl s with - | exception Not_found -> - Log.debugf 15 - (fun k->k "(@[cc.add_sig@ %a@ <--> %a@])" Signature.pp s N.pp n); - on_backtrack cc (fun () -> Sig_tbl.remove cc.signatures_tbl s); - Sig_tbl.add cc.signatures_tbl s n; - | r' -> - assert (same_class cc n r'); - end - -let push_pending cc t : unit = - if not @@ N.get_field N.field_is_pending t then ( - Log.debugf 5 (fun k->k "(@[cc.push_pending@ %a@])" N.pp t); - N.set_field N.field_is_pending true t; - Vec.push cc.pending t - ) - -let push_combine cc t u e : unit = - Log.debugf 5 - (fun k->k "(@[cc.push_combine@ :t1 %a@ :t2 %a@ :expl %a@])" - N.pp t N.pp u Explanation.pp (Lazy.force e)); - Vec.push cc.combine @@ CT_merge (t,u,e) - -(* re-root the explanation tree of the equivalence class of [n] - so that it points to [n]. - postcondition: [n.n_expl = None] *) -let rec reroot_expl (cc:t) (n:node): unit = - let old_expl = n.n_expl in - begin match old_expl with - | E_none -> () (* already root *) - | E_some {next=u; expl=e_n_u} -> - reroot_expl cc u; - u.n_expl <- E_some {next=n; expl=e_n_u}; - n.n_expl <- E_none; - end - -let raise_conflict (cc:t) (acts:sat_actions) (e:conflict): _ = - (* clear tasks queue *) - Vec.iter (N.set_field N.field_is_pending false) cc.pending; - Vec.clear cc.pending; - Vec.clear cc.combine; - let c = List.map Lit.neg e in - acts.Msat.acts_raise_conflict c Proof_default - -let[@inline] all_classes cc : repr Sequence.t = - Term.Tbl.values cc.tbl - |> Sequence.filter is_root_ - -(* TODO: use markers and lockstep iteration instead *) -(* distance from [t] to its root in the proof forest *) -let[@inline][@unroll 2] rec distance_to_root (n:node): int = match n.n_expl with - | E_none -> 0 - | E_some {next=t'; _} -> 1 + distance_to_root t' - -(* TODO: bool flag on nodes + stepwise progress + cleanup *) -(* find the closest common ancestor of [a] and [b] in the proof forest *) -let find_common_ancestor (a:node) (b:node) : node = - let d_a = distance_to_root a in - let d_b = distance_to_root b in - (* drop [n] nodes in the path from [t] to its root *) - let rec drop_ n t = - if n=0 then t - else match t.n_expl with - | E_none -> assert false - | E_some {next=t'; _} -> drop_ (n-1) t' - in - (* reduce to the problem where [a] and [b] have the same distance to root *) - let a, b = - if d_a > d_b then drop_ (d_a-d_b) a, b - else if d_a < d_b then a, drop_ (d_b-d_a) b - else a, b - in - (* traverse stepwise until a==b *) - let rec aux_same_dist a b = - if a==b then a - else match a.n_expl, b.n_expl with - | E_none, _ | _, E_none -> assert false - | E_some {next=a'; _}, E_some {next=b'; _} -> aux_same_dist a' b' - in - aux_same_dist a b - -let[@inline] ps_add_obligation (cc:t) a b = Vec.push cc.ps_queue (a,b) -let[@inline] ps_add_lit ps l = ps.ps_lits <- Lit.Set.add l ps.ps_lits - -let ps_clear (cc:t) = - cc.ps_lits <- Lit.Set.empty; - Vec.clear cc.ps_queue; - () - -let decompose_explain cc (e:explanation): unit = - Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Explanation.pp e); - begin match e with - | E_reduction -> () - | E_lit lit -> ps_add_lit cc lit - | E_lits l -> List.iter (ps_add_lit cc) l - | E_merges l -> - (* need to explain each merge in [l] *) - IArray.iter (fun (t,u) -> ps_add_obligation cc t u) l - end - -(* explain why [a = parent_a], where [a -> ... -> parent_a] in the - proof forest *) -let rec explain_along_path ps (a:node) (parent_a:node) : unit = - if a!=parent_a then ( - match a.n_expl with - | E_none -> assert false - | E_some {next=next_a; expl=e_a_b} -> - decompose_explain ps e_a_b; - (* now prove [next_a = parent_a] *) - explain_along_path ps next_a parent_a - ) - -(* find explanation *) -let explain_loop (cc : t) : Lit.Set.t = - while not (Vec.is_empty cc.ps_queue) do - let a, b = Vec.pop cc.ps_queue in - Log.debugf 5 - (fun k->k "(@[cc.explain_loop.at@ %a@ =?= %a@])" N.pp a N.pp b); - assert (N.equal (find cc a) (find cc b)); - let c = find_common_ancestor a b in - explain_along_path cc a c; - explain_along_path cc b c; - done; - cc.ps_lits - -(* TODO: do not use ps_lits anymore *) -let explain_eq_n ?(init=Lit.Set.empty) cc (n1:node) (n2:node) : Lit.Set.t = - ps_clear cc; - cc.ps_lits <- init; - ps_add_obligation cc n1 n2; - explain_loop cc - -let explain_unfold ?(init=Lit.Set.empty) cc (e:explanation) : Lit.Set.t = - ps_clear cc; - cc.ps_lits <- init; - decompose_explain cc e; - explain_loop cc - -(* add [tag] to [n], indicating that [n] is distinct from all the other - nodes tagged with [tag] - precond: [n] is a representative *) -let add_tag_n cc (n:node) (tag:int) (expl:explanation) : unit = - assert (is_root_ n); - if not (Util.Int_map.mem tag n.n_tags) then ( - on_backtrack cc - (fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags); - n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags; - ) - -(* TODO: payload for set of tags *) -(* TODO: payload for mapping an equiv class to a set of literals, for bool prop *) - -let relevant_subterms (t:Term.t) : Term.t Sequence.t = - fun yield -> - match t.term_view with - | App_cst (c, a) when Cst.do_cc c -> IArray.iter yield a - | Bool _ | App_cst _ -> () - | If (a,b,c) -> - (* TODO: relevancy? only [a] needs be decided for now *) - yield a; - yield b; - yield c - -(* Checks if [ra] and [~into] have compatible normal forms and can - be merged w.r.t. the theories. - Side effect: also pushes sub-tasks *) -let notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = - assert (is_root_ rb); - match cc.on_merge with - | Some f -> f ra rb e - | None -> () - -(* main CC algo: add terms from [pending] to the signature table, - check for collisions *) -let rec update_tasks (cc:t) (acts:sat_actions) : unit = - while not (Vec.is_empty cc.pending && Vec.is_empty cc.combine) do - Vec.iter (task_pending_ cc) cc.pending; - Vec.clear cc.pending; - Vec.iter (task_combine_ cc acts) cc.combine; - Vec.clear cc.combine; - done - -and task_pending_ cc n = - N.set_field N.field_is_pending false n; - (* check if some parent collided *) - begin match find_by_signature cc n.n_term with - | None -> - (* add to the signature table [sig(n) --> n] *) - add_signature cc n - | Some u when n == u -> () - | Some u -> - (* [t1] and [t2] must be applications of the same symbol to - arguments that are pairwise equal *) - assert (n != u); - let expl = lazy ( - match n.n_term.term_view, u.n_term.term_view with - | App_cst (f1, a1), App_cst (f2, a2) -> - assert (Cst.equal f1 f2); - assert (IArray.length a1 = IArray.length a2); - (* TODO: just use "congruence" as explanation *) - Explanation.mk_merges @@ IArray.map2 (fun u1 u2 -> add_term_rec_ cc u1, add_term_rec_ cc u2) a1 a2 - | If _, _ | App_cst _, _ | Bool _, _ - -> assert false - ) in - push_combine cc n u expl - end; - (* TODO: evaluate [(= t u) := true] when [find t==find u] *) - (* FIXME: when to actually evaluate? - eval_pending cc; - *) - () - -and[@inline] task_combine_ cc acts = function - | CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab - | CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e - -(* main CC algo: merge equivalence classes in [st.combine]. - @raise Exn_unsat if merge fails *) -and task_merge_ cc acts a b e_ab : unit = - let ra = find cc a in - let rb = find cc b in - if not @@ N.equal ra rb then ( - assert (is_root_ ra); - assert (is_root_ rb); - let lazy e_ab = e_ab in - (* We will merge [r_from] into [r_into]. - we try to ensure that [size ra <= size rb] in general, but always - keep values as representative *) - let r_from, r_into = - if Term.is_value ra.n_term then rb, ra - else if Term.is_value rb.n_term then ra, rb - else if size_ ra > size_ rb then rb, ra - else ra, rb - in - (* check we're not merging [true] and [false] *) - if (N.equal ra (true_ cc) && N.equal rb (false_ cc)) || - (N.equal rb (true_ cc) && N.equal ra (false_ cc)) then ( - Log.debugf 5 - (fun k->k "(@[cc.merge.true_false_conflict@ @[:r1 %a@]@ @[:r2 %a@]@ :e_ab %a@])" - N.pp ra N.pp rb Explanation.pp e_ab); - let lits = explain_unfold cc e_ab in - let lits = explain_eq_n ~init:lits cc a ra in - let lits = explain_eq_n ~init:lits cc b rb in - raise_conflict cc acts @@ Lit.Set.elements lits - ); - (* TODO: isntead call micro theories, including "distinct" *) - (* update set of tags the new node cannot be equal to *) - let new_tags = - Util.Int_map.union - (fun _i (n1,e1) (n2,e2) -> - (* both maps contain same tag [_i]. conflict clause: - [e1 & e2 & e_ab] impossible *) - Log.debugf 5 - (fun k->k "(@[cc.merge.distinct_conflict@ :tag %d@ \ - @[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])" - _i N.pp n1 Explanation.pp e1 - N.pp n2 Explanation.pp e2 Explanation.pp e_ab); - let lits = explain_unfold cc e1 in - let lits = explain_unfold ~init:lits cc e2 in - let lits = explain_unfold ~init:lits cc e_ab in - let lits = explain_eq_n ~init:lits cc a n1 in - let lits = explain_eq_n ~init:lits cc b n2 in - raise_conflict cc acts @@ Lit.Set.elements lits) - ra.n_tags rb.n_tags - in - (* when merging terms with [true] or [false], possibly propagate them to SAT *) - let merge_bool r1 t1 r2 t2 = - if N.equal r1 (true_ cc) then ( - propagate_bools cc acts r2 t2 r1 t1 e_ab true - ) else if N.equal r1 (false_ cc) then ( - propagate_bools cc acts r2 t2 r1 t1 e_ab false - ) - in - merge_bool ra a rb b; - merge_bool rb b ra a; - (* perform [union r_from r_into] *) - Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); - (* TODO: only iterate on parents of [rb] *) - (* TODO: [ra.parents <- ra.parent ++ rb.parents] *) - begin - (* for each node in [r_from]'s class: - - make it point to [r_into] - - push it into [st.pending] *) - iter_class_ r_from - (fun u -> - assert (u.n_root == r_from); - on_backtrack cc (fun () -> u.n_root <- r_from); - u.n_root <- r_into; - Bag.to_seq u.n_parents - (fun parent -> push_pending cc parent)); - (* now merge the classes *) - let r_into_old_tags = r_into.n_tags in - let r_into_old_next = r_into.n_next in - let r_from_old_next = r_from.n_next in - on_backtrack cc - (fun () -> - Log.debugf 15 - (fun k->k "(@[cc.undo_merge@ :from %a :into %a@])" - Term.pp r_from.n_term Term.pp r_into.n_term); - r_into.n_next <- r_into_old_next; - r_from.n_next <- r_from_old_next; - r_into.n_tags <- r_into_old_tags); - r_into.n_tags <- new_tags; - (* swap [into.next] and [from.next], merging the classes *) - r_into.n_next <- r_from_old_next; - r_from.n_next <- r_into_old_next; - end; - (* update explanations (a -> b), arbitrarily. - Note that here we merge the classes by adding a bridge between [a] - and [b], not their roots. *) - begin - reroot_expl cc a; - assert (a.n_expl = E_none); - (* on backtracking, link may be inverted, but we delete the one - that bridges between [a] and [b] *) - on_backtrack cc - (fun () -> match a.n_expl, b.n_expl with - | E_some e, _ when N.equal e.next b -> a.n_expl <- E_none - | _, E_some e when N.equal e.next a -> b.n_expl <- E_none - | _ -> assert false); - a.n_expl <- E_some {next=b; expl=e_ab}; - end; - (* notify listeners of the merge *) - notify_merge cc r_from ~into:r_into e_ab; - ) - -and task_distinct_ cc acts (l:node list) tag expl : unit = - let l = List.map (fun n -> n, find cc n) l in - let coll = - Sequence.diagonal_l l - |> Sequence.find_pred (fun ((_,r1),(_,r2)) -> N.equal r1 r2) - in - begin match coll with - | Some ((n1,_r1),(n2,_r2)) -> - (* two classes are already equal *) - Log.debugf 5 - (fun k->k "(@[cc.distinct.conflict@ %a = %a@ :expl %a@])" N.pp n1 N.pp - n2 Explanation.pp expl); - let lits = explain_unfold cc expl in - raise_conflict cc acts (Lit.Set.to_list lits) - | None -> - (* put a tag on all equivalence classes, that will make their merge fail *) - List.iter (fun (_,n) -> add_tag_n cc n tag expl) l - end - -(* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] - in the equiv class of [r1] that is a known literal back to the SAT solver - and which is not the one initially merged. - We can explain the propagation with [u1 = t1 =e= t2 = r2==bool] *) -and propagate_bools cc acts r1 t1 r2 t2 (e_12:explanation) sign : unit = - (* explanation for [t1 =e= t2 = r2] *) - let half_expl = lazy ( - let expl = explain_unfold cc e_12 in - explain_eq_n ~init:expl cc r2 t2 - ) in - iter_class_ r1 - (fun u1 -> - (* propagate if: - - [u1] is a proper literal - - [t2 != r2], because that can only happen - after an explicit merge (no way to obtain that by propagation) - *) - if N.get_field N.field_is_literal u1 && not (N.equal r2 t2) then ( - let lit = Lit.atom ~sign u1.n_term in - Log.debugf 5 (fun k->k "(@[cc.bool_propagate@ %a@])" Lit.pp lit); - (* complete explanation with the [u1=t1] chunk *) - let expl = explain_eq_n ~init:(Lazy.force half_expl) cc u1 t1 in - let reason = Msat.Consequence (Lit.Set.to_list expl, Proof_default) in - acts.Msat.acts_propagate lit reason - )) - -(* add [t] to [cc] when not present already *) -and add_new_term_ cc (t:term) : node = - assert (not @@ mem cc t); - Log.debugf 15 (fun k->k "(@[cc.add-term@ %a@])" Term.pp t); - let n = N.make t in - (* how to add a subterm *) - let add_to_parents_of_sub_node (sub:node) : unit = - let sub = find cc sub in (* update the repr! *) - let old_parents = sub.n_parents in - on_backtrack cc (fun () -> sub.n_parents <- old_parents); - sub.n_parents <- Bag.cons n sub.n_parents; - in - (* add sub-term to [cc], and register [n] to its parents *) - let add_sub_t (u:term) : unit = - let n_u = add_term_rec_ cc u in - add_to_parents_of_sub_node n_u - in - (* register sub-terms, add [t] to their parent list *) - relevant_subterms t add_sub_t; - (* remove term when we backtrack *) - on_backtrack cc - (fun () -> - Log.debugf 15 (fun k->k "(@[cc.remove-term@ %a@])" Term.pp t); - Term.Tbl.remove cc.tbl t); - (* add term to the table *) - Term.Tbl.add cc.tbl t n; - (* [n] might be merged with other equiv classes *) - push_pending cc n; - n - -(* add a term *) -and[@inline] add_term_rec_ cc t : node = - try Term.Tbl.find cc.tbl t - with Not_found -> add_new_term_ cc t - -let check_invariants_ (cc:t) = - Log.debug 5 "(cc.check-invariants)"; - Log.debugf 15 (fun k-> k "%a" pp_full cc); - assert (Term.equal (Term.true_ cc.tst) (true_ cc).n_term); - assert (Term.equal (Term.false_ cc.tst) (false_ cc).n_term); - assert (not @@ same_class cc (true_ cc) (false_ cc)); - assert (Vec.is_empty cc.combine); - assert (Vec.is_empty cc.pending); - (* check that subterms are internalized *) - Term.Tbl.iter - (fun t n -> - assert (Term.equal t n.n_term); - assert (not @@ N.get_field N.field_is_pending n); - relevant_subterms t - (fun u -> assert (Term.Tbl.mem cc.tbl u)); - assert (N.equal n.n_root n.n_next.n_root); - (* check proper signature. - note that some signatures in the sig table can be obsolete (they - were not removed) but there must be a valid, up-to-date signature for - each term *) - begin match signature cc t with - | None -> () - | Some s -> - Log.debugf 15 (fun k->k "(@[cc.check-sig@ %a@ :sig %a@])" Term.pp t Signature.pp s); - (* add, but only if not present already *) - begin match Sig_tbl.find cc.signatures_tbl s with - | exception Not_found -> assert false - | repr_s -> assert (same_class cc n repr_s) - end - end; - ) - cc.tbl; - () - -let[@inline] check_invariants (cc:t) : unit = - if Util._CHECK_INVARIANTS then check_invariants_ cc - -let[@inline] add cc t : node = add_term_rec_ cc t - -let add_seq cc seq = - seq (fun t -> ignore @@ add_term_rec_ cc t); - () - -let[@inline] push_level (self:t) : unit = - Backtrack_stack.push_level self.undo - -let pop_levels (self:t) n : unit = - Vec.iter (N.set_field N.field_is_pending false) self.pending; - Vec.clear self.pending; - Vec.clear self.combine; - Log.debugf 15 - (fun k->k "(@[cc.pop-levels %d@ :n-lvls %d@])" n (Backtrack_stack.n_levels self.undo)); - Backtrack_stack.pop_levels self.undo n ~f:(fun f -> f()); - () - -(* TODO: if a lit is [= a b], merge [a] and [b]; - if it's [distinct a1…an], make them distinct, etc. etc. *) -(* assert that this boolean literal holds *) -let assert_lit cc lit : unit = - let t = Lit.view lit in - assert (Ty.is_prop t.term_ty); - Log.debugf 5 (fun k->k "(@[cc.assert_lit@ %a@])" Lit.pp lit); - let sign = Lit.sign lit in - (* equate t and true/false *) - let rhs = if sign then true_ cc else false_ cc in - let n = add_term_rec_ cc t in - (* TODO: ensure that this is O(1). - basically, just have [n] point to true/false and thus acquire - the corresponding value, so its superterms (like [ite]) can evaluate - properly *) - push_combine cc n rhs (Lazy.from_val @@ E_lit lit) - -let[@inline] assert_lits cc lits : unit = - Sequence.iter (assert_lit cc) lits - -let assert_eq cc (t:term) (u:term) e : unit = - let n1 = add_term_rec_ cc t in - let n2 = add_term_rec_ cc u in - if not (same_class cc n1 n2) then ( - let e = Lazy.from_val @@ Explanation.E_lits e in - push_combine cc n1 n2 e; - ) - -let assert_distinct cc (l:term list) ~neq (lit:Lit.t) : unit = - assert (match l with[] | [_] -> false | _ -> true); - let tag = Term.id neq in - Log.debugf 5 - (fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list Term.pp) l tag); - let l = List.map (add cc) l in - Vec.push cc.combine @@ CT_distinct (l, tag, Explanation.lit lit) - -let create ?on_merge ?(size=`Big) (tst:Term.state) : t = - let size = match size with `Small -> 128 | `Big -> 2048 in - let rec cc = { - tst; - tbl = Term.Tbl.create size; - signatures_tbl = Sig_tbl.create size; - on_merge; - pending=Vec.create(); - combine=Vec.create(); - ps_lits=Lit.Set.empty; - undo=Backtrack_stack.create(); - ps_queue=Vec.create(); - true_; - false_; - } and true_ = lazy ( - add_term_rec_ cc (Term.true_ tst) - ) and false_ = lazy ( - add_term_rec_ cc (Term.false_ tst) - ) - in - ignore (Lazy.force true_ : node); - ignore (Lazy.force false_ : node); - cc - -let[@inline] find_t cc t : repr = - let n = Term.Tbl.find cc.tbl t in - find cc n - -let[@inline] check cc acts : unit = - Log.debug 5 "(cc.check)"; - update_tasks cc acts - -(* model: map each uninterpreted equiv class to some ID *) -let mk_model (cc:t) (m:Model.t) : Model.t = - Log.debugf 15 (fun k->k "(@[cc.mk_model@ %a@])" pp_full cc); - (* populate [repr -> value] table *) - let t_tbl = N.Tbl.create 32 in - (* type -> default value *) - let ty_tbl = Ty.Tbl.create 8 in - Term.Tbl.values cc.tbl - (fun r -> - if is_root_ r then ( - let t = r.n_term in - let v = match Model.eval m t with - | Some v -> v - | None -> - if same_class cc r (true_ cc) then Value.true_ - else if same_class cc r (false_ cc) then Value.false_ - else ( - Value.mk_elt - (ID.makef "v_%d" @@ Term.id t) - (Term.ty r.n_term) - ) - in - if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then ( - Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *) - ); - N.Tbl.add t_tbl r v - )); - (* now map every uninterpreted term to its representative's value, and - create function tables *) - let m, funs = - Term.Tbl.to_seq cc.tbl - |> Sequence.fold - (fun (m,funs) (t,r) -> - let r = find cc r in (* get representative *) - match Term.view t with - | _ when Model.mem t m -> m, funs - | App_cst (c, args) -> - if Model.mem t m then m, funs - else if Cst.is_undefined c && IArray.length args > 0 then ( - (* update signature of [c] *) - let ty = Term.ty t in - let v = N.Tbl.find t_tbl r in - let args = - args - |> IArray.map (fun t -> N.Tbl.find t_tbl @@ find_tn cc t) - |> IArray.to_list - in - let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in - m, Cst.Map.add c (ty, (args,v)::l) funs - ) else ( - let v = N.Tbl.find t_tbl r in - Model.add t v m, funs - ) - | _ -> - let v = N.Tbl.find t_tbl r in - Model.add t v m, funs) - (m,Cst.Map.empty) - in - (* get or make a default value for this type *) - let rec get_ty_default (ty:Ty.t) : Value.t = - match Ty.view ty with - | Ty_prop -> Value.true_ - | Ty_atomic { def = Ty_uninterpreted _;_} -> - (* domain element *) - Ty.Tbl.get_or_add ty_tbl ~k:ty - ~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty) - | Ty_atomic { def = Ty_def d; args; _} -> - (* ask the theory for a default value *) - Ty.Tbl.get_or_add ty_tbl ~k:ty - ~f:(fun _ty -> - let vals = List.map get_ty_default args in - d.default_val vals) - in - let funs = - Cst.Map.map - (fun (ty,l) -> - Model.Fun_interpretation.make ~default:(get_ty_default ty) l) - funs - in - Model.add_funs funs m diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli deleted file mode 100644 index 22215e37..00000000 --- a/src/smt/Congruence_closure.mli +++ /dev/null @@ -1,73 +0,0 @@ -(** {2 Congruence Closure} *) - -open Solver_types - -type t -(** Global state of the congruence closure *) - -type node = Eq_class.t -(** Node in the congruence closure *) - -type repr = Eq_class.t -(** Node that is currently a representative *) - -type conflict = Theory.conflict - -val create : - ?on_merge:(repr -> repr -> explanation -> unit) -> - ?size:[`Small | `Big] -> - Term.state -> - t -(** Create a new congruence closure. - @param acts the actions available to the congruence closure -*) - -val find : t -> node -> repr -(** Current representative *) - -val add : t -> term -> node -(** Add the term to the congruence closure, if not present already. - Will be backtracked. *) - -val find_t : t -> term -> repr -(** Current representative of the term. - @raise Not_found if the term is not already {!add}-ed. *) - -val add_seq : t -> term Sequence.t -> unit -(** Add a sequence of terms to the congruence closure *) - -val all_classes : t -> repr Sequence.t -(** All current classes *) - -val assert_lit : t -> Lit.t -> unit -(** Given a literal, assume it in the congruence closure and propagate - its consequences. Will be backtracked. *) - -val assert_lits : t -> Lit.t Sequence.t -> unit - -val assert_eq : t -> term -> term -> Lit.t list -> unit - -val assert_distinct : t -> term list -> neq:term -> Lit.t -> unit -(** [assert_distinct l ~expl:u e] asserts all elements of [l] are distinct - with explanation [e] - precond: [u = distinct l] *) - -val check : t -> sat_actions -> unit -(** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. - Will use the [sat_actions] to propagate literals, declare conflicts, etc. *) - -val push_level : t -> unit - -val pop_levels : t -> int -> unit - -val mk_model : t -> Model.t -> Model.t -(** Enrich a model by mapping terms to their representative's value, - if any. Otherwise map the representative to a fresh value *) - -(**/**) -val check_invariants : t -> unit -val pp_full : t Fmt.printer -(**/**) - -module T_arg : Mini_cc_intf.ARG with type Fun.t = cst and type Term.t = Term.t -module Mini_cc : module type of Mini_cc.Make(T_arg) diff --git a/src/smt/Eq_class.ml b/src/smt/Eq_class.ml deleted file mode 100644 index bdcb9583..00000000 --- a/src/smt/Eq_class.ml +++ /dev/null @@ -1,66 +0,0 @@ - -open Solver_types - -type t = equiv_class -type payload = equiv_class_payload = .. - -let field_is_active = Node_bits.mk_field() -let field_is_pending = Node_bits.mk_field() -let field_is_literal = Node_bits.mk_field() -let () = Node_bits.freeze() - -let[@inline] equal (n1:t) n2 = n1==n2 -let[@inline] hash n = Term.hash n.n_term -let[@inline] term n = n.n_term -let[@inline] payload n = n.n_payload -let[@inline] pp out n = Term.pp out n.n_term - -let make (t:term) : t = - let rec n = { - n_term=t; - n_bits=Node_bits.empty; - n_parents=Bag.empty; - n_root=n; - n_expl=E_none; - n_payload=[]; - n_next=n; - n_size=1; - n_tags=Util.Int_map.empty; - } in - n - -let set_payload ?(can_erase=fun _->false) n e = - let rec aux = function - | [] -> [e] - | e' :: tail when can_erase e' -> e :: tail - | e' :: tail -> e' :: aux tail - in - n.n_payload <- aux n.n_payload - -let payload_find ~f:p n = - let[@unroll 2] rec aux = function - | [] -> None - | e1 :: tail -> - match p e1 with - | Some _ as res -> res - | None -> aux tail - in - aux n.n_payload - -let payload_pred ~f:p n = - begin match n.n_payload with - | [] -> false - | e :: _ when p e -> true - | _ :: e :: _ when p e -> true - | _ :: _ :: e :: _ when p e -> true - | l -> List.exists p l - end - -let[@inline] get_field f t = Node_bits.get f t.n_bits -let[@inline] set_field f b t = t.n_bits <- Node_bits.set f b t.n_bits - -module Tbl = CCHashtbl.Make(struct - type t = equiv_class - let equal = equal - let hash = hash - end) diff --git a/src/smt/Eq_class.mli b/src/smt/Eq_class.mli deleted file mode 100644 index a2f03aa6..00000000 --- a/src/smt/Eq_class.mli +++ /dev/null @@ -1,61 +0,0 @@ - -open Solver_types - -(** {1 Equivalence Classes} *) - -(** An equivalence class is a set of terms that are currently equal - in the partial model built by the solver. - The class is represented by a collection of nodes, one of which is - distinguished and is called the "representative". - - All information pertaining to the whole equivalence class is stored - in this representative's node. - - When two classes become equal (are "merged"), one of the two - representatives is picked as the representative of the new class. - The new class contains the union of the two old classes' nodes. - - We also allow theories to store additional information in the - representative. This information can be used when two classes are - merged, to detect conflicts and solve equations à la Shostak. -*) - -type t = equiv_class -type payload = equiv_class_payload = .. - -val field_is_active : Node_bits.field -(** The term is needed for evaluation. We must try to evaluate it - or to find a value for it using the theory *) - -val field_is_pending : Node_bits.field -(** true iff the node is in the [cc.pending] queue *) - -val field_is_literal : Node_bits.field -(** This term is a boolean literal, subject to propagations *) - -(** {2 basics} *) - -val term : t -> term -val equal : t -> t -> bool -val hash : t -> int -val pp : t Fmt.printer -val payload : t -> payload list - -(** {2 Helpers} *) - -val make : term -> t -(** Make a new equivalence class whose representative is the given term *) - -val payload_find: f:(payload -> 'a option) -> t -> 'a option - -val payload_pred: f:(payload -> bool) -> t -> bool - -val set_payload : ?can_erase:(payload -> bool) -> t -> payload -> unit -(** Add given payload - @param can_erase if provided, checks whether an existing value - is to be replaced instead of adding a new entry *) - -val get_field : Node_bits.field -> t -> bool -val set_field : Node_bits.field -> bool -> t -> unit - -module Tbl : CCHashtbl.S with type key = t diff --git a/src/smt/Explanation.ml b/src/smt/Explanation.ml deleted file mode 100644 index b1e1ee06..00000000 --- a/src/smt/Explanation.ml +++ /dev/null @@ -1,26 +0,0 @@ - -open Solver_types - -type t = explanation = - | E_reduction (* by pure reduction, tautologically equal *) - | E_merges of (equiv_class * equiv_class) IArray.t (* caused by these merges *) - | E_lit of lit (* because of this literal *) - | E_lits of lit list (* because of this (true) conjunction *) - -let compare = cmp_exp -let equal a b = cmp_exp a b = 0 - -let pp = pp_explanation - -let mk_merges l : t = E_merges l -let mk_lit l : t = E_lit l -let mk_lits = function [x] -> mk_lit x | l -> E_lits l -let mk_reduction : t = E_reduction - -let[@inline] lit l : t = E_lit l - -module Set = CCSet.Make(struct - type t = explanation - let compare = compare - end) - diff --git a/src/smt/Lit.ml b/src/smt/Lit.ml index 2554ea7c..c31e3dcf 100644 --- a/src/smt/Lit.ml +++ b/src/smt/Lit.ml @@ -8,7 +8,7 @@ type t = lit = { let[@inline] neg l = {l with lit_sign=not l.lit_sign} let[@inline] sign t = t.lit_sign -let[@inline] view (t:t): term = t.lit_term +let[@inline] term (t:t): term = t.lit_term let[@inline] abs t: t = {t with lit_sign=true} diff --git a/src/smt/Lit.mli b/src/smt/Lit.mli index 073dbf05..2111ce1a 100644 --- a/src/smt/Lit.mli +++ b/src/smt/Lit.mli @@ -10,7 +10,7 @@ type t = lit = { val neg : t -> t val abs : t -> t val sign : t -> bool -val view : t -> term +val term : t -> term val as_atom : t -> term * bool val atom : ?sign:bool -> term -> t val hash : t -> int diff --git a/src/smt/Mini_cc.mli b/src/smt/Mini_cc.mli deleted file mode 100644 index 69359b30..00000000 --- a/src/smt/Mini_cc.mli +++ /dev/null @@ -1,18 +0,0 @@ - -(** {1 Mini congruence closure} *) - -type ('f, 't, 'ts) view = ('f, 't, 'ts) Mini_cc_intf.view = - | Bool of bool - | App of 'f * 'ts - | If of 't * 't * 't - -type res = Mini_cc_intf.res = - | Sat - | Unsat - -module type ARG = Mini_cc_intf.ARG -module type S = Mini_cc_intf.S - -module Make(A: ARG) - : S with type term = A.Term.t - and type fun_ = A.Fun.t diff --git a/src/smt/Mini_cc_intf.ml b/src/smt/Mini_cc_intf.ml deleted file mode 100644 index 52dff3c7..00000000 --- a/src/smt/Mini_cc_intf.ml +++ /dev/null @@ -1,47 +0,0 @@ - -type ('f, 't, 'ts) view = - | Bool of bool - | App of 'f * 'ts - | If of 't * 't * 't - -(* TODO: also HO app, Eq, Distinct cases? - -> then API that just adds boolean terms and does the right thing in case of - Eq/Distinct *) - -type res = - | Sat - | Unsat - -module type ARG = sig - module Fun : sig - type t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - end - - module Term : sig - type t - val equal : t -> t -> bool - val hash : t -> int - val pp : t Fmt.printer - - (** View the term through the lens of the congruence closure *) - val view : t -> (Fun.t, t, t Sequence.t) view - end -end - -module type S = sig - type term - type fun_ - - type t - - val create : unit -> t - - val merge : t -> term -> term -> unit - val distinct : t -> term list -> unit - - val check : t -> res -end - diff --git a/src/smt/Model.ml b/src/smt/Model.ml index a64d24a1..fa0641a0 100644 --- a/src/smt/Model.ml +++ b/src/smt/Model.ml @@ -58,6 +58,24 @@ let empty : t = { funs=Cst.Map.empty; } +(* FIXME: ues this to allocate a default value for each sort + (* get or make a default value for this type *) + let rec get_ty_default (ty:Ty.t) : Value.t = + match Ty.view ty with + | Ty_prop -> Value.true_ + | Ty_atomic { def = Ty_uninterpreted _;_} -> + (* domain element *) + Ty_tbl.get_or_add ty_tbl ~k:ty + ~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty) + | Ty_atomic { def = Ty_def d; args; _} -> + (* ask the theory for a default value *) + Ty_tbl.get_or_add ty_tbl ~k:ty + ~f:(fun _ty -> + let vals = List.map get_ty_default args in + d.default_val vals) + in + *) + let[@inline] mem t m = Term.Map.mem t m.values let[@inline] find t m = Term.Map.get t m.values @@ -102,9 +120,9 @@ let add_funs fs m : t = merge {values=Term.Map.empty; funs=fs} m let pp out {values; funs} = let module FI = Fun_interpretation in - let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ %a@])" Term.pp t Value.pp v in + let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ := %a@])" Term.pp t Value.pp v in let pp_fun_entry out (vals,ret) = - Format.fprintf out "(@[%a@ %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret + Format.fprintf out "(@[%a@ := %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret in let pp_fun out (c, fi: Cst.t * FI.t) = Format.fprintf out "(@[%a :default %a@ %a@])" @@ -127,6 +145,10 @@ let eval (m:t) (t:Term.t) : Value.t option = | V_bool false -> aux c | v -> Error.errorf "@[Model: wrong value@ for boolean %a@ %a@]" Term.pp a Value.pp v end + | Eq(a,b) -> + let a = aux a in + let b = aux b in + if Value.equal a b then Value.true_ else Value.false_ | App_cst (c, args) -> begin try Term.Map.find t m.values with Not_found -> diff --git a/src/smt/Model.mli b/src/smt/Model.mli index c6ac4c04..8dfe6da2 100644 --- a/src/smt/Model.mli +++ b/src/smt/Model.mli @@ -37,10 +37,6 @@ val empty : t val add : Term.t -> Value.t -> t -> t -val add_fun : Cst.t -> Fun_interpretation.t -> t -> t - -val add_funs : Fun_interpretation.t Cst.Map.t -> t -> t - val mem : Term.t -> t -> bool val find : Term.t -> t -> Value.t option diff --git a/src/smt/Sidekick_smt.ml b/src/smt/Sidekick_smt.ml index efb513cc..e61eee7f 100644 --- a/src/smt/Sidekick_smt.ml +++ b/src/smt/Sidekick_smt.ml @@ -20,7 +20,6 @@ module Solver = Solver module Solver_types = Solver_types (**/**) -module Bag = Bag module Vec = Msat.Vec module Log = Msat.Log (**/**) diff --git a/src/smt/Solver.ml b/src/smt/Solver.ml index 1f7a0e25..d3cc98d6 100644 --- a/src/smt/Solver.ml +++ b/src/smt/Solver.ml @@ -208,11 +208,9 @@ let assume (self:t) (c:Lit.t IArray.t) : unit = let c = IArray.to_array_map (Sat_solver.make_atom sat) c in Sat_solver.add_clause_a sat c Proof_default -let[@inline] assume_eq self t u expl : unit = - Congruence_closure.assert_eq (cc self) t u [expl] - +(* TODO: remove? use a special constant + micro theory instead? *) let[@inline] assume_distinct self l ~neq lit : unit = - Congruence_closure.assert_distinct (cc self) l lit ~neq + CC.assert_distinct (cc self) l lit ~neq let check_model (_s:t) : unit = Log.debug 1 "(smt.solver.check-model)"; diff --git a/src/smt/Solver.mli b/src/smt/Solver.mli index 4162183f..c66cdccd 100644 --- a/src/smt/Solver.mli +++ b/src/smt/Solver.mli @@ -47,7 +47,7 @@ val create : val solver : t -> Sat_solver.t val th_combine : t -> Theory_combine.t val add_theory : t -> Theory.t -> unit -val cc : t -> Congruence_closure.t +val cc : t -> CC.t val stats : t -> Stat.t val tst : t -> Term.state @@ -56,7 +56,6 @@ val mk_atom_t : t -> ?sign:bool -> Term.t -> Atom.t val assume : t -> Lit.t IArray.t -> unit -val assume_eq : t -> Term.t -> Term.t -> Lit.t -> unit val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit val solve : diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index 3594755d..f88b317f 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -3,7 +3,6 @@ module Vec = Msat.Vec module Log = Msat.Log module Fmt = CCFormat -module Node_bits = CCBitField.Make(struct end) (* for objects that are expanded on demand only *) type 'a lazily_expanded = @@ -21,43 +20,9 @@ type term = { and 'a term_view = | Bool of bool | App_cst of cst * 'a IArray.t (* full, first-order application *) + | Eq of 'a * 'a | If of 'a * 'a * 'a -(** A node of the congruence closure. - An equivalence class is represented by its "root" element, - the representative. - - If there is a normal form in the congruence class, then the - representative is a normal form *) -and equiv_class = { - n_term: term; - mutable n_bits: Node_bits.t; (* bitfield for various properties *) - mutable n_parents: equiv_class Bag.t; (* parent terms of this node *) - mutable n_root: equiv_class; (* representative of congruence class (itself if a representative) *) - mutable n_next: equiv_class; (* pointer to next element of congruence class *) - mutable n_size: int; (* size of the class *) - mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) - mutable n_payload: equiv_class_payload list; (* list of theory payloads *) - mutable n_tags: (equiv_class * explanation) Util.Int_map.t; (* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *) -} - -(** Theory-extensible payloads *) -and equiv_class_payload = .. - -and explanation_forest_link = - | E_none - | E_some of { - next: equiv_class; - expl: explanation; - } - -(* atomic explanation in the congruence closure *) -and explanation = - | E_reduction (* by pure reduction, tautologically equal *) - | E_merges of (equiv_class * equiv_class) IArray.t (* caused by these merges *) - | E_lit of lit (* because of this literal *) - | E_lits of lit list (* because of this (true) conjunction *) - (* boolean literal *) and lit = { lit_term: term; @@ -157,23 +122,6 @@ let hash_lit a = let sign = a.lit_sign in Hash.combine3 2 (Hash.bool sign) (term_hash_ a.lit_term) -let cmp_cc_node a b = term_cmp_ a.n_term b.n_term - -let cmp_exp a b = - let toint = function - | E_merges _ -> 0 | E_lit _ -> 1 - | E_reduction -> 2 | E_lits _ -> 3 - in - begin match a, b with - | E_merges l1, E_merges l2 -> - IArray.compare (CCOrd.pair cmp_cc_node cmp_cc_node) l1 l2 - | E_reduction, E_reduction -> 0 - | E_lit l1, E_lit l2 -> cmp_lit l1 l2 - | E_lits l1, E_lits l2 -> CCList.compare cmp_lit l1 l2 - | E_merges _, _ | E_lit _, _ | E_lits _, _ | E_reduction, _ - -> CCInt.compare (toint a)(toint b) - end - let pp_cst out a = ID.pp out a.cst_id let id_of_cst a = a.cst_id @@ -215,6 +163,7 @@ let pp_term_view_gen ~pp_id ~pp_t out = function pp_id out (id_of_cst c) | App_cst (f,l) -> Fmt.fprintf out "(@[<1>%a@ %a@])" pp_id (id_of_cst f) (Util.pp_iarray pp_t) l + | Eq (a,b) -> Fmt.fprintf out "(@[=@ %a@ %a@])" pp_t a pp_t b | If (a, b, c) -> Fmt.fprintf out "(@[if %a@ %a@ %a@])" pp_t a pp_t b pp_t c @@ -233,14 +182,5 @@ let pp_lit out l = if l.lit_sign then pp_term out l.lit_term else Format.fprintf out "(@[@<1>¬@ %a@])" pp_term l.lit_term -let pp_cc_node out n = pp_term out n.n_term - -let pp_explanation out (e:explanation) = match e with - | E_reduction -> Fmt.string out "reduction" - | E_lit lit -> pp_lit out lit - | E_lits l -> CCFormat.Dump.list pp_lit out l - | E_merges l -> - Format.fprintf out "(@[merges@ %a@])" - Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ - pair ~sep:(return "@ <-> ") pp_cc_node pp_cc_node) - (IArray.to_seq l) +let pp_proof out = function + | Proof_default -> Fmt.fprintf out "" diff --git a/src/smt/Term.ml b/src/smt/Term.ml index a4ebead7..d0223da2 100644 --- a/src/smt/Term.ml +++ b/src/smt/Term.ml @@ -10,6 +10,7 @@ type t = term = { type 'a view = 'a term_view = | Bool of bool | App_cst of cst * 'a IArray.t + | Eq of 'a * 'a | If of 'a * 'a * 'a let[@inline] id t = t.term_id @@ -47,6 +48,7 @@ let[@inline] make st (c:t term_view) : t = let[@inline] true_ st = Lazy.force st.true_ let[@inline] false_ st = Lazy.force st.false_ +let bool st b = if b then true_ st else false_ st let create ?(size=1024) () : state = let rec st ={ @@ -66,9 +68,9 @@ let app_cst st f a = let cell = Term_cell.app_cst f a in make st cell -let const st c = app_cst st c IArray.empty - -let if_ st a b c = make st (Term_cell.if_ a b c) +let[@inline] const st c = app_cst st c IArray.empty +let[@inline] if_ st a b c = make st (Term_cell.if_ a b c) +let[@inline] eq st a b = make st (Term_cell.eq a b) (* "eager" and, evaluating [a] first *) let and_eager st a b = if_ st a b (false_ st) @@ -87,10 +89,12 @@ let[@inline] is_const t = match view t with | _ -> false let cc_view (t:t) = - let module C = Mini_cc in + let module C = Sidekick_cc in match view t with | Bool b -> C.Bool b - | App_cst (f,args) -> C.App (f, IArray.to_seq args) + | App_cst (f,_) when not (Cst.do_cc f) -> C.Opaque t (* skip *) + | App_cst (f,args) -> C.App_fun (f, IArray.to_seq args) + | Eq (a,b) -> C.Eq (a, b) | If (a,b,c) -> C.If (a,b,c) module As_key = struct @@ -109,6 +113,7 @@ let to_seq t yield = match view t with | Bool _ -> () | App_cst (_,a) -> IArray.iter aux a + | Eq (a,b) -> aux a; aux b | If (a,b,c) -> aux a; aux b; aux c in aux t @@ -121,3 +126,14 @@ let as_cst_undef (t:term): (cst * Ty.Fun.t) option = let pp = Solver_types.pp_term + +(* TODO + module T_arg = struct + module Fun = Cst + module Term = struct + include Term + let view = cc_view + end + end + module Mini_cc = Mini_cc.Make(T_arg) + *) diff --git a/src/smt/Term.mli b/src/smt/Term.mli index 38dc9a98..5e47042b 100644 --- a/src/smt/Term.mli +++ b/src/smt/Term.mli @@ -10,6 +10,7 @@ type t = term = { type 'a view = 'a term_view = | Bool of bool | App_cst of cst * 'a IArray.t + | Eq of 'a * 'a | If of 'a * 'a * 'a val id : t -> int @@ -26,8 +27,10 @@ val create : ?size:int -> unit -> state val make : state -> t view -> t val true_ : state -> t val false_ : state -> t +val bool : state -> bool -> t val const : state -> cst -> t val app_cst : state -> cst -> t IArray.t -> t +val eq : state -> t -> t -> t val if_: state -> t -> t -> t -> t val and_eager : state -> t -> t -> t (* evaluate left argument first *) @@ -49,7 +52,7 @@ val is_true : t -> bool val is_false : t -> bool val is_const : t -> bool -val cc_view : t -> (cst,t,t Sequence.t) Mini_cc.view +val cc_view : t -> (cst,t,t Sequence.t) Sidekick_cc.view (* return [Some] iff the term is an undefined constant *) val as_cst_undef : t -> (cst * Ty.Fun.t) option diff --git a/src/smt/Term_cell.ml b/src/smt/Term_cell.ml index 90fd9575..b873e3c5 100644 --- a/src/smt/Term_cell.ml +++ b/src/smt/Term_cell.ml @@ -6,6 +6,7 @@ open Solver_types type 'a view = 'a Solver_types.term_view = | Bool of bool | App_cst of cst * 'a IArray.t + | Eq of 'a * 'a | If of 'a * 'a * 'a type t = term view @@ -25,6 +26,7 @@ module Make_eq(A : ARG) = struct | Bool b -> Hash.bool b | App_cst (f,l) -> Hash.combine3 4 (Cst.hash f) (Hash.iarray sub_hash l) + | Eq (a,b) -> Hash.combine3 12 (sub_hash a) (sub_hash b) | If (a,b,c) -> Hash.combine4 7 (sub_hash a) (sub_hash b) (sub_hash c) (* equality that relies on physical equality of subterms *) @@ -32,9 +34,10 @@ module Make_eq(A : ARG) = struct | Bool b1, Bool b2 -> CCBool.equal b1 b2 | App_cst (f1, a1), App_cst (f2, a2) -> Cst.equal f1 f2 && IArray.equal sub_eq a1 a2 + | Eq(a1,b1), Eq(a2,b2) -> sub_eq a1 a2 && sub_eq b1 b2 | If (a1,b1,c1), If (a2,b2,c2) -> sub_eq a1 a2 && sub_eq b1 b2 && sub_eq c1 c2 - | Bool _, _ | App_cst _, _ | If _, _ + | Bool _, _ | App_cst _, _ | If _, _ | Eq _, _ -> false let pp = Solver_types.pp_term_view_gen ~pp_id:ID.pp_name ~pp_t:A.pp @@ -53,17 +56,25 @@ let false_ = Bool false let is_value = function | Bool _ -> true | App_cst ({cst_view=Cst_def r;_}, _) -> r.is_value - | If _ | App_cst _ -> false + | If _ | App_cst _ | Eq _ -> false let app_cst f a = App_cst (f, a) let const c = App_cst (c, IArray.empty) +let eq a b = + if term_equal_ a b then ( + Bool true + ) else ( + (* canonize *) + let a,b = if a.term_id > b.term_id then b, a else a, b in + Eq (a,b) + ) let if_ a b c = assert (Ty.equal b.term_ty c.term_ty); If (a,b,c) let ty (t:t): Ty.t = match t with - | Bool _ -> Ty.prop + | Bool _ | Eq _ -> Ty.prop | App_cst (f, args) -> begin match Cst.view f with | Cst_undef fty -> diff --git a/src/smt/Term_cell.mli b/src/smt/Term_cell.mli index 35f31f99..cce393b4 100644 --- a/src/smt/Term_cell.mli +++ b/src/smt/Term_cell.mli @@ -4,6 +4,7 @@ open Solver_types type 'a view = 'a Solver_types.term_view = | Bool of bool | App_cst of cst * 'a IArray.t + | Eq of 'a * 'a | If of 'a * 'a * 'a type t = term view @@ -15,6 +16,7 @@ val true_ : t val false_ : t val const : cst -> t val app_cst : cst -> term IArray.t -> t +val eq : term -> term -> t val if_ : term -> term -> term -> t val is_value : t -> bool diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index f165bc7e..d877d4d9 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -18,6 +18,9 @@ end Its negation will become a conflict clause *) type conflict = Lit.t list +module CC_eq_class = CC.N +module CC_expl = CC.Expl + (** Actions available to a theory during its lifetime *) module type ACTIONS = sig val raise_conflict: conflict -> 'a @@ -41,12 +44,15 @@ module type ACTIONS = sig (** Add toplevel clause to the SAT solver. This clause will not be backtracked. *) - val find: Term.t -> Eq_class.t - (** Find representative of this term *) + val cc_add_term: Term.t -> CC_eq_class.t + (** add/get term to the congruence closure *) - val all_classes: Eq_class.t Sequence.t + val cc_find: CC_eq_class.t -> CC_eq_class.t + (** Find representative of this in the congruence closure *) + + val cc_all_classes: CC_eq_class.t Sequence.t (** All current equivalence classes - (caution: linear in the number of terms existing in the solver) *) + (caution: linear in the number of terms existing in the congruence closure) *) end type actions = (module ACTIONS) @@ -60,7 +66,7 @@ module type S = sig val create : Term.state -> t (** Instantiate the theory's state *) - val on_merge: t -> actions -> Eq_class.t -> Eq_class.t -> Explanation.t -> unit + val on_merge: t -> actions -> CC_eq_class.t -> CC_eq_class.t -> CC_expl.t -> unit (** Called when two classes are merged *) val partial_check : t -> actions -> Lit.t Sequence.t -> unit @@ -70,7 +76,7 @@ module type S = sig (** Final check, must be complete (i.e. must raise a conflict if the set of literals is not satisfiable) *) - val mk_model : t -> Lit.t Sequence.t -> Model.t + val mk_model : t -> Lit.t Sequence.t -> Model.t -> Model.t (** Make a model for this theory's terms *) val push_level : t -> unit @@ -91,7 +97,7 @@ let make ?(check_invariants=fun _ -> ()) ?(on_merge=fun _ _ _ _ _ -> ()) ?(partial_check=fun _ _ _ -> ()) - ?(mk_model=fun _ _ -> Model.empty) + ?(mk_model=fun _ _ m -> m) ?(push_level=fun _ -> ()) ?(pop_levels=fun _ _ -> ()) ~name diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index faff92db..a89db7ef 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -3,7 +3,6 @@ (** Combine the congruence closure with a number of plugins *) -module C_clos = Congruence_closure open Solver_types module Proof = struct @@ -12,6 +11,8 @@ module Proof = struct end module Formula = Lit +module Eq_class = CC.N +module Expl = CC.Expl type formula = Lit.t type proof = Proof.t @@ -24,11 +25,11 @@ type theory_state = type t = { tst: Term.state; (** state for managing terms *) - cc: C_clos.t lazy_t; + cc: CC.t lazy_t; (** congruence closure *) mutable theories : theory_state list; (** Set of theories *) - new_merges: (Eq_class.t * Eq_class.t * explanation) Vec.t; + new_merges: (Eq_class.t * Eq_class.t * Expl.t) Vec.t; } let[@inline] cc (t:t) = Lazy.force t.cc @@ -41,24 +42,28 @@ let[@inline] theories (self:t) : theory_state Sequence.t = (* handle a literal assumed by the SAT solver *) let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit = Msat.Log.debugf 2 - (fun k->k "(@[<1>@{th_combine.assume_lits@}@ @[%a@]@])" (Fmt.seq Lit.pp) lits); + (fun k->k "(@[@{th_combine.assume_lits@}@ %a@])" + (Util.pp_seq ~sep:";" Lit.pp) lits); (* transmit to CC *) Vec.clear self.new_merges; let cc = cc self in - C_clos.assert_lits cc lits; + if not final then ( + CC.assert_lits cc lits; + ); (* transmit to theories. *) - C_clos.check cc acts; + CC.check cc acts; let module A = struct let[@inline] raise_conflict c : 'a = acts.Msat.acts_raise_conflict c Proof_default - let[@inline] propagate_eq t u expl : unit = C_clos.assert_eq cc t u expl - let propagate_distinct ts ~neq expl = C_clos.assert_distinct cc ts ~neq expl + let[@inline] propagate_eq t u expl : unit = CC.assert_eq cc t u expl + let propagate_distinct ts ~neq expl = CC.assert_distinct cc ts ~neq expl let[@inline] propagate p cs : unit = acts.Msat.acts_propagate p (Msat.Consequence (cs, Proof_default)) let[@inline] add_local_axiom lits : unit = acts.Msat.acts_add_clause ~keep:false lits Proof_default let[@inline] add_persistent_axiom lits : unit = acts.Msat.acts_add_clause ~keep:true lits Proof_default - let[@inline] find t = C_clos.find_t cc t - let all_classes = C_clos.all_classes cc + let[@inline] cc_add_term t = CC.add_term cc t + let[@inline] cc_find t = CC.find cc t + let cc_all_classes = CC.all_classes cc end in let acts = (module A : Theory.ACTIONS) in theories self @@ -83,10 +88,10 @@ let check_ ~final (self:t) (acts:_ Msat.acts) = assert_lits_ ~final self acts iter let add_formula (self:t) (lit:Lit.t) = - let t = Lit.view lit in + let t = Lit.term lit in let lazy cc = self.cc in - let n = C_clos.add cc t in - Eq_class.set_field Eq_class.field_is_literal true n; + let n = CC.add_term cc t in + CC.set_as_lit cc n (Lit.abs lit); () (* propagation from the bool solver *) @@ -98,21 +103,21 @@ let[@inline] final_check (self:t) (acts:_ Msat.acts) : unit = check_ ~final:true self acts let push_level (self:t) : unit = - C_clos.push_level (cc self); + CC.push_level (cc self); theories self (fun (Th_state ((module Th), st)) -> Th.push_level st) let pop_levels (self:t) n : unit = - C_clos.pop_levels (cc self) n; + CC.pop_levels (cc self) n; theories self (fun (Th_state ((module Th), st)) -> Th.pop_levels st n) let mk_model (self:t) lits : Model.t = let m = Sequence.fold - (fun m (Th_state ((module Th),st)) -> Model.merge m (Th.mk_model st lits)) + (fun m (Th_state ((module Th),st)) -> Th.mk_model st lits m) Model.empty (theories self) in (* now complete model using CC *) - Congruence_closure.mk_model (cc self) m + CC.mk_model (cc self) m (** {2 Interface to Congruence Closure} *) @@ -131,16 +136,16 @@ let create () : t = cc = lazy ( (* lazily tie the knot *) let on_merge = on_merge_from_cc self in - C_clos.create ~on_merge ~size:`Big self.tst; + CC.create ~on_merge ~size:`Big self.tst; ); theories = []; } in - ignore (Lazy.force @@ self.cc : C_clos.t); + ignore (Lazy.force @@ self.cc : CC.t); self let check_invariants (self:t) = if Util._CHECK_INVARIANTS then ( - Congruence_closure.check_invariants (cc self); + CC.check_invariants (cc self); ) let add_theory (self:t) (th:Theory.t) : unit = diff --git a/src/smt/Theory_combine.mli b/src/smt/Theory_combine.mli index f8490ebc..8d723954 100644 --- a/src/smt/Theory_combine.mli +++ b/src/smt/Theory_combine.mli @@ -13,7 +13,7 @@ include Msat.Solver_intf.PLUGIN_CDCL_T val create : unit -> t -val cc : t -> Congruence_closure.t +val cc : t -> CC.t val tst : t -> Term.state type theory_state = diff --git a/src/smt/Value.ml b/src/smt/Value.ml index 3f2d9181..9057db36 100644 --- a/src/smt/Value.ml +++ b/src/smt/Value.ml @@ -19,3 +19,5 @@ let equal = eq_value let hash = hash_value let pp = pp_value +let fresh (t:term) : t = + mk_elt (ID.makef "v_%d" t.term_id) t.term_ty diff --git a/src/smt/Value.mli b/src/smt/Value.mli index c44fe86b..5bfadde6 100644 --- a/src/smt/Value.mli +++ b/src/smt/Value.mli @@ -15,6 +15,8 @@ val is_bool : t -> bool val is_true : t -> bool val is_false : t -> bool +val fresh : Term.t -> t + include Intf.EQ with type t := t include Intf.HASH with type t := t include Intf.PRINT with type t := t diff --git a/src/smt/dune b/src/smt/dune index d3ded68d..2f7f98ec 100644 --- a/src/smt/dune +++ b/src/smt/dune @@ -2,7 +2,8 @@ (library (name Sidekick_smt) (public_name sidekick.smt) - (libraries containers containers.data sequence sidekick.util msat zarith) + (libraries containers containers.data sequence + sidekick.util sidekick.cc msat zarith) (flags :standard -warn-error -a+8 -color always -safe-string -short-paths -open Sidekick_util) (ocamlopt_flags :standard -O3 -color always diff --git a/src/th-bool/Sidekick_th_bool.ml b/src/th-bool/Sidekick_th_bool.ml index 09dcec2b..07bc90ad 100644 --- a/src/th-bool/Sidekick_th_bool.ml +++ b/src/th-bool/Sidekick_th_bool.ml @@ -15,7 +15,6 @@ let id_not = ID.make "not" let id_and = ID.make "and" let id_or = ID.make "or" let id_imply = ID.make "=>" -let id_eq = ID.make "=" let id_distinct = ID.make "distinct" type 'a view = @@ -32,8 +31,6 @@ exception Not_a_th_term let view_id cst_id args = if ID.equal cst_id id_not && IArray.length args=1 then ( B_not (IArray.get args 0) - ) else if ID.equal cst_id id_eq && IArray.length args=2 then ( - B_eq (IArray.get args 0, IArray.get args 1) ) else if ID.equal cst_id id_and then ( B_and args ) else if ID.equal cst_id id_or then ( @@ -45,13 +42,14 @@ let view_id cst_id args = ) else if ID.equal cst_id id_distinct then ( B_distinct args ) else ( - raise Not_a_th_term + raise_notrace Not_a_th_term ) let view (t:Term.t) : term view = match Term.view t with + | Eq (a,b) -> B_eq (a,b) | App_cst ({cst_id; _}, args) -> - (try view_id cst_id args with Not_a_th_term -> B_atom t) + begin try view_id cst_id args with Not_a_th_term -> B_atom t end | _ -> B_atom t @@ -59,9 +57,6 @@ module C = struct let get_ty _ _ = Ty.prop - (* no congruence closure, except for `=` *) - let relevant id _ _ = ID.equal id_eq id - let abs ~self _a = match Term.view self with | App_cst ({cst_id;_}, args) when ID.equal cst_id id_not && IArray.length args=1 -> @@ -89,6 +84,9 @@ module C = struct | B_not _ | B_and _ | B_or _ | B_imply _ -> Error.errorf "non boolean value in boolean connective" + (* no congruence closure for boolean terms *) + let relevant _id _ _ = false + let mk_cst ?(do_cc=false) id : Cst.t = {cst_id=id; cst_view=Cst_def { @@ -98,7 +96,6 @@ module C = struct let and_ = mk_cst id_and let or_ = mk_cst id_or let imply = mk_cst id_imply - let eq = mk_cst ~do_cc:true id_eq let distinct = mk_cst id_distinct end @@ -134,13 +131,7 @@ let or_l st l = let and_ st a b = and_l st [a;b] let or_ st a b = or_l st [a;b] -let eq st a b = - if Term.equal a b then ( - Term.true_ st - ) else ( - let a,b = if Term.id a > Term.id b then b, a else a, b in - Term.app_cst st C.eq (IArray.doubleton a b) - ) +let eq = Term.eq let not_ st a = match as_id id_not a, Term.view a with @@ -164,7 +155,7 @@ let distinct st = function module Lit = struct include Lit let eq tst a b = Lit.atom ~sign:true (eq tst a b) - let neq tst a b = Lit.atom ~sign:false (neq tst a b) + let neq tst a b = neg @@ eq tst a b end type t = { @@ -175,14 +166,8 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie let (module A) = acts in Log.debugf 5 (fun k->k "(@[th_bool.tseitin@ %a@])" Lit.pp lit); match v with - | B_atom _ -> () | B_not _ -> assert false (* normalized *) - | B_eq (t,u) -> - if Lit.sign lit then ( - A.propagate_eq t u [lit] - ) else ( - A.propagate_distinct [t;u] ~neq:lit_t lit - ) + | B_atom _ | B_eq _ -> () (* CC will manage *) | B_distinct l -> let l = IArray.to_list l in if Lit.sign lit then ( @@ -197,7 +182,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie IArray.iter (fun sub -> let sublit = Lit.atom sub in - A.propagate sublit [lit]) + A.add_local_axiom [Lit.neg lit; sublit]) subs ) else ( (* propagate [¬lit => ∨_i ¬ subs_i] *) @@ -216,7 +201,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie IArray.iter (fun sub -> let sublit = Lit.atom ~sign:false sub in - A.propagate sublit [lit]) + A.add_local_axiom [Lit.neg lit; sublit]) subs ) | B_imply (guard,concl) -> @@ -239,7 +224,7 @@ let tseitin (_self:t) (acts:Theory.actions) (lit:Lit.t) (lit_t:term) (v:term vie let partial_check (self:t) acts (lits:Lit.t Sequence.t) = lits (fun lit -> - let t = Lit.view lit in + let t = Lit.term lit in match view t with | B_atom _ -> () | v -> tseitin self acts lit t v) diff --git a/src/smt/Bag.ml b/src/util/Bag.ml similarity index 100% rename from src/smt/Bag.ml rename to src/util/Bag.ml diff --git a/src/smt/Bag.mli b/src/util/Bag.mli similarity index 100% rename from src/smt/Bag.mli rename to src/util/Bag.mli diff --git a/src/util/Sidekick_util.ml b/src/util/Sidekick_util.ml index 809afae7..76ae82b9 100644 --- a/src/util/Sidekick_util.ml +++ b/src/util/Sidekick_util.ml @@ -9,3 +9,4 @@ module Backtrack_stack = Backtrack_stack module Error = Error module IArray = IArray module Intf = Intf +module Bag = Bag