diff --git a/src/sat/Internal.ml b/src/sat/Internal.ml index a4a65f22..dd29d576 100644 --- a/src/sat/Internal.ml +++ b/src/sat/Internal.ml @@ -1345,23 +1345,33 @@ module Make (Th : Theory_intf.S) = struct (Util.pp_list Atom.debug) l ) - let current_slice st head = Theory_intf.Slice_acts { - slice_iter = slice_iter st head (Vec.size st.trail); - } + let current_slice st head : formula Theory_intf.slice_actions = + let module A = struct + type form = formula + let slice_iter = slice_iter st head (Vec.size st.trail) + end in + (module A) (* full slice, for [if_sat] final check *) - let full_slice st = Theory_intf.Slice_acts { - slice_iter = slice_iter st 0 (Vec.size st.trail); - } + let full_slice st : formula Theory_intf.slice_actions = + let module A = struct + type form = formula + let slice_iter = slice_iter st 0 (Vec.size st.trail) + end in + (module A) let act_at_level_0 st () = at_level_0 st - let actions st = Theory_intf.Actions { - push_persistent = act_push_persistent st; - push_local = act_push_local st; - on_backtrack = on_backtrack st; - propagate = act_propagate st; - } + let actions st: (formula,lemma) Theory_intf.actions = + let module A = struct + type nonrec formula = formula + type proof = lemma + let push_persistent = act_push_persistent st + let push_local = act_push_local st + let on_backtrack = on_backtrack st + let propagate = act_propagate st + end in + (module A) let create ?(size=`Big) () : t = let size_map, size_vars, size_trail, size_lvl = match size with diff --git a/src/sat/Sidekick_sat.ml b/src/sat/Sidekick_sat.ml index 028350c7..fdfea3fa 100644 --- a/src/sat/Sidekick_sat.ml +++ b/src/sat/Sidekick_sat.ml @@ -34,16 +34,8 @@ type 'clause export = 'clause Solver_intf.export = { clauses : 'clause Vec.t; } -type ('form, 'proof) actions = ('form,'proof) Theory_intf.actions = Actions of { - push_persistent : 'form IArray.t -> 'proof -> unit; - push_local : 'form IArray.t -> 'proof -> unit; - on_backtrack: (unit -> unit) -> unit; - propagate : 'form -> 'form list -> 'proof -> unit; -} - -type ('form, 'proof) slice_actions = ('form, 'proof) Theory_intf.slice_actions = Slice_acts of { - slice_iter : ('form -> unit) -> unit; -} +type ('form, 'proof) actions = ('form,'proof) Theory_intf.actions +type 'form slice_actions = 'form Theory_intf.slice_actions module Make = Solver.Make diff --git a/src/sat/Theory_intf.ml b/src/sat/Theory_intf.ml index 16e9c073..814de2d2 100644 --- a/src/sat/Theory_intf.ml +++ b/src/sat/Theory_intf.ml @@ -39,30 +39,40 @@ type ('formula, 'proof) res = theory tautology (with its proof), for which every literal is false under the current assumptions. *) -(** Actions given to the theory during initialization, to be used - at any time *) -type ('form, 'proof) actions = Actions of { - push_persistent : 'form IArray.t -> 'proof -> unit; +module type ACTIONS = sig + type formula + type proof + + val push_persistent : formula IArray.t -> proof -> unit (** Allows to add a persistent clause to the solver. *) - push_local : 'form IArray.t -> 'proof -> unit; + val push_local : formula IArray.t -> proof -> unit (** Allows to add a local clause to the solver. The clause will be removed after backtracking. *) - on_backtrack: (unit -> unit) -> unit; + val on_backtrack: (unit -> unit) -> unit (** [on_backtrack f] calls [f] when the main solver backtracks *) - propagate : 'form -> 'form list -> 'proof -> unit; + val propagate : formula -> formula list -> proof -> unit (** [propagate lit causes proof] informs the solver to propagate [lit], with the reason that the clause [causes => lit] is a theory tautology. It is faster than pushing the associated clause but the clause will not be remembered by the sat solver, i.e it will not be used by the solver to do boolean propagation. *) -} +end -type ('form, 'proof) slice_actions = Slice_acts of { - slice_iter : ('form -> unit) -> unit; +(** Actions given to the theory during initialization, to be used + at any time *) +type ('form, 'proof) actions = + (module ACTIONS with type formula = 'form and type proof = 'proof) + +module type SLICE_ACTIONS = sig + type form + + val slice_iter : (form -> unit) -> unit (** iterate on the slice of the trail *) -} +end + +type 'form slice_actions = (module SLICE_ACTIONS with type form = 'form) (** The type for a slice. Slices are some kind of view of the current propagation queue. They allow to look at the propagated literals, and to add new clauses to the solver. *) @@ -110,11 +120,11 @@ module type S = sig val create : (formula, proof) actions -> t (** Create a new instance of the theory *) - val assume : t -> (formula, proof) slice_actions -> (formula, proof) res + val assume : t -> formula slice_actions -> (formula, proof) res (** Assume the formulas in the slice, possibly pushing new formulas to be propagated, and returns the result of the new assumptions. *) - val if_sat : t -> (formula, proof) slice_actions -> (formula, proof) res + val if_sat : t -> formula slice_actions -> (formula, proof) res (** Called at the end of the search in case a model has been found. If no new clause is pushed, then 'sat' is returned, else search is resumed. *) end diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index ff13ce81..c9234d9e 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -19,19 +19,21 @@ module Sig_tbl = CCHashtbl.Make(Signature) type merge_op = node * node * explanation (* a merge operation to perform *) -type actions = { - on_backtrack:(unit -> unit) -> unit; +module type ACTIONS = sig + val on_backtrack: (unit -> unit) -> unit (** Register a callback to be invoked upon backtracking below the current level *) - on_merge:repr -> repr -> explanation -> unit; + val on_merge: repr -> repr -> explanation -> unit (** Call this when two classes are merged *) - raise_conflict: 'a. conflict -> 'a; + val raise_conflict: conflict -> 'a (** Report a conflict *) - propagate: Lit.t -> Lit.t list -> unit; + val propagate: Lit.t -> Lit.t list -> unit (** Propagate a literal *) -} +end + +type actions = (module ACTIONS) type t = { tst: Term.state; @@ -63,7 +65,9 @@ type t = { several times. See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *) -let[@inline] on_backtrack cc f : unit = cc.acts.on_backtrack f +let[@inline] on_backtrack cc f : unit = + let (module A) = cc.acts in + A.on_backtrack f let[@inline] is_root_ (n:node) : bool = n.n_root == n @@ -192,7 +196,8 @@ let rec reroot_expl (cc:t) (n:node): unit = end let[@inline] raise_conflict (cc:t) (e:conflict): _ = - cc.acts.raise_conflict e + let (module A) = cc.acts in + A.raise_conflict e let[@inline] all_classes cc : repr Sequence.t = Term.Tbl.values cc.tbl @@ -482,7 +487,8 @@ and update_combine cc = Side effect: also pushes sub-tasks *) and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = assert (is_root_ rb); - cc.acts.on_merge ra rb e + let (module A) = cc.acts in + A.on_merge ra rb e (* FIXME: callback? diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli index cac49888..f8cb8ed3 100644 --- a/src/smt/Congruence_closure.mli +++ b/src/smt/Congruence_closure.mli @@ -13,19 +13,21 @@ type repr = Equiv_class.t type conflict = Theory.conflict -type actions = { - on_backtrack:(unit -> unit) -> unit; +module type ACTIONS = sig + val on_backtrack: (unit -> unit) -> unit (** Register a callback to be invoked upon backtracking below the current level *) - on_merge:repr -> repr -> explanation -> unit; + val on_merge: repr -> repr -> explanation -> unit (** Call this when two classes are merged *) - raise_conflict: 'a. conflict -> 'a; + val raise_conflict: conflict -> 'a (** Report a conflict *) - propagate: Lit.t -> Lit.t list -> unit; + val propagate: Lit.t -> Lit.t list -> unit (** Propagate a literal *) -} +end + +type actions = (module ACTIONS) val create : ?size:int -> diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml index 90620f58..fd4bce05 100644 --- a/src/smt/Theory.ml +++ b/src/smt/Theory.ml @@ -14,58 +14,65 @@ end = struct ) end -(** Runtime state of a theory, with all the operations it provides. - ['a] is the internal state *) -type state = State : { - mutable st: 'a; - on_merge: 'a -> Equiv_class.t -> Equiv_class.t -> Explanation.t -> unit; +module type STATE = sig + type t + + val state : t + + val on_merge: t -> Equiv_class.t -> Equiv_class.t -> Explanation.t -> unit (** Called when two classes are merged *) - on_assert: 'a -> Lit.t -> unit; + val on_assert: t -> Lit.t -> unit (** Called when a literal becomes true *) - final_check: 'a -> Lit.t Sequence.t -> unit; + val final_check: t -> Lit.t Sequence.t -> unit (** Final check, must be complete (i.e. must raise a conflict if the set of literals is not satisfiable) *) -} -> state +end + + +(** Runtime state of a theory, with all the operations it provides. *) +type state = (module STATE) (** Unsatisfiable conjunction. Its negation will become a conflict clause *) type conflict = Lit.t list (** Actions available to a theory during its lifetime *) -type actions = { - on_backtrack: (unit -> unit) -> unit; +module type ACTIONS = sig + val on_backtrack: (unit -> unit) -> unit (** Register an action to do when we backtrack *) - raise_conflict: 'a. conflict -> 'a; + val raise_conflict: conflict -> 'a (** Give a conflict clause to the solver *) - propagate_eq: Term.t -> Term.t -> Lit.t list -> unit; + val propagate_eq: Term.t -> Term.t -> Lit.t list -> unit (** Propagate an equality [t = u] because [e] *) - propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit; + val propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit (** Propagate a [distinct l] because [e] (where [e = neq] *) - propagate: Lit.t -> Lit.t list -> unit; + val propagate: Lit.t -> Lit.t list -> unit (** Propagate a boolean using a unit clause. [expl => lit] must be a theory lemma, that is, a T-tautology *) - add_local_axiom: Lit.t IArray.t -> unit; + val add_local_axiom: Lit.t IArray.t -> unit (** Add local clause to the SAT solver. This clause will be removed when the solver backtracks. *) - add_persistent_axiom: Lit.t IArray.t -> unit; + val add_persistent_axiom: Lit.t IArray.t -> unit (** Add toplevel clause to the SAT solver. This clause will not be backtracked. *) - find: Term.t -> Equiv_class.t; + val find: Term.t -> Equiv_class.t (** Find representative of this term *) - all_classes: Equiv_class.t Sequence.t; + val all_classes: Equiv_class.t Sequence.t (** All current equivalence classes (caution: linear in the number of terms existing in the solver) *) -} +end + +type actions = (module ACTIONS) type t = { name: string; @@ -75,9 +82,17 @@ type t = { let make ~name ~make () : t = {name;make} let make_st + (type st) ?(on_merge=fun _ _ _ _ -> ()) ?(on_assert=fun _ _ -> ()) ~final_check ~st () : state = - State { st; on_merge; on_assert; final_check } + let module A = struct + type nonrec t = st + let state = st + let on_merge = on_merge + let on_assert = on_assert + let final_check = final_check + end in + (module A : STATE) diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml index 64bcc862..872b2605 100644 --- a/src/smt/Theory_combine.ml +++ b/src/smt/Theory_combine.ml @@ -53,7 +53,7 @@ let assume_lit (self:t) (lit:Lit.t) : unit = | Lit_atom _ -> (* transmit to theories. *) Congruence_closure.assert_lit (cc self) lit; - theories self (fun (Theory.State th) -> th.on_assert th.st lit); + theories self (fun (module Th) -> Th.on_assert Th.state lit); end (* return result to the SAT solver *) @@ -84,13 +84,13 @@ let with_conflict_catch self f = cdcl_return_res self (* propagation from the bool solver *) -let assume_real (self:t) (slice:_ Sat_solver.slice_actions) = +let assume_real (self:t) (slice:Lit.t Sat_solver.slice_actions) = (* TODO if Config.progress then print_progress(); *) - let Sat_solver.Slice_acts slice = slice in - Log.debugf 5 (fun k->k "(th_combine.assume :len %d)" (Sequence.length slice.slice_iter)); + let (module SA) = slice in + Log.debugf 5 (fun k->k "(th_combine.assume :len %d)" (Sequence.length @@ SA.slice_iter)); with_conflict_catch self (fun () -> - slice.slice_iter (assume_lit self); + SA.slice_iter (assume_lit self); (* now check satisfiability *) check self ) @@ -104,28 +104,28 @@ let assume (self:t) (slice:_ Sat_solver.slice_actions) = cdcl_return_res self (* perform final check of the model *) -let if_sat (self:t) (slice:_) : _ Sat_solver.res = +let if_sat (self:t) (slice:Lit.t Sat_solver.slice_actions) : _ Sat_solver.res = Congruence_closure.final_check (cc self); (* all formulas in the SAT solver's trail *) let forms = - let Sat_solver.Slice_acts r = slice in - r.slice_iter + let (module SA) = slice in + SA.slice_iter in (* final check for each theory *) with_conflict_catch self (fun () -> theories self - (fun (Theory.State th) -> th.final_check th.st forms)) + (fun (module Th) -> Th.final_check Th.state forms)) (** {2 Various helpers} *) (* forward propagations from CC or theories directly to the SMT core *) let act_propagate (self:t) f guard : unit = - let Sat_solver.Actions r = self.cdcl_acts in + let (module A) = self.cdcl_acts in Sat_solver.Log.debugf 2 (fun k->k "(@[@{propagate@}@ %a@ :guard %a@])" Lit.pp f (Util.pp_list Lit.pp) guard); - r.propagate f guard Proof.default + A.propagate f guard Proof.default (** {2 Interface to Congruence Closure} *) @@ -134,16 +134,17 @@ let act_raise_conflict e = raise (Exn_conflict e) (* when CC decided to merge [r1] and [r2], notify theories *) let on_merge_from_cc (self:t) r1 r2 e : unit = theories self - (fun (Theory.State th) -> th.on_merge th.st r1 r2 e) + (fun (module Th) -> Th.on_merge Th.state r1 r2 e) let mk_cc_actions (self:t) : Congruence_closure.actions = - let Sat_solver.Actions r = self.cdcl_acts in - { Congruence_closure. - on_backtrack = r.on_backtrack; - on_merge = on_merge_from_cc self; - raise_conflict = act_raise_conflict; - propagate = act_propagate self; - } + let (module A) = self.cdcl_acts in + let module R = struct + let on_backtrack = A.on_backtrack + let on_merge = on_merge_from_cc self + let raise_conflict = act_raise_conflict + let propagate = act_propagate self + end in + (module R) (** {2 Main} *) @@ -180,29 +181,30 @@ let act_find self t = let act_add_local_axiom self c : unit = Sat_solver.Log.debugf 5 (fun k->k "(@[<2>th_combine.push_local_lemma@ %a@])" Theory.Clause.pp c); - let Sat_solver.Actions r = self.cdcl_acts in - r.push_local c Proof.default + let (module A) = self.cdcl_acts in + A.push_local c Proof.default (* push one clause into [M], in the current level (not a lemma but an axiom) *) let act_add_persistent_axiom self c : unit = Sat_solver.Log.debugf 5 (fun k->k "(@[<2>th_combine.push_persistent_lemma@ %a@])" Theory.Clause.pp c); - let Sat_solver.Actions r = self.cdcl_acts in - r.push_persistent c Proof.default + let (module A) = self.cdcl_acts in + A.push_persistent c Proof.default let mk_theory_actions (self:t) : Theory.actions = - let Sat_solver.Actions r = self.cdcl_acts in - { Theory. - on_backtrack = r.on_backtrack; - raise_conflict = act_raise_conflict; - propagate = act_propagate self; - all_classes = act_all_classes self; - propagate_eq = act_propagate_eq self; - propagate_distinct = act_propagate_distinct self; - add_local_axiom = act_add_local_axiom self; - add_persistent_axiom = act_add_persistent_axiom self; - find = act_find self; - } + let (module A) = self.cdcl_acts in + let module R = struct + let on_backtrack = A.on_backtrack + let raise_conflict = act_raise_conflict + let propagate = act_propagate self + let all_classes = act_all_classes self + let propagate_eq = act_propagate_eq self + let propagate_distinct = act_propagate_distinct self + let add_local_axiom = act_add_local_axiom self + let add_persistent_axiom = act_add_persistent_axiom self + let find = act_find self + end + in (module R) let add_theory (self:t) (th:Theory.t) : unit = Sat_solver.Log.debugf 2 diff --git a/src/smt/th_bool/Sidekick_th_bool.ml b/src/smt/th_bool/Sidekick_th_bool.ml index 915fdc0f..2948e6f8 100644 --- a/src/smt/th_bool/Sidekick_th_bool.ml +++ b/src/smt/th_bool/Sidekick_th_bool.ml @@ -251,17 +251,18 @@ type t = { let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit = Log.debugf 5 (fun k->k "(@[th_bool.tseitin@ %a@])" Lit.pp lit); + let (module A) = self.acts in match b with | B_not _ -> assert false (* normalized *) | B_eq (t,u) -> if Lit.sign lit then ( - self.acts.Theory.propagate_eq t u [lit] + A.propagate_eq t u [lit] ) else ( - self.acts.Theory.propagate_distinct [t;u] ~neq:lit_t lit + A.propagate_distinct [t;u] ~neq:lit_t lit ) | B_distinct l -> if Lit.sign lit then ( - self.acts.Theory.propagate_distinct l ~neq:lit_t lit + A.propagate_distinct l ~neq:lit_t lit ) else ( (* TODO: propagate pairwise equalities? *) Error.errorf "cannot process negative distinct lit %a" Lit.pp lit; @@ -272,39 +273,39 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit = List.iter (fun sub -> let sublit = Lit.atom sub in - self.acts.Theory.propagate sublit [lit]) + A.propagate sublit [lit]) subs ) else ( (* propagate [¬lit => ∨_i ¬ subs_i] *) let c = Lit.neg lit :: List.map (Lit.atom ~sign:false) subs in - self.acts.Theory.add_local_axiom (IArray.of_list c) + A.add_local_axiom (IArray.of_list c) ) | B_or subs -> if Lit.sign lit then ( (* propagate [lit => ∨_i subs_i] *) let c = Lit.neg lit :: List.map (Lit.atom ~sign:true) subs in - self.acts.Theory.add_local_axiom (IArray.of_list c) + A.add_local_axiom (IArray.of_list c) ) else ( (* propagate [¬lit => ¬subs_i] *) List.iter (fun sub -> let sublit = Lit.atom ~sign:false sub in - self.acts.Theory.propagate sublit [lit]) + A.propagate sublit [lit]) subs ) | B_imply (guard,concl) -> if Lit.sign lit then ( (* propagate [lit => ∨_i ¬guard_i ∨ concl] *) let c = Lit.atom concl :: Lit.neg lit :: List.map (Lit.atom ~sign:false) guard in - self.acts.Theory.add_local_axiom (IArray.of_list c) + A.add_local_axiom (IArray.of_list c) ) else ( (* propagate [¬lit => ¬concl] *) - self.acts.Theory.propagate (Lit.atom ~sign:false concl) [lit]; + A.propagate (Lit.atom ~sign:false concl) [lit]; (* propagate [¬lit => ∧_i guard_i] *) List.iter (fun sub -> let sublit = Lit.atom ~sign:true sub in - self.acts.Theory.propagate sublit [lit]) + A.propagate sublit [lit]) guard )