From 342dba453342c7cdac865f27795bd45d0d24c4b2 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 26 Feb 2019 22:46:40 -0600 Subject: [PATCH] wip: new micro-theories in CC --- src/cc/Congruence_closure.ml | 351 ++++++++++++++++++------------ src/cc/Congruence_closure.mli | 5 +- src/cc/Congruence_closure_intf.ml | 125 ++++++----- src/cc/Sidekick_cc.ml | 1 + src/smt/Model.ml | 5 +- src/smt/Sidekick_smt.ml | 3 + src/smt/Solver.ml | 3 +- src/smt/Solver.mli | 2 + src/smt/Theory.ml | 19 +- src/smt/Theory_combine.ml | 6 +- src/smtlib/Process.ml | 3 +- src/smtlib/dune | 8 +- src/th-bool/Bool_intf.ml | 2 - src/th-bool/Bool_term.ml | 25 +-- src/th-bool/Bool_term.mli | 2 - src/th-bool/Sidekick_th_bool.ml | 2 - src/th-bool/Th_dyn_tseitin.ml | 28 +-- src/th-bool/Th_dyn_tseitin.mli | 6 - src/th-bool/dune | 5 +- 19 files changed, 307 insertions(+), 294 deletions(-) diff --git a/src/cc/Congruence_closure.ml b/src/cc/Congruence_closure.ml index 5af965fd..40020bff 100644 --- a/src/cc/Congruence_closure.ml +++ b/src/cc/Congruence_closure.ml @@ -3,41 +3,62 @@ open CC_types module type ARG = Congruence_closure_intf.ARG module type S = Congruence_closure_intf.S -module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA + module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY -type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data - -module type KEY_IMPL = sig - include THEORY_DATA - exception Store of t - val id : int -end - (** Custom keys for theory data. This imitates the classic tricks for heterogeneous maps https://blog.janestreet.com/a-universal-type/ - *) + + It needs to form a commutative monoid where values are persistent so + they can be restored during backtracking. +*) module Key = struct - type ('term, 'a) t = (module KEY_IMPL with type term = 'term and type t = 'a) + module type KEY_IMPL = sig + type term + type lit + type t + val id : int + val name : string + val pp : t Fmt.printer + val equal : t -> t -> bool + val merge : t -> t -> t + exception Store of t + end + + type ('term,'lit,'a) t = + (module KEY_IMPL with type term = 'term and type lit = 'lit and type t = 'a) let n_ = ref 0 - let create (type term)(type d) (th:(term,d) theory_data) : (term,d) t = - let (module TH) = th in + let create (type term)(type lit)(type d) + ?(pp=fun out _ -> Fmt.string out "") + ~name ~eq ~merge () : (term,lit,d) t = let module K = struct - include TH - exception Store of d + type nonrec term = term + type nonrec lit = lit + type t = d let id = !n_ + let name = name + let pp = pp + let merge = merge + let equal = eq + exception Store of d end in incr n_; (module K) - let id (module K : KEY_IMPL) = K.id + let[@inline] id + : type term lit a. (term,lit,a) t -> int + = fun (module K) -> K.id - let equal - : type a b term. (term,a) t -> (term,b) t -> bool + let[@inline] equal + : type term lit a b. (term,lit,a) t -> (term,lit,b) t -> bool = fun (module K1) (module K2) -> K1.id = K2.id + + let pp + : type term lit a. (term,lit,a) t Fmt.printer + = fun out (module K) -> Fmt.string out K.name end module Bits = CCBitField.Make() @@ -67,12 +88,12 @@ module Make(A: ARG) = struct module T = A.Term module Fun = A.Fun module Key = Key - + module IM = Map.Make(CCInt) (** Map for theory data associated with representatives *) module K_map = struct - type pair = Pair : (term, 'a) Key.t * exn -> pair - module IM = Map.Make(CCInt) + type 'a key = (term,lit,'a) Key.t + type pair = Pair : 'a key * exn -> pair type t = pair IM.t @@ -80,20 +101,18 @@ module Make(A: ARG) = struct let[@inline] mem k t = IM.mem (Key.id k) t - let is_empty = IM.is_empty - - let find (type a) (k : (term,a) Key.t) (self:t) : a option = + let find (type a) (k : a key) (self:t) : a option = let (module K) = k in match IM.find K.id self with | Pair (_, K.Store v) -> Some v | _ -> None | exception Not_found -> None - let add (type a) (k : (term,a) Key.t) (v:a) (self:t) : t = + let add (type a) (k : a key) (v:a) (self:t) : t = let (module K) = k in IM.add K.id (Pair (k, K.Store v)) self - let remove (type a) (k: (term,a) Key.t) self : t = + let remove (type a) (k: a key) self : t = let (module K) = k in IM.remove K.id self @@ -102,22 +121,23 @@ module Make(A: ARG) = struct (fun p1 p2 -> let Pair ((module K1), v1) = p1 in let Pair ((module K2), v2) = p2 in - K1.id = K2.id && + assert (K1.id = K2.id); match v1, v2 with K1.Store v1, K1.Store v2 -> K1.equal v1 v2 | _ -> false) m1 m2 - let merge (m1:t) (m2:t) : t = + let merge ~f_both (m1:t) (m2:t) : t = IM.merge (fun _ p1 p2 -> match p1, p2 with | None, None -> None | Some v, None | None, Some v -> Some v - | Some (Pair ((module K1) as key1, v1)), Some (Pair (_, v2)) -> - match v1, v2 with + | Some (Pair ((module K1) as key1, pair1)), Some (Pair (_, pair2)) -> + match pair1, pair2 with | K1.Store v1, K1.Store v2 -> - (* merge content *) - Some (Pair (key1, K1.Store (K1.merge v1 v2))) + f_both K1.id pair1 pair2; (* callback for checking compat *) + let v12 = K1.merge v1 v2 in (* merge content *) + Some (Pair (key1, K1.Store v12)) | _ -> assert false ) m1 m2 @@ -137,9 +157,6 @@ module Make(A: ARG) = struct mutable n_as_lit: lit option; (* TODO: put into payload? and only in root? *) mutable n_expl: explanation_forest_link; (* the rooted forest for explanations *) mutable n_th_data: K_map.t; (* theory data *) - (* TODO: make a micro theory and move this inside *) - mutable n_tags: (node * explanation) Util.Int_map.t; - (* "distinct" tags (i.e. set of `(distinct t1…tn)` terms this belongs to *) } and signature = (fun_, node, node list) view @@ -151,15 +168,13 @@ module Make(A: ARG) = struct expl: explanation; } - (* TODO: make this recursive (the list case) *) (* atomic explanation in the congruence closure *) and explanation = | E_reduction (* by pure reduction, tautologically equal *) - | E_merge of node * node - | E_merges of (node * node) list (* caused by these merges *) - | E_congruence of node * node (* caused by normal congruence *) | E_lit of lit (* because of this literal *) - | E_lits of lit list (* because of this (true) conjunction *) + | E_merge of node * node + | E_list of explanation list + | E_congruence of node * node (* caused by normal congruence *) type repr = node type conflict = lit list @@ -185,7 +200,6 @@ module Make(A: ARG) = struct n_next=n; n_size=1; n_th_data=K_map.empty; - n_tags=Util.Int_map.empty; } in n @@ -217,30 +231,24 @@ module Make(A: ARG) = struct module Expl = struct type t = explanation - let pp out (e:explanation) = match e with + let rec pp out (e:explanation) = match e with | E_reduction -> Fmt.string out "reduction" | E_lit lit -> A.Lit.pp out lit | E_congruence (n1,n2) -> Fmt.fprintf out "(@[congruence@ %a@ %a@])" N.pp n1 N.pp n2 - | E_lits l -> CCFormat.Dump.list A.Lit.pp out l | E_merge (a,b) -> Fmt.fprintf out "(@[merge@ %a@ %a@])" N.pp a N.pp b - | E_merges l -> - Format.fprintf out "(@[merges@ %a@])" - Fmt.(seq ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ - pair ~sep:(return " ~@ ") N.pp N.pp) - (Sequence.of_list l) + | E_list l -> + Format.fprintf out "(@[and@ %a@])" + Fmt.(list ~sep:(return "@ ") @@ within "[" "]" @@ hvbox @@ pp) l let mk_reduction : t = E_reduction let[@inline] mk_congruence n1 n2 : t = E_congruence (n1,n2) let[@inline] mk_merge a b : t = E_merge (a,b) - let[@inline] mk_merges = function - | [] -> mk_reduction - | [(a,b)] -> mk_merge a b - | l -> E_merges l let[@inline] mk_lit l : t = E_lit l - let[@inline] mk_lits = function + let mk_list l = + match l with | [] -> mk_reduction - | [x] -> mk_lit x - | l -> E_lits l + | [x] -> x + | l -> E_list l end (** A signature is a shallow term shape where immediate subterms @@ -290,7 +298,15 @@ module Make(A: ARG) = struct type combine_task = | CT_merge of node * node * explanation - | CT_distinct of node list * int * explanation + + module type THEORY = sig + type cc + type data + val key_id : int + val key : (term,lit,data) Key.t + val on_merge : cc -> N.t -> data -> N.t -> data -> Expl.t -> unit + val on_new_term: cc -> term -> data option + end type t = { tst: term_state; @@ -307,8 +323,7 @@ module Make(A: ARG) = struct pending: node Vec.t; combine: combine_task Vec.t; undo: (unit -> unit) Backtrack_stack.t; - mutable on_merge: (t -> repr -> repr -> explanation -> unit) list; - mutable on_new_term: (t -> repr -> term -> unit) list; + mutable theories: theory IM.t; mutable ps_lits: lit list; (* TODO: thread it around instead? *) (* proof state *) ps_queue: (node*node) Vec.t; @@ -322,6 +337,10 @@ module Make(A: ARG) = struct several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) + and theory = (module THEORY with type cc = t) + + type cc = t + let[@inline] size_ (r:repr) = r.n_size let[@inline] true_ cc = Lazy.force cc.true_ let[@inline] false_ cc = Lazy.force cc.false_ @@ -333,8 +352,10 @@ module Make(A: ARG) = struct (* check if [t] is in the congruence closure. Invariant: [in_cc t ∧ do_cc t => forall u subterm t, in_cc u] *) let[@inline] mem (cc:t) (t:term): bool = T_tbl.mem cc.tbl t + (* FIXME let on_merge cc f = cc.on_merge <- f :: cc.on_merge let on_new_term cc f = cc.on_new_term <- f :: cc.on_new_term + *) (* find representative, recursively *) let[@unroll 2] rec find_rec (n:node) : repr = @@ -378,28 +399,6 @@ module Make(A: ARG) = struct (Util.pp_seq ~sep:" " pp_n) (T_tbl.values cc.tbl) (Util.pp_seq ~sep:" " pp_sig_e) (Sig_tbl.to_seq cc.signatures_tbl) - let th_data_get (_:t) (n:node) (key: _ Key.t) : _ option = - let n = find_ n in - K_map.find key n.n_th_data - - (* update data for [n] *) - let th_data_add (type a) (self:t) (n:node) (key: (term,a) Key.t) (v:a) : unit = - let n = find_ n in - let map = n.n_th_data in - let old_v = K_map.find key map in - let v', is_diff = match old_v with - | None -> v, true - | Some old_v -> - let (module K) = key in - let v' = K.merge old_v v in - v', K.equal v v' - in - if is_diff then ( - on_backtrack self (fun () -> n.n_th_data <- map); - ); - n.n_th_data <- K_map.add key v' map; - () - (* compute up-to-date signature *) let update_sig (s:signature) : Signature.t = CC_types.map_view s @@ -506,34 +505,30 @@ module Make(A: ARG) = struct (* TODO: turn this into a fold? *) (* decompose explanation [e] of why [n1 = n2] *) - let decompose_explain cc (e:explanation) : unit = + let rec decompose_explain cc (e:explanation) : unit = Log.debugf 5 (fun k->k "(@[cc.decompose_expl@ %a@])" Expl.pp e); - begin match e with - | E_reduction -> () - | E_congruence (n1, n2) -> - begin match n1.n_sig0, n2.n_sig0 with - | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> - assert (Fun.equal f1 f2); - assert (List.length a1 = List.length a2); - List.iter2 (ps_add_obligation cc) a1 a2; - | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> - assert (List.length a1 = List.length a2); - ps_add_obligation cc f1 f2; - List.iter2 (ps_add_obligation cc) a1 a2; - | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> - ps_add_obligation cc a1 a2; - ps_add_obligation cc b1 b2; - ps_add_obligation cc c1 c2; - | _ -> - assert false - end - | E_lit lit -> ps_add_lit cc lit - | E_lits l -> List.iter (ps_add_lit cc) l - | E_merge (a,b) -> ps_add_obligation cc a b - | E_merges l -> - (* need to explain each merge in [l] *) - List.iter (fun (t,u) -> ps_add_obligation cc t u) l - end + match e with + | E_reduction -> () + | E_congruence (n1, n2) -> + begin match n1.n_sig0, n2.n_sig0 with + | Some (App_fun (f1, a1)), Some (App_fun (f2, a2)) -> + assert (Fun.equal f1 f2); + assert (List.length a1 = List.length a2); + List.iter2 (ps_add_obligation cc) a1 a2; + | Some (App_ho (f1, a1)), Some (App_ho (f2, a2)) -> + assert (List.length a1 = List.length a2); + ps_add_obligation cc f1 f2; + List.iter2 (ps_add_obligation cc) a1 a2; + | Some (If (a1,b1,c1)), Some (If (a2,b2,c2)) -> + ps_add_obligation cc a1 a2; + ps_add_obligation cc b1 b2; + ps_add_obligation cc c1 c2; + | _ -> + assert false + end + | E_lit lit -> ps_add_lit cc lit + | E_merge (a,b) -> ps_add_obligation cc a b + | E_list l -> List.iter (decompose_explain cc) l (* explain why [a = parent_a], where [a -> ... -> parent_a] in the proof forest *) @@ -575,6 +570,7 @@ module Make(A: ARG) = struct decompose_explain cc e; explain_loop cc + (* FIXME remove (* add [tag] to [n], indicating that [n] is distinct from all the other nodes tagged with [tag] precond: [n] is a representative *) @@ -585,6 +581,7 @@ module Make(A: ARG) = struct (fun () -> n.n_tags <- Util.Int_map.remove tag n.n_tags); n.n_tags <- Util.Int_map.add tag (n,expl) n.n_tags; ) + *) (* add a term *) let [@inline] rec add_term_rec_ cc t : node = @@ -611,7 +608,16 @@ module Make(A: ARG) = struct (* [n] might be merged with other equiv classes *) push_pending cc n; ); - List.iter (fun f -> f cc n t) cc.on_new_term; + (* initial theory data *) + let th_map = + IM.fold + (fun _ (module Th: THEORY with type cc=cc) th_map -> + match Th.on_new_term cc t with + | None -> th_map + | Some v -> K_map.add Th.key v th_map) + cc.theories K_map.empty + in + n.n_th_data <- th_map; n (* compute the initial signature of the given node *) @@ -701,7 +707,6 @@ module Make(A: ARG) = struct (* TODO: remove, once we have moved distinct to a theory *) and[@inline] task_combine_ cc acts = function | CT_merge (a,b,e_ab) -> task_merge_ cc acts a b e_ab - | CT_distinct (l,tag,e) -> task_distinct_ cc acts l tag e (* main CC algo: merge equivalence classes in [st.combine]. @raise Exn_unsat if merge fails *) @@ -731,26 +736,6 @@ module Make(A: ARG) = struct else if size_ ra > size_ rb then rb, ra else ra, rb in - (* TODO: instead call micro theories, including "distinct" *) - (* update set of tags the new node cannot be equal to *) - let new_tags = - Util.Int_map.union - (fun _i (n1,e1) (n2,e2) -> - (* both maps contain same tag [_i]. conflict clause: - [e1 & e2 & e_ab] impossible *) - Log.debugf 5 - (fun k->k "(@[cc.merge.distinct_conflict@ :tag %d@ \ - @[:r1 %a@ :e1 %a@]@ @[:r2 %a@ :e2 %a@]@ :e_ab %a@])" - _i N.pp n1 Expl.pp e1 - N.pp n2 Expl.pp e2 Expl.pp e_ab); - let lits = explain_unfold cc e1 in - let lits = explain_unfold ~init:lits cc e2 in - let lits = explain_unfold ~init:lits cc e_ab in - let lits = explain_eq_n ~init:lits cc a n1 in - let lits = explain_eq_n ~init:lits cc b n2 in - raise_conflict cc acts lits) - ra.n_tags rb.n_tags - in (* when merging terms with [true] or [false], possibly propagate them to SAT *) let merge_bool r1 t1 r2 t2 = if N.equal r1 (true_ cc) then ( @@ -763,6 +748,35 @@ module Make(A: ARG) = struct merge_bool rb b ra a; (* perform [union r_from r_into] *) Log.debugf 15 (fun k->k "(@[cc.merge@ :from %a@ :into %a@])" N.pp r_from N.pp r_into); + (* call micro theories *) + begin + let th_into = r_into.n_th_data in + let th_from = r_from.n_th_data in + (* merge the two maps; if a key occurs in both, looks for theories with + this particular key *) + let th = + K_map.merge th_into th_from + ~f_both:(fun id pair_into pair_from -> + match IM.find id cc.theories with + | (module Th : THEORY with type cc=t) -> + (* casting magic *) + let (module K) = Th.key in + begin match pair_into, pair_from with + | K.Store v_into, K.Store v_from -> + Log.debugf 15 + (fun k->k "(@[cc.merge.th-on-merge@ :th %s@])" K.name); + (* FIXME: explanation is a=ra, e_ab, b=rb *) + Th.on_merge cc r_into v_into r_from v_from e_ab + | _ -> assert false + end + | exception Not_found -> ()) + in + (* restore old data, if it changed *) + if not @@ K_map.equal th th_into then ( + on_backtrack cc (fun () -> r_into.n_th_data <- th_into); + ); + r_into.n_th_data <- th; + end; begin (* parents might have a different signature, check for collisions *) N.iter_parents r_from @@ -773,7 +787,6 @@ module Make(A: ARG) = struct assert (u.n_root == r_from); u.n_root <- r_into); (* now merge the classes *) - let r_into_old_tags = r_into.n_tags in let r_into_old_next = r_into.n_next in let r_from_old_next = r_from.n_next in let r_into_old_parents = r_into.n_parents in @@ -786,11 +799,9 @@ module Make(A: ARG) = struct N.pp r_from N.pp r_into); r_into.n_next <- r_into_old_next; r_from.n_next <- r_from_old_next; - r_into.n_tags <- r_into_old_tags; r_into.n_parents <- r_into_old_parents; N.iter_class_ r_from (fun u -> u.n_root <- r_from); ); - r_into.n_tags <- new_tags; (* swap [into.next] and [from.next], merging the classes *) r_into.n_next <- r_from_old_next; r_from.n_next <- r_into_old_next; @@ -810,10 +821,9 @@ module Make(A: ARG) = struct | _ -> assert false); a.n_expl <- FL_some {next=b; expl=e_ab}; end; - (* notify listeners of the merge *) - List.iter (fun f -> f cc r_into r_from e_ab) cc.on_merge ) + (* FIXME: remove and task_distinct_ cc acts (l:node list) tag expl : unit = let l = List.map (fun n -> n, find_ n) l in let coll = @@ -832,6 +842,7 @@ module Make(A: ARG) = struct (* put a tag on all equivalence classes, that will make their merge fail *) List.iter (fun (_,n) -> add_tag_n cc n tag expl) l end + *) (* we are merging [r1] with [r2==Bool(sign)], so propagate each term [u1] in the equiv class of [r1] that is a known literal back to the SAT solver @@ -864,6 +875,61 @@ module Make(A: ARG) = struct acts.Msat.acts_propagate lit reason | _ -> ()) + module Theory = struct + type cc = t + type t = theory + type 'a key = (term,lit,'a) Key.t + + (* raise a conflict *) + let raise_conflict cc _n1 _n2 expl = + Log.debugf 5 + (fun k->k "(@[cc.theory.raise-conflict@ :n1 %a@ :n2 %a@ :expl %a@])" + N.pp _n1 N.pp _n2 Expl.pp expl); + merge_classes cc (true_ cc) (false_ cc) expl + + let merge cc n1 n2 expl = + Log.debugf 5 + (fun k->k "(@[cc.theory.merge@ :n1 %a@ :n2 %a@ :expl %a@])" N.pp n1 N.pp n2 Expl.pp expl); + merge_classes cc n1 n2 expl + + let add_term = add_term + + let get_data _cc n key = + assert (N.is_root n); + K_map.find key n.n_th_data + + (* FIXME: call micro theory here? in case of merge *) + (* update data for [n] *) + let add_data (type a) (self:cc) (n:node) (key: a key) (v:a) : unit = + let n = find_ n in + let map = n.n_th_data in + let old_v = K_map.find key map in + let v', is_diff = match old_v with + | None -> v, true + | Some old_v -> + let (module K) = key in + let v' = K.merge old_v v in + v', K.equal v v' + in + if is_diff then ( + on_backtrack self (fun () -> n.n_th_data <- map); + ); + n.n_th_data <- K_map.add key v' map; + () + + let make (type a) ~(key:a key) ~on_merge ~on_new_term () : t = + let module Th = struct + type nonrec cc = cc + type data = a + let key = key + let key_id = Key.id key + let on_merge = on_merge + let on_new_term = on_new_term + end in + (module Th : THEORY with type cc=cc) + + end + let check_invariants_ (cc:t) = Log.debug 5 "(cc.check-invariants)"; Log.debugf 15 (fun k-> k "%a" pp_full cc); @@ -943,11 +1009,12 @@ module Make(A: ARG) = struct Sequence.iter (assert_lit cc) lits let assert_eq cc t1 t2 (e:lit list) : unit = - let expl = Expl.mk_lits e in + let expl = Expl.mk_list @@ List.rev_map Expl.mk_lit e in let n1 = add_term cc t1 in let n2 = add_term cc t2 in merge_classes cc n1 n2 expl + (* FIXME: remove (* generative tag used to annotate classes that can't be merged *) let distinct_tag_ = ref 0 @@ -958,14 +1025,23 @@ module Make(A: ARG) = struct (fun k->k "(@[cc.assert_distinct@ (@[%a@])@ :tag %d@])" (Util.pp_list T.pp) l tag); let l = List.map (add_term cc) l in Vec.push cc.combine @@ CT_distinct (l, tag, Expl.mk_lit lit) + *) - let create ?(on_merge=[]) ?(on_new_term=[]) ?(size=`Big) (tst:term_state) : t = + let add_th (self:t) (th:theory) : unit = + let (module Th) = th in + if IM.mem Th.key_id self.theories then ( + Error.errorf "attempt to add two theories with key %a" Key.pp Th.key + ); + Log.debugf 3 (fun k->k "(@[@{cc.add-theory@} %a@])" Key.pp Th.key); + self.theories <- IM.add Th.key_id th self.theories + + let create ?th:(theories=[]) ?(size=`Big) (tst:term_state) : t = let size = match size with `Small -> 128 | `Big -> 2048 in let rec cc = { tst; tbl = T_tbl.create size; signatures_tbl = Sig_tbl.create size; - on_merge; on_new_term; + theories=IM.empty; pending=Vec.create(); combine=Vec.create(); ps_lits=[]; @@ -981,6 +1057,7 @@ module Make(A: ARG) = struct in ignore (Lazy.force true_ : node); ignore (Lazy.force false_ : node); + List.iter (add_th cc) theories; (* now add theories *) cc let[@inline] find_t cc t : repr = diff --git a/src/cc/Congruence_closure.mli b/src/cc/Congruence_closure.mli index 31982030..cc26fa17 100644 --- a/src/cc/Congruence_closure.mli +++ b/src/cc/Congruence_closure.mli @@ -2,11 +2,8 @@ module type ARG = Congruence_closure_intf.ARG module type S = Congruence_closure_intf.S -module type THEORY_DATA = Congruence_closure_intf.THEORY_DATA + module type THEORY_KEY = Congruence_closure_intf.THEORY_KEY - -type ('t, 'a) theory_data = ('t,'a) Congruence_closure_intf.theory_data - module Key : THEORY_KEY module Make(A: ARG) diff --git a/src/cc/Congruence_closure_intf.ml b/src/cc/Congruence_closure_intf.ml index 76c74fc9..f9105d75 100644 --- a/src/cc/Congruence_closure_intf.ml +++ b/src/cc/Congruence_closure_intf.ml @@ -1,36 +1,35 @@ module type ARG = CC_types.FULL -(** Data stored by a theory for its own terms. - - It needs to form a commutative monoid where values can be unmerged upon - backtracking. -*) -module type THEORY_DATA = sig - type term - type t - - val empty : t - - val equal : t -> t -> bool - (** Equality. This is used to optimize backtracking info. *) - - val merge : t -> t -> t - (** [merge d1 d2] is called when merging classes with data [d1] and [d2] - respectively. The theory should already have checked that the merge - is compatible, and this produces the combined data for terms in the - merged class. *) -end - -type ('t, 'a) theory_data = (module THEORY_DATA with type term = 't and type t = 'a) - module type THEORY_KEY = sig - type ('t, 'a) t - (** An access key for theories that use terms ['t] and which have - per-class data ['a] *) + type ('term,'lit,'a) t + (** An access key for theories which have per-class data ['a] *) - val create : ('t, 'a) theory_data -> ('t, 'a) t - (** Generative creation of keys for the given theory data. *) + val create : + ?pp:'a Fmt.printer -> + name:string -> + eq:('a -> 'a -> bool) -> + merge:('a -> 'a -> 'a) -> + unit -> + ('term,'lit,'a) t + (** Generative creation of keys for the given theory data. + + @param eq : Equality. This is used to optimize backtracking info. + + @param merge : + [merge d1 d2] is called when merging classes with data [d1] and [d2] + respectively. The theory should already have checked that the merge + is compatible, and this produces the combined data for terms in the + merged class. + @param name name of the theory which owns this data + @param pp a printer for the data + *) + + val equal : ('t,'lit,_) t -> ('t,'lit,_) t -> bool + (** Checks if two keys are equal (generatively) *) + + val pp : _ t Fmt.printer + (** Prints the name of the key. *) end module type S = sig @@ -87,12 +86,9 @@ module type S = sig type t val pp : t Fmt.printer - val mk_reduction : t - val mk_congruence : N.t -> N.t -> t val mk_merge : N.t -> N.t -> t - val mk_merges : (N.t * N.t) list -> t val mk_lit : lit -> t - val mk_lits : lit list -> t + val mk_list : t list -> t end type node = N.t @@ -119,34 +115,52 @@ module type S = sig (** Actions available to the theory *) type sat_actions = (Msat.void, lit, Msat.void, proof) Msat.acts + module Theory : sig + type cc = t + type t + + type 'a key = (term,lit,'a) Key.t + + val raise_conflict : cc -> Expl.t -> unit + (** Raise a conflict with the given explanation + it must be a theory tautology that [expl ==> absurd]. + To be used in theories. *) + + val merge : cc -> N.t -> N.t -> Expl.t -> unit + (** Merge these two nodes given this explanation. + It must be a theory tautology that [expl ==> n1 = n2]. + To be used in theories. *) + + val add_term : cc -> term -> N.t + (** Add/retrieve node for this term. + To be used in theories *) + + val get_data : cc -> N.t -> 'a key -> 'a option + (** Get data information for this particular representative *) + + val add_data : cc -> N.t -> 'a key -> 'a -> unit + (** Add data to this particular representative. Will be backtracked. *) + + val make : + key:'a key -> + on_merge:(cc -> N.t -> 'a -> N.t -> 'a -> Expl.t -> unit) -> + on_new_term:(cc -> term -> 'a option) -> + unit -> + t + (** Build a micro theory. It can use the callbacks above. *) + end + val create : - ?on_merge:(t -> repr -> repr -> explanation -> unit) list -> - ?on_new_term:(t -> repr -> term -> unit) list -> + ?th:Theory.t list -> ?size:[`Small | `Big] -> term_state -> t (** Create a new congruence closure. *) - val on_merge : t -> (t -> repr -> repr -> explanation -> unit) -> unit - (** Add a callback, to be called whenever two classes are merged *) - - val on_new_term : t -> (t -> repr -> term -> unit) -> unit - (** Add a callback, to be called whenever a node is added *) - - val merge_classes : t -> node -> node -> explanation -> unit - (** Merge the two given nodes with given explanation. - It must be a theory tautology that [expl ==> n1 = n2] *) - - val th_data_get : t -> N.t -> (term, 'a) Key.t -> 'a option - (** Get data information for this particular representative *) - - val th_data_add : t -> N.t -> (term, 'a) Key.t -> 'a -> unit - (** Add the given data to this node (or rather, to its representative). - This will be backtracked. *) - - (* TODO: merge true/false? - val raise_conflict : CC.t -> CC.N.t -> CC.N.t -> Expl.t -> 'a - *) + val add_th : t -> Theory.t -> unit + (** Add a (micro) theory to the congruence closure. + @raise Error.Error if there is already a theory with + the same key. *) val set_as_lit : t -> N.t -> lit -> unit (** map the given node to a literal. *) @@ -173,11 +187,12 @@ module type S = sig val assert_eq : t -> term -> term -> lit list -> unit (** merge the given terms with some explanations *) - (* TODO: remove and move into its own library as a micro theory *) + (* TODO: remove and move into its own library as a micro theory val assert_distinct : t -> term list -> neq:term -> lit -> unit (** [assert_distinct l ~neq:u e] asserts all elements of [l] are distinct because [lit] is true precond: [u = distinct l] *) + *) val check : t -> sat_actions -> unit (** Perform all pending operations done via {!assert_eq}, {!assert_lit}, etc. diff --git a/src/cc/Sidekick_cc.ml b/src/cc/Sidekick_cc.ml index 2f40100f..6389618f 100644 --- a/src/cc/Sidekick_cc.ml +++ b/src/cc/Sidekick_cc.ml @@ -16,6 +16,7 @@ module type S = Congruence_closure.S module Mini_cc = Mini_cc module Congruence_closure = Congruence_closure +module Key = Congruence_closure.Key module Make = Congruence_closure.Make diff --git a/src/smt/Model.ml b/src/smt/Model.ml index fa0641a0..fa6281e7 100644 --- a/src/smt/Model.ml +++ b/src/smt/Model.ml @@ -150,8 +150,8 @@ let eval (m:t) (t:Term.t) : Value.t option = let b = aux b in if Value.equal a b then Value.true_ else Value.false_ | App_cst (c, args) -> - begin try Term.Map.find t m.values - with Not_found -> + try Term.Map.find t m.values + with Not_found -> match Cst.view c with | Cst_def udef -> (* use builtin interpretation function *) @@ -168,7 +168,6 @@ let eval (m:t) (t:Term.t) : Value.t option = | exception Not_found -> raise No_value (* no particular interpretation *) end - end in try Some (aux t) with No_value -> None diff --git a/src/smt/Sidekick_smt.ml b/src/smt/Sidekick_smt.ml index 9b7d9ddd..cdd6cc37 100644 --- a/src/smt/Sidekick_smt.ml +++ b/src/smt/Sidekick_smt.ml @@ -13,9 +13,12 @@ module Lit = Lit module Theory_combine = Theory_combine module Theory = Theory module Solver = Solver +module CC = CC module Solver_types = Solver_types +type theory = Theory.t + (**/**) module Vec = Msat.Vec module Log = Msat.Log diff --git a/src/smt/Solver.ml b/src/smt/Solver.ml index 0553642c..4ce0bbe7 100644 --- a/src/smt/Solver.ml +++ b/src/smt/Solver.ml @@ -175,9 +175,10 @@ let assume (self:t) (c:Lit.t IArray.t) : unit = let c = IArray.to_array_map (Sat_solver.make_atom sat) c in Sat_solver.add_clause_a sat c Proof_default -(* TODO: remove? use a special constant + micro theory instead? *) +(* TODO: remove? use a special constant + micro theory instead? let[@inline] assume_distinct self l ~neq lit : unit = CC.assert_distinct (cc self) l lit ~neq + *) let check_model (_s:t) : unit = Log.debug 1 "(smt.solver.check-model)"; diff --git a/src/smt/Solver.mli b/src/smt/Solver.mli index b2c3dac9..83adee9a 100644 --- a/src/smt/Solver.mli +++ b/src/smt/Solver.mli @@ -55,7 +55,9 @@ val mk_atom_t : t -> ?sign:bool -> Term.t -> Atom.t val assume : t -> Lit.t IArray.t -> unit +(* TODO: use the theory instead val assume_distinct : t -> Term.t list -> neq:Term.t -> Lit.t -> unit + *) val solve : ?on_exit:(unit -> unit) list -> diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index 4c440366..149529e5 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -23,16 +23,11 @@ module CC_expl = CC.Expl (** Actions available to a theory during its lifetime *) module type ACTIONS = sig + val cc : CC.t + val raise_conflict: conflict -> 'a (** Give a conflict clause to the solver *) - val propagate_eq: Term.t -> Term.t -> Lit.t list -> unit - (** Propagate an equality [t = u] because [e]. - TODO: use [CC.Expl] instead, with lit/merge constructors *) - - val propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit - (** Propagate a [distinct l] because [e] (where [e = neq] *) - val propagate: Lit.t -> (unit -> Lit.t list) -> unit (** Propagate a boolean using a unit clause. [expl => lit] must be a theory lemma, that is, a T-tautology *) @@ -48,16 +43,6 @@ module type ACTIONS = sig val add_persistent_axiom: Lit.t list -> unit (** Add toplevel clause to the SAT solver. This clause will not be backtracked. *) - - val cc_add_term: Term.t -> CC_eq_class.t - (** add/get term to the congruence closure *) - - val cc_find: CC_eq_class.t -> CC_eq_class.t - (** Find representative of this in the congruence closure *) - - val cc_all_classes: CC_eq_class.t Sequence.t - (** All current equivalence classes - (caution: linear in the number of terms existing in the congruence closure) *) end type actions = (module ACTIONS) diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index a9e306ec..42d08f44 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -51,9 +51,8 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit = (* transmit to theories. *) CC.check cc acts; let module A = struct + let cc = cc let[@inline] raise_conflict c : 'a = acts.Msat.acts_raise_conflict c Proof_default - let[@inline] propagate_eq t u expl : unit = CC.assert_eq cc t u expl - let propagate_distinct ts ~neq expl = CC.assert_distinct cc ts ~neq expl let[@inline] propagate p cs : unit = acts.Msat.acts_propagate p (Msat.Consequence (fun () -> cs(), Proof_default)) let[@inline] propagate_l p cs : unit = propagate p (fun()->cs) @@ -61,9 +60,6 @@ let assert_lits_ ~final (self:t) acts (lits:Lit.t Sequence.t) : unit = acts.Msat.acts_add_clause ~keep:false lits Proof_default let[@inline] add_persistent_axiom lits : unit = acts.Msat.acts_add_clause ~keep:true lits Proof_default - let[@inline] cc_add_term t = CC.add_term cc t - let[@inline] cc_find t = CC.find cc t - let cc_all_classes = CC.all_classes cc end in let acts = (module A : Theory.ACTIONS) in theories self diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index 3d6bf3d0..1f95c73b 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -8,6 +8,7 @@ type 'a or_error = ('a, string) CCResult.t module E = CCResult module A = Ast module Form = Sidekick_th_bool.Bool_term +module Distinct = Sidekick_th_distinct module Fmt = CCFormat module Dot = Msat_backend.Dot.Make(Solver.Sat_solver)(Msat_backend.Dot.Default(Solver.Sat_solver)) @@ -137,7 +138,7 @@ module Conv = struct in Form.and_l tst (curry_eq l) | A.Op (A.Distinct, l) -> - Form.distinct_l tst @@ List.map (aux subst) l + Distinct.distinct_l tst @@ List.map (aux subst) l | A.Not f -> Form.not_ tst (aux subst f) | A.Bool true -> Term.true_ tst | A.Bool false -> Term.false_ tst diff --git a/src/smtlib/dune b/src/smtlib/dune index d581cf95..7f1b23d5 100644 --- a/src/smtlib/dune +++ b/src/smtlib/dune @@ -3,12 +3,8 @@ (name sidekick_smtlib) (public_name sidekick.smtlib) (libraries containers zarith msat sidekick.smt sidekick.util - sidekick.smt.th-bool msat.backend) - (flags :standard -w +a-4-42-44-48-50-58-32-60@8 - -safe-string -color always -open Sidekick_util) - (ocamlopt_flags :standard -O3 -color always -bin-annot - -unbox-closures -unbox-closures-factor 20) - ) + sidekick.smt.th-bool sidekick.smt.th-distinct msat.backend) + (flags :standard -open Sidekick_util)) (menhir (modules Parser)) diff --git a/src/th-bool/Bool_intf.ml b/src/th-bool/Bool_intf.ml index 7ca04cec..7bef0b7a 100644 --- a/src/th-bool/Bool_intf.ml +++ b/src/th-bool/Bool_intf.ml @@ -3,11 +3,9 @@ type 'a view = | B_not of 'a - | B_eq of 'a * 'a | B_and of 'a IArray.t | B_or of 'a IArray.t | B_imply of 'a IArray.t * 'a - | B_distinct of 'a IArray.t | B_atom of 'a (** {2 Interface for a representation of boolean terms} *) diff --git a/src/th-bool/Bool_term.ml b/src/th-bool/Bool_term.ml index 6a405911..e9a617e2 100644 --- a/src/th-bool/Bool_term.ml +++ b/src/th-bool/Bool_term.ml @@ -18,7 +18,6 @@ let id_not = ID.make "not" let id_and = ID.make "and" let id_or = ID.make "or" let id_imply = ID.make "=>" -let id_distinct = ID.make "distinct" let equal = T.equal let hash = T.hash @@ -34,17 +33,14 @@ let view_id cst_id args = (* conclusion is stored first *) let len = IArray.length args in B_imply (IArray.sub args 1 (len-1), IArray.get args 0) - ) else if ID.equal cst_id id_distinct then ( - B_distinct args ) else ( raise_notrace Not_a_th_term ) let view_as_bool (t:T.t) : T.t view = match T.view t with - | Eq (a,b) -> B_eq (a,b) | App_cst ({cst_id; _}, args) -> - begin try view_id cst_id args with Not_a_th_term -> B_atom t end + (try view_id cst_id args with Not_a_th_term -> B_atom t) | _ -> B_atom t module C = struct @@ -69,14 +65,7 @@ module C = struct | B_imply (_, V_bool true) -> Value.true_ | B_imply (a,_) when IArray.exists Value.is_false a -> Value.true_ | B_imply (a,b) when IArray.for_all Value.is_bool a && Value.is_bool b -> Value.false_ - | B_eq (a,b) -> Value.bool @@ Value.equal a b | B_atom v -> v - | B_distinct a -> - if - Sequence.diagonal (IArray.to_seq a) - |> Sequence.for_all (fun (x,y) -> not @@ Value.equal x y) - then Value.true_ - else Value.false_ | B_not _ | B_and _ | B_or _ | B_imply _ -> Error.errorf "non boolean value in boolean connective" @@ -92,7 +81,6 @@ module C = struct let and_ = mk_cst id_and let or_ = mk_cst id_or let imply = mk_cst id_imply - let distinct = mk_cst id_distinct end let as_id id (t:T.t) : T.t IArray.t option = @@ -152,20 +140,9 @@ let imply_l st xs y = match xs with let imply st a b = imply_a st (IArray.singleton a) b -let distinct st a = - if IArray.length a <= 1 - then T.true_ st - else T.app_cst st C.distinct a - -let distinct_l st = function - | [] | [_] -> T.true_ st - | xs -> distinct st (IArray.of_list xs) - let make st = function | B_atom t -> t - | B_eq (a,b) -> T.eq st a b | B_and l -> and_a st l | B_or l -> or_a st l | B_imply (a,b) -> imply_a st a b | B_not t -> not_ st t - | B_distinct l -> distinct st l diff --git a/src/th-bool/Bool_term.mli b/src/th-bool/Bool_term.mli index 7b34210e..31eceb83 100644 --- a/src/th-bool/Bool_term.mli +++ b/src/th-bool/Bool_term.mli @@ -15,8 +15,6 @@ val imply_a : state -> term IArray.t -> term -> term val imply_l : state -> term list -> term -> term val eq : state -> term -> term -> term val neq : state -> term -> term -> term -val distinct : state -> term IArray.t -> term -val distinct_l : state -> term list -> term val and_a : state -> term IArray.t -> term val and_l : state -> term list -> term val or_a : state -> term IArray.t -> term diff --git a/src/th-bool/Sidekick_th_bool.ml b/src/th-bool/Sidekick_th_bool.ml index 66542221..f2607f47 100644 --- a/src/th-bool/Sidekick_th_bool.ml +++ b/src/th-bool/Sidekick_th_bool.ml @@ -9,11 +9,9 @@ module Th_dyn_tseitin = Th_dyn_tseitin type 'a view = 'a Intf.view = | B_not of 'a - | B_eq of 'a * 'a | B_and of 'a IArray.t | B_or of 'a IArray.t | B_imply of 'a IArray.t * 'a - | B_distinct of 'a IArray.t | B_atom of 'a module type BOOL_TERM = Intf.BOOL_TERM diff --git a/src/th-bool/Th_dyn_tseitin.ml b/src/th-bool/Th_dyn_tseitin.ml index 504a19f8..9cbc9f2e 100644 --- a/src/th-bool/Th_dyn_tseitin.ml +++ b/src/th-bool/Th_dyn_tseitin.ml @@ -15,14 +15,7 @@ module Make(Term : ARG) = struct type term = Term.t module T_tbl = CCHashtbl.Make(Term) - - module Lit = struct - include Sidekick_smt.Lit - let eq tst a b = atom tst ~sign:true @@ Term.make tst (B_eq (a,b)) - let neq tst a b = neg @@ eq tst a b - end - - let pp_c out c = Fmt.fprintf out "(@[%a@])" (Util.pp_list Lit.pp) c + module Lit = Sidekick_smt.Lit type t = { tst: Term.state; @@ -39,22 +32,7 @@ module Make(Term : ARG) = struct in match v with | B_not _ -> assert false (* normalized *) - | B_atom _ | B_eq _ -> () (* CC will manage *) - | B_distinct l -> - let l = IArray.to_list l in - if Lit.sign lit then ( - A.propagate_distinct l ~neq:lit_t lit - ) else if final && not @@ expanded () then ( - (* add clause [distinct t1…tn ∨ ∨_{i,j>i} t_i=j] *) - let c = - Sequence.diagonal_l l - |> Sequence.map (fun (t,u) -> Lit.eq self.tst t u) - |> Sequence.to_rev_list - in - let c = Lit.neg lit :: c in - Log.debugf 5 (fun k->k "(@[tseitin.distinct.case-split@ %a@])" pp_c c); - add_axiom c - ) + | B_atom _ -> () (* CC will manage *) | B_and subs -> if Lit.sign lit then ( (* propagate [lit => subs_i] *) @@ -105,7 +83,7 @@ module Make(Term : ARG) = struct (fun lit -> let t = Lit.term lit in match Term.view_as_bool t with - | B_atom _ | B_eq _ -> () + | B_atom _ -> () | v -> tseitin ~final self acts lit t v) let partial_check (self:t) acts (lits:Lit.t Sequence.t) = diff --git a/src/th-bool/Th_dyn_tseitin.mli b/src/th-bool/Th_dyn_tseitin.mli index 351f534f..be131f8c 100644 --- a/src/th-bool/Th_dyn_tseitin.mli +++ b/src/th-bool/Th_dyn_tseitin.mli @@ -12,11 +12,5 @@ module type ARG = Bool_intf.BOOL_TERM module Make(Term : ARG) : sig type term = Term.t - module Lit : sig - type t = Sidekick_smt.Lit.t - val eq : Term.state -> term -> term -> t - val neq : Term.state -> term -> term -> t - end - val th : Sidekick_smt.Theory.t end diff --git a/src/th-bool/dune b/src/th-bool/dune index 93d2e26e..248a759c 100644 --- a/src/th-bool/dune +++ b/src/th-bool/dune @@ -2,8 +2,5 @@ (name Sidekick_th_bool) (public_name sidekick.smt.th-bool) (libraries containers sidekick.smt) - (flags :standard -w +a-4-44-48-58-60@8 - -color always -safe-string -short-paths -open Sidekick_util) - (ocamlopt_flags :standard -O3 -color always - -unbox-closures -unbox-closures-factor 20)) + (flags :standard -open Sidekick_util))