From ef7333af6d1e57a5d7d6a705cec85fa8b198eb03 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 29 Dec 2017 16:48:26 +0100 Subject: [PATCH] make state explicit and add `type t` state-wrapper in most modules --- src/backend/Coq.ml | 11 +- src/backend/Dimacs.ml | 16 +- src/backend/Dimacs.mli | 4 +- src/core/Internal.ml | 615 +++++++++++++++++----------------- src/core/Internal.mli | 46 +-- src/core/Res.ml | 5 +- src/core/Res_intf.ml | 2 + src/core/Solver.ml | 64 ++-- src/core/Solver.mli | 3 +- src/core/Solver_intf.ml | 19 +- src/core/Solver_types.ml | 138 ++++---- src/core/Solver_types.mli | 4 +- src/core/Solver_types_intf.ml | 35 +- src/main/main.ml | 32 +- src/mcsat/Minismt_mcsat.ml | 4 +- src/mcsat/Minismt_mcsat.mli | 2 +- src/sat/Minismt_sat.ml | 3 +- src/sat/Minismt_sat.mli | 2 +- src/smt/Minismt_smt.ml | 2 +- src/smt/Minismt_smt.mli | 2 +- src/solver/mcsolver.ml | 6 +- src/solver/mcsolver.mli | 7 +- src/solver/solver.ml | 6 +- src/solver/solver.mli | 5 +- tests/test_api.ml | 17 +- 25 files changed, 558 insertions(+), 492 deletions(-) diff --git a/src/backend/Coq.ml b/src/backend/Coq.ml index 47fa6e3e..f33f0a2c 100644 --- a/src/backend/Coq.ml +++ b/src/backend/Coq.ml @@ -127,21 +127,18 @@ module Make(S : Res.S)(A : Arg with type hyp := S.clause | S.Lemma _ -> A.prove_lemma fmt (name clause) clause | S.Duplicate (p, l) -> - let p' = S.expand p in - let c = p'.S.conclusion in + let c = S.conclusion p in let () = elim_duplicate fmt clause c l in clean t fmt [c] | S.Resolution (p1, p2, a) -> - let c1 = (S.expand p1).S.conclusion in - let c2 = (S.expand p2).S.conclusion in + let c1 = S.conclusion p1 in + let c2 = S.conclusion p2 in if resolution fmt clause c1 c2 a then clean t fmt [c1; c2] let count_uses p = let h = S.H.create 4013 in let aux () node = - List.iter (fun p' -> - incr_use h S.((expand p').conclusion)) - (S.parents node.S.step) + List.iter (fun p' -> incr_use h S.(conclusion p')) (S.parents node.S.step) in let () = S.fold aux () p in h diff --git a/src/backend/Dimacs.ml b/src/backend/Dimacs.ml index a4e35b0f..07b1ee02 100644 --- a/src/backend/Dimacs.ml +++ b/src/backend/Dimacs.ml @@ -7,18 +7,18 @@ Copyright 2014 Simon Cruanes open Msat module type S = sig + type st type clause + (** The type of clauses *) val export : + st -> Format.formatter -> hyps:clause Vec.t -> history:clause Vec.t -> local:clause Vec.t -> unit - (** Export the given clause vectors to the dimacs format. - The arguments should be transmitted directly from the corresponding - function of the {Internal} module. *) val export_icnf : Format.formatter -> @@ -26,13 +26,11 @@ module type S = sig history:clause Vec.t -> local:clause Vec.t -> unit - (** Export the given clause vectors to the dimacs format. - The arguments should be transmitted directly from the corresponding - function of the {Internal} module. *) end -module Make(St : Solver_types_intf.S)(Dummy: sig end) = struct +module Make(St : Solver_types_intf.S) = struct + type st = St.t (* Dimacs & iCNF export *) let export_vec name fmt vec = @@ -76,7 +74,7 @@ module Make(St : Solver_types_intf.S)(Dummy: sig end) = struct ) learnt; lemmas - let export fmt ~hyps ~history ~local = + let export st fmt ~hyps ~history ~local = assert (Vec.for_all (fun c -> St.Clause.premise c = St.Hyp) hyps); (* Learnt clauses, then filtered to only keep only the theory lemmas; all other learnt clauses should be logical @@ -85,7 +83,7 @@ module Make(St : Solver_types_intf.S)(Dummy: sig end) = struct (* Local assertions *) assert (Vec.for_all (fun c -> St.Local = St.Clause.premise c) local); (* Number of atoms and clauses *) - let n = St.nb_elt () in + let n = St.nb_elt st in let m = Vec.size local + Vec.size hyps + Vec.size lemmas in Format.fprintf fmt "@[p cnf %d %d@,%a%a%a@]@." n m diff --git a/src/backend/Dimacs.mli b/src/backend/Dimacs.mli index 0ad94e63..e829bebd 100644 --- a/src/backend/Dimacs.mli +++ b/src/backend/Dimacs.mli @@ -13,11 +13,13 @@ Copyright 2014 Simon Cruanes open Msat module type S = sig + type st type clause (** The type of clauses *) val export : + st -> Format.formatter -> hyps:clause Vec.t -> history:clause Vec.t -> @@ -42,6 +44,6 @@ module type S = sig end -module Make(St: Solver_types_intf.S)(Dummy: sig end) : S with type clause := St.clause +module Make(St: Solver_types_intf.S) : S with type clause := St.clause and type st = St.t (** Functor to create a module for exporting probems to the dimacs (& iCNF) formats. *) diff --git a/src/core/Internal.ml b/src/core/Internal.ml index b88f8fdd..1c39d704 100644 --- a/src/core/Internal.ml +++ b/src/core/Internal.ml @@ -9,7 +9,6 @@ module Make (Plugin : Plugin_intf.S with type term = St.term and type formula = St.formula and type proof = St.proof) - (Dummy: sig end) = struct module Proof = Res.Make(St) @@ -36,7 +35,8 @@ module Make let debug = 50 (* Singleton type containing the current state *) - type env = { + type t = { + st : St.t; (* Clauses are simplified for eficiency purposes. In the following vectors, the comments actually refer to the original non-simplified @@ -134,7 +134,8 @@ module Make } (* Starting environment. *) - let env = { + let create ?(st=St.create()) () : t = { + st; unsat_conflict = None; next_decision = None; @@ -181,17 +182,19 @@ module Make } (* Misc functions *) - let to_float i = float_of_int i - let to_int f = int_of_float f + let to_float = float_of_int + let to_int = int_of_float - let nb_clauses () = Vec.size env.clauses_hyps + let[@inline] st t = t.st + let[@inline] nb_clauses st = Vec.size st.clauses_hyps (* let nb_vars () = St.nb_elt () *) - let decision_level () = Vec.size env.elt_levels - let base_level () = Vec.size env.user_levels + + let[@inline] decision_level st = Vec.size st.elt_levels + let[@inline] base_level st = Vec.size st.user_levels (* Are the assumptions currently unsat ? *) - let is_unsat () = - match env.unsat_conflict with + let[@inline] is_unsat st = + match st.unsat_conflict with | Some _ -> true | None -> false @@ -210,15 +213,15 @@ module Make (* When we have a new literal, we need to first create the list of its subterms. *) - let atom (f:St.formula) : atom = - let res = Atom.make f in + let mk_atom st (f:St.formula) : atom = + let res = Atom.make st.st f in if St.mcsat then ( begin match res.var.v_assignable with | Some _ -> () | None -> let l = ref [] in Plugin.iter_assignable - (fun t -> l := Lit.make t :: !l) + (fun t -> l := Lit.make st.st t :: !l) res.var.pa.lit; res.var.v_assignable <- Some !l; end; @@ -232,79 +235,79 @@ module Make When we add a variable (which wraps a formula), we also need to add all its subterms. *) - let rec insert_var_order (elt:elt) : unit = - H.insert env.order elt; + let rec insert_var_order st (elt:elt) : unit = + H.insert st.order elt; begin match elt with | E_lit _ -> () - | E_var v -> insert_subterms_order v + | E_var v -> insert_subterms_order st v end - and insert_subterms_order (v:St.var) : unit = - iter_sub (fun t -> insert_var_order (Elt.of_lit t)) v + and insert_subterms_order st (v:St.var) : unit = + iter_sub (fun t -> insert_var_order st (Elt.of_lit t)) v (* Add new litterals/atoms on which to decide on, even if there is no clause that constrains it. We could maybe check if they have already has been decided before inserting them into the heap, if it appears that it helps performance. *) - let new_lit t = - let l = Lit.make t in - insert_var_order (E_lit l) + let new_lit st t = + let l = Lit.make st.st t in + insert_var_order st (E_lit l) - let new_atom p = - let a = atom p in - insert_var_order (E_var a.var) + let new_atom st p = + let a = mk_atom st p in + insert_var_order st (E_var a.var) (* Rather than iterate over all the heap when we want to decrease all the variables/literals activity, we instead increase the value by which we increase the activity of 'interesting' var/lits. *) - let var_decay_activity () = - env.var_incr <- env.var_incr *. env.var_decay + let var_decay_activity st = + st.var_incr <- st.var_incr *. st.var_decay - let clause_decay_activity () = - env.clause_incr <- env.clause_incr *. env.clause_decay + let clause_decay_activity st = + st.clause_incr <- st.clause_incr *. st.clause_decay (* increase activity of [v] *) - let var_bump_activity_aux v = - v.v_weight <- v.v_weight +. env.var_incr; + let var_bump_activity_aux st v = + v.v_weight <- v.v_weight +. st.var_incr; if v.v_weight > 1e100 then ( - for i = 0 to (St.nb_elt ()) - 1 do - Elt.set_weight (St.get_elt i) ((Elt.weight (St.get_elt i)) *. 1e-100) + for i = 0 to St.nb_elt st.st - 1 do + Elt.set_weight (St.get_elt st.st i) ((Elt.weight (St.get_elt st.st i)) *. 1e-100) done; - env.var_incr <- env.var_incr *. 1e-100; + st.var_incr <- st.var_incr *. 1e-100; ); let elt = Elt.of_var v in if H.in_heap elt then ( - H.decrease env.order elt + H.decrease st.order elt ) (* increase activity of literal [l] *) - let lit_bump_activity_aux (l:lit): unit = - l.l_weight <- l.l_weight +. env.var_incr; + let lit_bump_activity_aux st (l:lit): unit = + l.l_weight <- l.l_weight +. st.var_incr; if l.l_weight > 1e100 then ( - for i = 0 to (St.nb_elt ()) - 1 do - Elt.set_weight (St.get_elt i) ((Elt.weight (St.get_elt i)) *. 1e-100) + for i = 0 to St.nb_elt st.st - 1 do + Elt.set_weight (St.get_elt st.st i) ((Elt.weight (St.get_elt st.st i)) *. 1e-100) done; - env.var_incr <- env.var_incr *. 1e-100; + st.var_incr <- st.var_incr *. 1e-100; ); let elt = Elt.of_lit l in if H.in_heap elt then ( - H.decrease env.order elt + H.decrease st.order elt ) (* increase activity of var [v] *) - let var_bump_activity (v:var): unit = - var_bump_activity_aux v; - iter_sub lit_bump_activity_aux v + let var_bump_activity st (v:var): unit = + var_bump_activity_aux st v; + iter_sub (lit_bump_activity_aux st) v (* increase activity of clause [c] *) - let clause_bump_activity (c:clause) : unit = - c.activity <- c.activity +. env.clause_incr; + let clause_bump_activity st (c:clause) : unit = + c.activity <- c.activity +. st.clause_incr; if c.activity > 1e20 then ( - for i = 0 to (Vec.size env.clauses_learnt) - 1 do - (Vec.get env.clauses_learnt i).activity <- - (Vec.get env.clauses_learnt i).activity *. 1e-20; + for i = 0 to Vec.size st.clauses_learnt - 1 do + (Vec.get st.clauses_learnt i).activity <- + (Vec.get st.clauses_learnt i).activity *. 1e-20; done; - env.clause_incr <- env.clause_incr *. 1e-20 + st.clause_incr <- st.clause_incr *. 1e-20 ) (* Simplification of clauses. @@ -326,7 +329,7 @@ module Make else Array.to_list (Array.sub arr i (Array.length arr - i)) (* Eliminates atom doublons in clauses *) - let eliminate_doublons clause : clause = + let eliminate_duplicates clause : clause = let trivial = ref false in let duplicates = ref [] in let res = ref [] in @@ -404,11 +407,11 @@ module Make stack of literals i.e we have indeed reached a propagation fixpoint before making a new decision *) - let new_decision_level() = - assert (env.th_head = Vec.size env.trail); - assert (env.elt_head = Vec.size env.trail); - Vec.push env.elt_levels (Vec.size env.trail); - Vec.push env.th_levels (Plugin.current_level ()); (* save the current theory state *) + let new_decision_level st = + assert (st.th_head = Vec.size st.trail); + assert (st.elt_head = Vec.size st.trail); + Vec.push st.elt_levels (Vec.size st.trail); + Vec.push st.th_levels (Plugin.current_level ()); (* save the current theory state *) () (* Attach/Detach a clause. @@ -429,32 +432,32 @@ module Make Used to backtrack, i.e cancel down to [lvl] excluded, i.e we want to go back to the state the solver was in when decision level [lvl] was created. *) - let cancel_until lvl = - assert (lvl >= base_level ()); + let cancel_until st lvl = + assert (lvl >= base_level st); (* Nothing to do if we try to backtrack to a non-existent level. *) - if decision_level () <= lvl then ( + if decision_level st <= lvl then ( Log.debugf debug (fun k -> k "Already at level <= %d" lvl) ) else ( Log.debugf info (fun k -> k "Backtracking to lvl %d" lvl); (* We set the head of the solver and theory queue to what it was. *) - let head = ref (Vec.get env.elt_levels lvl) in - env.elt_head <- !head; - env.th_head <- !head; + let head = ref (Vec.get st.elt_levels lvl) in + st.elt_head <- !head; + st.th_head <- !head; (* Now we need to cleanup the vars that are not valid anymore (i.e to the right of elt_head in the queue. *) - for c = env.elt_head to Vec.size env.trail - 1 do - match (Vec.get env.trail c) with + for c = st.elt_head to Vec.size st.trail - 1 do + match (Vec.get st.trail c) with (* A literal is unassigned, we nedd to add it back to the heap of potentially assignable literals, unless it has a level lower than [lvl], in which case we just move it back. *) | Lit l -> if l.l_level <= lvl then ( - Vec.set env.trail !head (Trail_elt.of_lit l); + Vec.set st.trail !head (Trail_elt.of_lit l); head := !head + 1 ) else ( l.assigned <- None; l.l_level <- -1; - insert_var_order (Elt.of_lit l) + insert_var_order st (Elt.of_lit l) ) (* A variable is not true/false anymore, one of two things can happen: *) | Atom a -> @@ -462,7 +465,7 @@ module Make (* It is a late propagation, which has a level lower than where we backtrack, so we just move it to the head of the queue, to be propagated again. *) - Vec.set env.trail !head (Trail_elt.of_atom a); + Vec.set st.trail !head (Trail_elt.of_atom a); head := !head + 1 ) else ( (* it is a result of bolean propagation, or a semantic propagation @@ -472,24 +475,24 @@ module Make a.neg.is_true <- false; a.var.v_level <- -1; a.var.reason <- None; - insert_var_order (Elt.of_var a.var) + insert_var_order st (Elt.of_var a.var) ) done; (* Recover the right theory state. *) - Plugin.backtrack (Vec.get env.th_levels lvl); + Plugin.backtrack (Vec.get st.th_levels lvl); (* Resize the vectors according to their new size. *) - Vec.shrink env.trail !head; - Vec.shrink env.elt_levels lvl; - Vec.shrink env.th_levels lvl; + Vec.shrink st.trail !head; + Vec.shrink st.elt_levels lvl; + Vec.shrink st.th_levels lvl; ); - assert (Vec.size env.elt_levels = Vec.size env.th_levels); + assert (Vec.size st.elt_levels = Vec.size st.th_levels); () (* Unsatisfiability is signaled through an exception, since it can happen in multiple places (adding new clauses, or solving for instance). *) - let report_unsat confl : _ = + let report_unsat st confl : _ = Log.debugf info (fun k -> k "@[Unsat conflict: %a@]" Clause.debug confl); - env.unsat_conflict <- Some confl; + st.unsat_conflict <- Some confl; raise Unsat (* Simplification of boolean propagation reasons. @@ -530,7 +533,7 @@ module Make (* Boolean propagation. Wrapper function for adding a new propagated formula. *) - let enqueue_bool a ~level:lvl reason : unit = + let enqueue_bool st a ~level:lvl reason : unit = if a.neg.is_true then ( Log.debugf error (fun k->k "Trying to enqueue a false literal: %a" Atom.debug a); assert false @@ -544,22 +547,22 @@ module Make a.is_true <- true; a.var.v_level <- lvl; a.var.reason <- Some reason; - Vec.push env.trail (Trail_elt.of_atom a); + Vec.push st.trail (Trail_elt.of_atom a); Log.debugf debug - (fun k->k "Enqueue (%d): %a" (Vec.size env.trail) Atom.debug a); + (fun k->k "Enqueue (%d): %a" (Vec.size st.trail) Atom.debug a); () - let enqueue_semantic a terms = + let enqueue_semantic st a terms = if not a.is_true then ( - let l = List.map Lit.make terms in + let l = List.map (Lit.make st.st) terms in let lvl = List.fold_left (fun acc {l_level; _} -> assert (l_level > 0); max acc l_level) 0 l in - H.grow_to_at_least env.order (St.nb_elt ()); - enqueue_bool a ~level:lvl Semantic + H.grow_to_at_least st.order (St.nb_elt st.st); + enqueue_bool st a ~level:lvl Semantic ) (* MCsat semantic assignment *) - let enqueue_assign l value lvl = + let enqueue_assign st l value lvl = match l.assigned with | Some _ -> Log.debugf error @@ -569,9 +572,9 @@ module Make assert (l.l_level < 0); l.assigned <- Some value; l.l_level <- lvl; - Vec.push env.trail (Trail_elt.of_lit l); + Vec.push st.trail (Trail_elt.of_lit l); Log.debugf debug - (fun k -> k "Enqueue (%d): %a" (Vec.size env.trail) Lit.debug l); + (fun k -> k "Enqueue (%d): %a" (Vec.size st.trail) Lit.debug l); () (* swap elements of array *) @@ -610,7 +613,7 @@ module Make (* evaluate an atom for MCsat, if it's not assigned by boolean propagation/decision *) - let th_eval a : bool option = + let th_eval st a : bool option = if a.is_true || a.neg.is_true then None else match Plugin.eval a.lit with | Plugin_intf.Unknown -> None @@ -620,25 +623,25 @@ module Make Format.asprintf "msat:core/internal.ml: %s" "semantic propagation at level 0 are currently forbidden")); let atom = if b then a else a.neg in - enqueue_semantic atom l; + enqueue_semantic st atom l; Some b (* find which level to backtrack to, given a conflict clause and a boolean stating whether it is a UIP ("Unique Implication Point") precond: the atom list is sorted by decreasing decision level *) - let backtrack_lvl : atom list -> int * bool = function + let backtrack_lvl st : atom list -> int * bool = function | [] | [_] -> 0, true | a :: b :: _ -> - assert(a.var.v_level > base_level ()); + assert(a.var.v_level > base_level st); if a.var.v_level > b.var.v_level then ( (* backtrack below [a], so we can propagate [not a] *) b.var.v_level, true ) else ( assert (a.var.v_level = b.var.v_level); - assert (a.var.v_level >= base_level ()); - max (a.var.v_level - 1) (base_level ()), false + assert (a.var.v_level >= base_level st); + max (a.var.v_level - 1) (base_level st), false ) (* result of conflict analysis, containing the learnt clause and some @@ -654,24 +657,24 @@ module Make cr_is_uip: bool; (* conflict is UIP? *) } - let get_atom i = - match Vec.get env.trail i with + let get_atom st i = + match Vec.get st.trail i with | Lit _ -> assert false | Atom x -> x (* conflict analysis for SAT Same idea as the mcsat analyze function (without semantic propagations), except we look the the Last UIP (TODO: check ?), and do it in an imperative and efficient manner. *) - let analyze_sat c_clause : conflict_res = + let analyze_sat st c_clause : conflict_res = let pathC = ref 0 in let learnt = ref [] in let cond = ref true in let blevel = ref 0 in let seen = ref [] in let c = ref (Some c_clause) in - let tr_ind = ref (Vec.size env.trail - 1) in + let tr_ind = ref (Vec.size st.trail - 1) in let history = ref [] in - assert (decision_level () > 0); + assert (decision_level st > 0); let conflict_level = Array.fold_left (fun acc p -> max acc p.var.v_level) 0 c_clause.atoms in @@ -684,7 +687,7 @@ module Make | Some clause -> Log.debugf debug (fun k->k" Resolving clause: %a" Clause.debug clause); begin match clause.cpremise with - | History _ -> clause_bump_activity clause + | History _ -> clause_bump_activity st clause | Hyp | Local | Lemma _ -> () end; history := clause :: !history; @@ -703,7 +706,7 @@ module Make Atom.mark q.neg; seen := q :: !seen; if q.var.v_level > 0 then ( - var_bump_activity q.var; + var_bump_activity st q.var; if q.var.v_level >= conflict_level then ( incr pathC; ) else ( @@ -717,7 +720,7 @@ module Make (* look for the next node to expand *) while - let a = Vec.get env.trail !tr_ind in + let a = Vec.get st.trail !tr_ind in Log.debugf debug (fun k -> k " looking at: %a" Trail_elt.debug a); match a with | Atom q -> @@ -727,7 +730,7 @@ module Make do decr tr_ind; done; - let p = get_atom !tr_ind in + let p = get_atom st !tr_ind in decr pathC; decr tr_ind; match !pathC, p.var.reason with @@ -746,15 +749,15 @@ module Make done; List.iter (fun q -> Var.clear q.var) !seen; let l = List.fast_sort (fun p q -> compare q.var.v_level p.var.v_level) !learnt in - let level, is_uip = backtrack_lvl l in + let level, is_uip = backtrack_lvl st l in { cr_backtrack_lvl = level; cr_learnt = l; cr_history = List.rev !history; cr_is_uip = is_uip; } - let analyze c_clause : conflict_res = - analyze_sat c_clause + let[@inline] analyze st c_clause : conflict_res = + analyze_sat st c_clause (* if St.mcsat then analyze_mcsat c_clause @@ -762,67 +765,69 @@ module Make *) (* add the learnt clause to the clause database, propagate, etc. *) - let record_learnt_clause (confl:clause) (cr:conflict_res): unit = + let record_learnt_clause st (confl:clause) (cr:conflict_res): unit = begin match cr.cr_learnt with | [] -> assert false | [fuip] -> assert (cr.cr_backtrack_lvl = 0); if fuip.neg.is_true then ( - report_unsat confl + report_unsat st confl ) else ( let uclause = Clause.make cr.cr_learnt (History cr.cr_history) in - Vec.push env.clauses_learnt uclause; + Vec.push st.clauses_learnt uclause; (* no need to attach [uclause], it is true at level 0 *) - enqueue_bool fuip ~level:0 (Bcp uclause) + enqueue_bool st fuip ~level:0 (Bcp uclause) ) | fuip :: _ -> let lclause = Clause.make cr.cr_learnt (History cr.cr_history) in - Vec.push env.clauses_learnt lclause; + Vec.push st.clauses_learnt lclause; attach_clause lclause; - clause_bump_activity lclause; + clause_bump_activity st lclause; if cr.cr_is_uip then ( - enqueue_bool fuip ~level:cr.cr_backtrack_lvl (Bcp lclause) + enqueue_bool st fuip ~level:cr.cr_backtrack_lvl (Bcp lclause) ) else ( - env.next_decision <- Some fuip.neg + st.next_decision <- Some fuip.neg ) end; - var_decay_activity (); - clause_decay_activity () + var_decay_activity st; + clause_decay_activity st (* process a conflict: - learn clause - backtrack - report unsat if conflict at level 0 *) - let add_boolean_conflict (confl:clause): unit = + let add_boolean_conflict st (confl:clause): unit = Log.debugf info (fun k -> k "Boolean conflict: %a" Clause.debug confl); - env.next_decision <- None; - env.conflicts <- env.conflicts + 1; - assert (decision_level() >= base_level ()); - if decision_level() = base_level () - || Array.for_all (fun a -> a.var.v_level <= base_level ()) confl.atoms then - report_unsat confl; (* Top-level conflict *) - let cr = analyze confl in - cancel_until (max cr.cr_backtrack_lvl (base_level ())); - record_learnt_clause confl cr + st.next_decision <- None; + st.conflicts <- st.conflicts + 1; + assert (decision_level st >= base_level st); + if decision_level st = base_level st || + Array.for_all (fun a -> a.var.v_level <= base_level st) confl.atoms then ( + (* Top-level conflict *) + report_unsat st confl; + ); + let cr = analyze st confl in + cancel_until st (max cr.cr_backtrack_lvl (base_level st)); + record_learnt_clause st confl cr (* Get the correct vector to insert a clause in. *) - let clause_vector c = + let clause_vector st c = match c.cpremise with - | Hyp -> env.clauses_hyps - | Local -> env.clauses_temp - | Lemma _ | History _ -> env.clauses_learnt + | Hyp -> st.clauses_hyps + | Local -> st.clauses_temp + | Lemma _ | History _ -> st.clauses_learnt (* Add a new clause, simplifying, propagating, and backtracking if the clause is false in the current trail *) - let add_clause (init:clause) : unit = + let add_clause st (init:clause) : unit = Log.debugf debug (fun k -> k "Adding clause: @[%a@]" Clause.debug init); (* Insertion of new lits is done before simplification. Indeed, else a lit in a trivial clause could end up being not decided on, which is a bug. *) - Array.iter (fun x -> insert_var_order (Elt.of_var x.var)) init.atoms; - let vec = clause_vector init in + Array.iter (fun x -> insert_var_order st (Elt.of_var x.var)) init.atoms; + let vec = clause_vector st init in try - let c = eliminate_doublons init in + let c = eliminate_duplicates init in Log.debugf debug (fun k -> k "Doublons eliminated: %a" Clause.debug c); let atoms, history = partition c.atoms in let clause = @@ -840,30 +845,30 @@ module Make (* Report_unsat will raise, and the current clause will be lost if we do not store it somewhere. Since the proof search will end, any of env.clauses_to_add or env.clauses_root is adequate. *) - Stack.push clause env.clauses_root; - report_unsat clause + Stack.push clause st.clauses_root; + report_unsat st clause | [a] -> - cancel_until (base_level ()); + cancel_until st (base_level st); if a.neg.is_true then ( (* Since we cannot propagate the atom [a], in order to not lose the information that [a] must be true, we add clause to the list of clauses to add, so that it will be e-examined later. *) Log.debug debug "Unit clause, adding to clauses to add"; - Stack.push clause env.clauses_to_add; - report_unsat clause + Stack.push clause st.clauses_to_add; + report_unsat st clause ) else if a.is_true then ( (* If the atom is already true, then it should be because of a local hyp. However it means we can't propagate it at level 0. In order to not lose that information, we store the clause in a stack of clauses that we will add to the solver at the next pop. *) Log.debug debug "Unit clause, adding to root clauses"; - assert (0 < a.var.v_level && a.var.v_level <= base_level ()); - Stack.push clause env.clauses_root; + assert (0 < a.var.v_level && a.var.v_level <= base_level st); + Stack.push clause st.clauses_root; () ) else ( Log.debugf debug (fun k->k "Unit clause, propagating: %a" Atom.debug a); Vec.push vec clause; - enqueue_bool a ~level:0 (Bcp clause) + enqueue_bool st a ~level:0 (Bcp clause) ) | a::b::_ -> Vec.push vec clause; @@ -872,30 +877,30 @@ module Make or we might watch the wrong literals. *) put_high_level_atoms_first clause.atoms; attach_clause clause; - add_boolean_conflict clause + add_boolean_conflict st clause ) else ( attach_clause clause; if b.neg.is_true && not a.is_true && not a.neg.is_true then ( let lvl = List.fold_left (fun m a -> max m a.var.v_level) 0 atoms in - cancel_until (max lvl (base_level ())); - enqueue_bool a ~level:lvl (Bcp clause) + cancel_until st (max lvl (base_level st)); + enqueue_bool st a ~level:lvl (Bcp clause) ) ) with Trivial -> Vec.push vec init; Log.debugf info (fun k->k "Trivial clause ignored : @[%a@]" Clause.debug init) - let flush_clauses () = - if not (Stack.is_empty env.clauses_to_add) then begin - let nbv = St.nb_elt () in - let nbc = env.nb_init_clauses + Stack.length env.clauses_to_add in - H.grow_to_at_least env.order nbv; - Vec.grow_to_at_least env.clauses_hyps nbc; - Vec.grow_to_at_least env.clauses_learnt nbc; - env.nb_init_clauses <- nbc; - while not (Stack.is_empty env.clauses_to_add) do - let c = Stack.pop env.clauses_to_add in - add_clause c + let flush_clauses st = + if not (Stack.is_empty st.clauses_to_add) then begin + let nbv = St.nb_elt st.st in + let nbc = st.nb_init_clauses + Stack.length st.clauses_to_add in + H.grow_to_at_least st.order nbv; + Vec.grow_to_at_least st.clauses_hyps nbc; + Vec.grow_to_at_least st.clauses_learnt nbc; + st.nb_init_clauses <- nbc; + while not (Stack.is_empty st.clauses_to_add) do + let c = Stack.pop st.clauses_to_add in + add_clause st c done end @@ -908,7 +913,7 @@ module Make [i] is the index of [c] in [a.watched] @return whether [c] was removed from [a.watched] *) - let propagate_in_clause (a:atom) (c:clause) (i:int): watch_res = + let propagate_in_clause st (a:atom) (c:clause) (i:int): watch_res = let atoms = c.atoms in let first = atoms.(0) in if first == a.neg then ( @@ -937,15 +942,15 @@ module Make (* no watch lit found *) if first.neg.is_true then ( (* clause is false *) - env.elt_head <- Vec.size env.trail; + st.elt_head <- Vec.size st.trail; raise (Conflict c) ) else ( - match th_eval first with + match th_eval st first with | None -> (* clause is unit, keep the same watches, but propagate *) - enqueue_bool first ~level:(decision_level ()) (Bcp c) + enqueue_bool st first ~level:(decision_level st) (Bcp c) | Some true -> () | Some false -> - env.elt_head <- Vec.size env.trail; + st.elt_head <- Vec.size st.trail; raise (Conflict c) ); Watch_kept @@ -957,7 +962,7 @@ module Make clause watching [a] to see if the clause is false, unit, or has other possible watches @param res the optional conflict clause that the propagation might trigger *) - let propagate_atom a (res:clause option ref) : unit = + let propagate_atom st a (res:clause option ref) : unit = let watched = a.watched in begin try @@ -966,7 +971,7 @@ module Make else ( let c = Vec.get watched i in assert (Clause.attached c); - let j = match propagate_in_clause a c i with + let j = match propagate_in_clause st a c i with | Watch_kept -> i+1 | Watch_removed -> i (* clause at this index changed *) in @@ -981,88 +986,87 @@ module Make () (* Propagation (boolean and theory) *) - let create_atom f = - let a = atom f in - ignore (th_eval a); + let create_atom st f = + let a = mk_atom st f in + ignore (th_eval st a); a - let slice_get i = - match Vec.get env.trail i with + let slice_get st i = + match Vec.get st.trail i with | Atom a -> Plugin_intf.Lit a.lit | Lit {term; assigned = Some v; _} -> Plugin_intf.Assign (term, v) | Lit _ -> assert false - let slice_push (l:formula list) (lemma:proof): unit = - let atoms = List.rev_map create_atom l in + let slice_push st (l:formula list) (lemma:proof): unit = + let atoms = List.rev_map (create_atom st) l in let c = Clause.make atoms (Lemma lemma) in Log.debugf info (fun k->k "Pushing clause %a" Clause.debug c); - Stack.push c env.clauses_to_add + Stack.push c st.clauses_to_add - let slice_propagate f = function + let slice_propagate (st:t) f = function | Plugin_intf.Eval l -> - let a = atom f in - enqueue_semantic a l + let a = mk_atom st f in + enqueue_semantic st a l | Plugin_intf.Consequence (causes, proof) -> - let l = List.rev_map atom causes in + let l = List.rev_map (mk_atom st) causes in if List.for_all (fun a -> a.is_true) l then ( - let p = atom f in + let p = mk_atom st f in let c = Clause.make (p :: List.map Atom.neg l) (Lemma proof) in if p.is_true then () else if p.neg.is_true then ( - Stack.push c env.clauses_to_add + Stack.push c st.clauses_to_add ) else ( - H.grow_to_at_least env.order (St.nb_elt ()); - insert_subterms_order p.var; - let lvl = List.fold_left (fun acc a -> max acc a.var.v_level) 0 l in - enqueue_bool p ~level:lvl (Bcp c) + H.grow_to_at_least st.order (St.nb_elt st.st); + insert_subterms_order st p.var; + let level = List.fold_left (fun acc a -> max acc a.var.v_level) 0 l in + enqueue_bool st p ~level (Bcp c) ) ) else ( invalid_arg "Msat.Internal.slice_propagate" ) - let current_slice (): (_,_,_) Plugin_intf.slice = { - Plugin_intf.start = env.th_head; - length = (Vec.size env.trail) - env.th_head; - get = slice_get; - push = slice_push; - propagate = slice_propagate; + let current_slice st : (_,_,_) Plugin_intf.slice = { + Plugin_intf.start = st.th_head; + length = (Vec.size st.trail) - st.th_head; + get = slice_get st; + push = slice_push st; + propagate = slice_propagate st; } (* full slice, for [if_sat] final check *) - let full_slice () : (_,_,_) Plugin_intf.slice = { + let full_slice st : (_,_,_) Plugin_intf.slice = { Plugin_intf.start = 0; - length = Vec.size env.trail; - get = slice_get; - push = slice_push; + length = Vec.size st.trail; + get = slice_get st; + push = slice_push st; propagate = (fun _ -> assert false); } (* some boolean literals were decided/propagated within Msat. Now we need to inform the theory of those assumptions, so it can do its job. @return the conflict clause, if the theory detects unsatisfiability *) - let rec theory_propagate (): clause option = - assert (env.elt_head = Vec.size env.trail); - assert (env.th_head <= env.elt_head); - if env.th_head = env.elt_head then ( + let rec theory_propagate st : clause option = + assert (st.elt_head = Vec.size st.trail); + assert (st.th_head <= st.elt_head); + if st.th_head = st.elt_head then ( None (* fixpoint/no propagation *) ) else ( - let slice = current_slice () in - env.th_head <- env.elt_head; (* catch up *) + let slice = current_slice st in + st.th_head <- st.elt_head; (* catch up *) match Plugin.assume slice with | Plugin_intf.Sat -> - propagate () + propagate st | Plugin_intf.Unsat (l, p) -> (* conflict *) - let l = List.rev_map create_atom l in + let l = List.rev_map (create_atom st) l in (* Assert that the conflcit is indeeed a conflict *) if not @@ List.for_all (fun a -> a.neg.is_true) l then ( raise (Invalid_argument "msat:core/internal: invalid conflict"); ); - (* Insert elements for decision (and ensure the heap is big enough) *) - H.grow_to_at_least env.order (St.nb_elt ()); - List.iter (fun a -> insert_var_order (Elt.of_var a.var)) l; + H.grow_to_at_least st.order (St.nb_elt st.st); + List.iter (fun a -> insert_var_order st (Elt.of_var a.var)) l; (* Create the clause and return it. *) let c = St.Clause.make l (Lemma p) in Some c @@ -1070,31 +1074,31 @@ module Make (* fixpoint between boolean propagation and theory propagation @return a conflict clause, if any *) - and propagate (): clause option = + and propagate (st:t) : clause option = (* First, treat the stack of lemmas added by the theory, if any *) - flush_clauses (); + flush_clauses st; (* Now, check that the situation is sane *) - assert (env.elt_head <= Vec.size env.trail); - if env.elt_head = Vec.size env.trail then - theory_propagate () - else begin + assert (st.elt_head <= Vec.size st.trail); + if st.elt_head = Vec.size st.trail then + theory_propagate st + else ( let num_props = ref 0 in let res = ref None in - while env.elt_head < Vec.size env.trail do - begin match Vec.get env.trail env.elt_head with + while st.elt_head < Vec.size st.trail do + begin match Vec.get st.trail st.elt_head with | Lit _ -> () | Atom a -> incr num_props; - propagate_atom a res + propagate_atom st a res end; - env.elt_head <- env.elt_head + 1; + st.elt_head <- st.elt_head + 1; done; - env.propagations <- env.propagations + !num_props; - env.simpDB_props <- env.simpDB_props - !num_props; + st.propagations <- st.propagations + !num_props; + st.simpDB_props <- st.simpDB_props - !num_props; match !res with - | None -> theory_propagate () + | None -> theory_propagate st | _ -> !res - end + ) (* remove some learnt clauses NOTE: so far we do not forget learnt clauses. We could, as long as @@ -1102,50 +1106,50 @@ module Make let reduce_db () = () (* Decide on a new literal, and enqueue it into the trail *) - let rec pick_branch_aux atom: unit = + let rec pick_branch_aux st atom : unit = let v = atom.var in if v.v_level >= 0 then ( assert (v.pa.is_true || v.na.is_true); - pick_branch_lit () + pick_branch_lit st ) else match Plugin.eval atom.lit with | Plugin_intf.Unknown -> - env.decisions <- env.decisions + 1; - new_decision_level(); - let current_level = decision_level () in - enqueue_bool atom ~level:current_level Decision + st.decisions <- st.decisions + 1; + new_decision_level st; + let current_level = decision_level st in + enqueue_bool st atom ~level:current_level Decision | Plugin_intf.Valued (b, l) -> let a = if b then atom else atom.neg in - enqueue_semantic a l + enqueue_semantic st a l - and pick_branch_lit () = - match env.next_decision with + and pick_branch_lit st = + match st.next_decision with | Some atom -> - env.next_decision <- None; - pick_branch_aux atom + st.next_decision <- None; + pick_branch_aux st atom | None -> - begin match H.remove_min env.order with + begin match H.remove_min st.order with | E_lit l -> if Lit.level l >= 0 then - pick_branch_lit () + pick_branch_lit st else ( let value = Plugin.assign l.term in - env.decisions <- env.decisions + 1; - new_decision_level(); - let current_level = decision_level () in - enqueue_assign l value current_level + st.decisions <- st.decisions + 1; + new_decision_level st; + let current_level = decision_level st in + enqueue_assign st l value current_level ) | E_var v -> - pick_branch_aux v.pa + pick_branch_aux st v.pa | exception Not_found -> raise Sat end (* do some amount of search, until the number of conflicts or clause learnt reaches the given parameters *) - let search n_of_conflicts n_of_learnts: unit = + let search (st:t) n_of_conflicts n_of_learnts : unit = let conflictC = ref 0 in - env.starts <- env.starts + 1; + st.starts <- st.starts + 1; while true do - match propagate () with + match propagate st with | Some confl -> (* Conflict *) incr conflictC; (* When the theory has raised Unsat, add_boolean_conflict @@ -1153,31 +1157,31 @@ module Make analyzed backtrack clause. So in those case, we use add_clause to make sure the initial conflict clause is also added. *) if Clause.attached confl then - add_boolean_conflict confl + add_boolean_conflict st confl else - add_clause confl + add_clause st confl | None -> (* No Conflict *) - assert (env.elt_head = Vec.size env.trail); - assert (env.elt_head = env.th_head); - if Vec.size env.trail = St.nb_elt () + assert (st.elt_head = Vec.size st.trail); + assert (st.elt_head = st.th_head); + if Vec.size st.trail = St.nb_elt st.st then raise Sat; if n_of_conflicts > 0 && !conflictC >= n_of_conflicts then ( Log.debug info "Restarting..."; - cancel_until (base_level ()); + cancel_until st (base_level st); raise Restart ); (* if decision_level() = 0 then simplify (); *) if n_of_learnts >= 0 && - Vec.size env.clauses_learnt - Vec.size env.trail >= n_of_learnts + Vec.size st.clauses_learnt - Vec.size st.trail >= n_of_learnts then reduce_db(); - pick_branch_lit () + pick_branch_lit st done - let eval_level lit = - let var, negated = Var.make lit in + let eval_level (st:t) lit = + let var, negated = Var.make st.st lit in if not var.pa.is_true && not var.na.is_true then raise UndecidedLit else assert (var.v_level >= 0); @@ -1188,128 +1192,128 @@ module Make in value, var.v_level - let eval lit = fst (eval_level lit) + let eval st lit = fst (eval_level st lit) - let unsat_conflict () = env.unsat_conflict + let[@inline] unsat_conflict st = st.unsat_conflict - let model () : (term * term) list = + let model (st:t) : (term * term) list = let opt = function Some a -> a | None -> assert false in Vec.fold (fun acc e -> match e with | Lit v -> (v.term, opt v.assigned) :: acc | Atom _ -> acc) - [] env.trail + [] st.trail (* fixpoint of propagation and decisions until a model is found, or a conflict is reached *) - let solve (): unit = + let solve (st:t) : unit = Log.debug 5 "solve"; - if is_unsat () then raise Unsat; - let n_of_conflicts = ref (to_float env.restart_first) in - let n_of_learnts = ref ((to_float (nb_clauses())) *. env.learntsize_factor) in + if is_unsat st then raise Unsat; + let n_of_conflicts = ref (to_float st.restart_first) in + let n_of_learnts = ref ((to_float (nb_clauses st)) *. st.learntsize_factor) in try while true do begin try - search (to_int !n_of_conflicts) (to_int !n_of_learnts) + search st (to_int !n_of_conflicts) (to_int !n_of_learnts) with | Restart -> - n_of_conflicts := !n_of_conflicts *. env.restart_inc; - n_of_learnts := !n_of_learnts *. env.learntsize_inc + n_of_conflicts := !n_of_conflicts *. st.restart_inc; + n_of_learnts := !n_of_learnts *. st.learntsize_inc | Sat -> - assert (env.elt_head = Vec.size env.trail); - begin match Plugin.if_sat (full_slice ()) with + assert (st.elt_head = Vec.size st.trail); + begin match Plugin.if_sat (full_slice st) with | Plugin_intf.Sat -> () | Plugin_intf.Unsat (l, p) -> - let atoms = List.rev_map create_atom l in + let atoms = List.rev_map (create_atom st) l in let c = Clause.make atoms (Lemma p) in Log.debugf info (fun k -> k "Theory conflict clause: %a" Clause.debug c); - Stack.push c env.clauses_to_add + Stack.push c st.clauses_to_add end; - if Stack.is_empty env.clauses_to_add then raise Sat + if Stack.is_empty st.clauses_to_add then raise Sat end done with Sat -> () - let assume ?tag cnf = + let assume st ?tag cnf = List.iter (fun l -> - let atoms = List.rev_map atom l in + let atoms = List.rev_map (mk_atom st) l in let c = Clause.make ?tag atoms Hyp in Log.debugf debug (fun k -> k "Assuming clause: @[%a@]" Clause.debug c); - Stack.push c env.clauses_to_add) + Stack.push c st.clauses_to_add) cnf (* create a factice decision level for local assumptions *) - let push (): unit = + let push st : unit = Log.debug debug "Pushing a new user level"; - match env.unsat_conflict with + match st.unsat_conflict with | Some _ -> raise Unsat | None -> - cancel_until (base_level ()); + cancel_until st (base_level st); Log.debugf debug (fun k -> k "@[Status:@,@[trail: %d - %d@,%a@]" - env.elt_head env.th_head (Vec.print ~sep:"" Trail_elt.debug) env.trail); - begin match propagate () with + st.elt_head st.th_head (Vec.print ~sep:"" Trail_elt.debug) st.trail); + begin match propagate st with | Some confl -> - report_unsat confl + report_unsat st confl | None -> Log.debugf debug (fun k -> k "@[Current trail:@,@[%a@]@]" - (Vec.print ~sep:"" Trail_elt.debug) env.trail); + (Vec.print ~sep:"" Trail_elt.debug) st.trail); Log.debug info "Creating new user level"; - new_decision_level (); - Vec.push env.user_levels (Vec.size env.clauses_temp); - assert (decision_level () = base_level ()) + new_decision_level st; + Vec.push st.user_levels (Vec.size st.clauses_temp); + assert (decision_level st = base_level st) end (* pop the last factice decision level *) - let pop (): unit = - if base_level () = 0 then + let pop st : unit = + if base_level st = 0 then Log.debug warn "Cannot pop (already at level 0)" else ( Log.debug info "Popping user level"; - assert (base_level () > 0); - env.unsat_conflict <- None; - let n = Vec.last env.user_levels in - Vec.pop env.user_levels; (* before the [cancel_until]! *) + assert (base_level st > 0); + st.unsat_conflict <- None; + let n = Vec.last st.user_levels in + Vec.pop st.user_levels; (* before the [cancel_until]! *) (* Add the root clauses to the clauses to add *) - Stack.iter (fun c -> Stack.push c env.clauses_to_add) env.clauses_root; - Stack.clear env.clauses_root; + Stack.iter (fun c -> Stack.push c st.clauses_to_add) st.clauses_root; + Stack.clear st.clauses_root; (* remove from env.clauses_temp the now invalid caluses. *) - Vec.shrink env.clauses_temp n; - assert (Vec.for_all (fun c -> Array.length c.atoms = 1) env.clauses_temp); - assert (Vec.for_all (fun c -> c.atoms.(0).var.v_level <= base_level ()) env.clauses_temp); - cancel_until (base_level ()) + Vec.shrink st.clauses_temp n; + assert (Vec.for_all (fun c -> Array.length c.atoms = 1) st.clauses_temp); + assert (Vec.for_all (fun c -> c.atoms.(0).var.v_level <= base_level st) st.clauses_temp); + cancel_until st (base_level st) ) (* Add local hyps to the current decision level *) - let local l = + let local (st:t) (l:_ list) : unit = let aux lit = - let a = atom lit in + let a = mk_atom st lit in Log.debugf info (fun k-> k "Local assumption: @[%a@]" Atom.debug a); - assert (decision_level () = base_level ()); + assert (decision_level st = base_level st); if not a.is_true then ( let c = Clause.make [a] Local in Log.debugf debug (fun k -> k "Temp clause: @[%a@]" Clause.debug c); - Vec.push env.clauses_temp c; + Vec.push st.clauses_temp c; if a.neg.is_true then ( (* conflict between assumptions: UNSAT *) - report_unsat c; + report_unsat st c; ) else ( (* Grow the heap, because when the lit is backtracked, it will be added to the heap. *) - H.grow_to_at_least env.order (St.nb_elt ()); + H.grow_to_at_least st.order (St.nb_elt st.st); (* make a decision, propagate *) - let level = decision_level() in - enqueue_bool a ~level (Bcp c); + let level = decision_level st in + enqueue_bool st a ~level (Bcp c); ) ) in - assert (base_level () > 0); - match env.unsat_conflict with + assert (base_level st > 0); + match st.unsat_conflict with | None -> Log.debug info "Adding local assumption"; - cancel_until (base_level ()); + cancel_until st (base_level st); List.iter aux l | Some _ -> Log.debug warn "Cannot add local assumption (already unsat)" @@ -1338,22 +1342,23 @@ module Make with Exit -> false - let check () = - Stack.is_empty env.clauses_to_add && - check_stack env.clauses_root && - check_vec env.clauses_hyps && - check_vec env.clauses_learnt && - check_vec env.clauses_temp + let check st : bool = + Stack.is_empty st.clauses_to_add && + check_stack st.clauses_root && + check_vec st.clauses_hyps && + check_vec st.clauses_learnt && + check_vec st.clauses_temp (* Unsafe access to internal data *) - let hyps () = env.clauses_hyps + let hyps env = env.clauses_hyps - let history () = env.clauses_learnt + let history env = env.clauses_learnt - let temp () = env.clauses_temp + let temp env = env.clauses_temp - let trail () = env.trail + let trail env = env.trail end +[@@inline] diff --git a/src/core/Internal.mli b/src/core/Internal.mli index bdbbb29e..e25a675c 100644 --- a/src/core/Internal.mli +++ b/src/core/Internal.mli @@ -14,8 +14,8 @@ Copyright 2014 Simon Cruanes module Make (St : Solver_types.S) - (Th : Plugin_intf.S with type term = St.term and type formula = St.formula and type proof = St.proof) - (Dummy: sig end) + (Th : Plugin_intf.S with type term = St.term + and type formula = St.formula and type proof = St.proof) : sig (** Functor to create a solver parametrised by the atomic formulas and a theory. *) @@ -24,55 +24,63 @@ module Make exception Unsat exception UndecidedLit - val solve : unit -> unit + type t + (** Solver *) + + val create : ?st:St.t -> unit -> t + + val st : t -> St.t + (** Underlying state *) + + val solve : t -> unit (** Try and solves the current set of assumptions. @return () if the current set of clauses is satisfiable @raise Unsat if a toplevel conflict is found *) - val assume : ?tag:int -> St.formula list list -> unit + val assume : t -> ?tag:int -> St.formula list list -> unit (** Add the list of clauses to the current set of assumptions. Modifies the sat solver state in place. *) - val new_lit : St.term -> unit + val new_lit : t -> St.term -> unit (** Add a new litteral (i.e term) to the solver. This term will be decided on at some point during solving, wether it appears in clauses or not. *) - val new_atom : St.formula -> unit + val new_atom : t -> St.formula -> unit (** Add a new atom (i.e propositional formula) to the solver. This formula will be decided on at some point during solving, wether it appears in clauses or not. *) - val push : unit -> unit + val push : t -> unit (** Create a decision level for local assumptions. @raise Unsat if a conflict is detected in the current state. *) - val pop : unit -> unit + val pop : t -> unit (** Pop a decision level for local assumptions. *) - val local : St.formula list -> unit + val local : t -> St.formula list -> unit (** Add local assumptions @param assumptions list of additional local assumptions to make, removed after the callback returns a value *) (** {2 Propositional models} *) - val eval : St.formula -> bool + val eval : t -> St.formula -> bool (** Returns the valuation of a formula in the current state of the sat solver. @raise UndecidedLit if the literal is not decided *) - val eval_level : St.formula -> bool * int + val eval_level : t -> St.formula -> bool * int (** Return the current assignement of the literals, as well as its decision level. If the level is 0, then it is necessary for the atom to have this value; otherwise it is due to choices that can potentially be backtracked. @raise UndecidedLit if the literal is not decided *) - val model : unit -> (St.term * St.term) list + val model : t -> (St.term * St.term) list (** Returns the model found if the formula is satisfiable. *) - val check : unit -> bool + val check : t -> bool (** Check the satisfiability of the current model. Only has meaning if the solver finished proof search and has returned [Sat]. *) @@ -80,11 +88,11 @@ module Make module Proof : Res.S with module St = St - val unsat_conflict : unit -> St.clause option + val unsat_conflict : t -> St.clause option (** Returns the unsat clause found at the toplevel, if it exists (i.e if [solve] has raised [Unsat]) *) - val full_slice : unit -> (St.term, St.formula, St.proof) Plugin_intf.slice + val full_slice : t -> (St.term, St.formula, St.proof) Plugin_intf.slice (** View the current state of the trail as a slice. Mainly useful when the solver has reached a SAT conclusion. *) @@ -92,21 +100,21 @@ module Make These functions expose some internal data stored by the solver, as such great care should be taken to ensure not to mess with the values returned. *) - val trail : unit -> St.trail_elt Vec.t + val trail : t -> St.trail_elt Vec.t (** Returns the current trail. *DO NOT MUTATE* *) - val hyps : unit -> St.clause Vec.t + val hyps : t -> St.clause Vec.t (** Returns the vector of assumptions used by the solver. May be slightly different from the clauses assumed because of top-level simplification of clauses. *DO NOT MUTATE* *) - val temp : unit -> St.clause Vec.t + val temp : t -> St.clause Vec.t (** Returns the clauses coreesponding to the local assumptions. All clauses in this vec are assured to be unit clauses. *DO NOT MUTATE* *) - val history : unit -> St.clause Vec.t + val history : t -> St.clause Vec.t (** Returns the history of learnt clauses, with no guarantees on order. *DO NOT MUTATE* *) diff --git a/src/core/Res.ml b/src/core/Res.ml index f164b940..4fbde74c 100644 --- a/src/core/Res.ml +++ b/src/core/Res.ml @@ -178,12 +178,15 @@ module Make(St : Solver_types.S) = struct end | _ -> Log.debugf error - (fun k -> k "While resolving clauses:@[%a@\n%a@]" St.Clause.debug c St.Clause.debug d); + (fun k -> k "While resolving clauses:@[%a@\n%a@]" + St.Clause.debug c St.Clause.debug d); raise (Resolution_error "Clause mismatch") end | _ -> raise (Resolution_error "Bad history") + let[@inline] conclusion (p:proof) : clause = p + let expand conclusion = Log.debugf debug (fun k -> k "Expanding : @[%a@]" St.Clause.debug conclusion); match conclusion.St.cpremise with diff --git a/src/core/Res_intf.ml b/src/core/Res_intf.ml index 5b443bd0..ce5f1479 100644 --- a/src/core/Res_intf.ml +++ b/src/core/Res_intf.ml @@ -95,6 +95,8 @@ module type S = sig val expand : proof -> proof_node (** Return the proof step at the root of a given proof. *) + val conclusion : proof -> clause + val fold : ('a -> proof_node -> 'a) -> 'a -> proof -> 'a (** [fold f acc p], fold [f] over the proof [p] and all its node. It is guaranteed that [f] is executed exactly once on each proof node in the tree, and that the execution of diff --git a/src/core/Solver.ml b/src/core/Solver.ml index 3dab855e..1a633694 100644 --- a/src/core/Solver.ml +++ b/src/core/Solver.ml @@ -13,12 +13,11 @@ module Make (Th : Plugin_intf.S with type term = St.term and type formula = St.formula and type proof = St.proof) - () = struct module St = St - module S = Internal.Make(St)(Th)(struct end) + module S = Internal.Make(St)(Th) module Proof = S.Proof @@ -26,25 +25,30 @@ module Make type atom = St.formula + type t = S.t + + let create = S.create + (* Result type *) type res = | Sat of (St.term,St.formula) sat_state | Unsat of (St.clause,Proof.proof) unsat_state - let pp_all lvl status = + let pp_all st lvl status = Log.debugf lvl (fun k -> k - "@[%s - Full resume:@,@[Trail:@\n%a@]@,@[Temp:@\n%a@]@,@[Hyps:@\n%a@]@,@[Lemmas:@\n%a@]@,@]@." + "@[%s - Full resume:@,@[Trail:@\n%a@]@,\ + @[Temp:@\n%a@]@,@[Hyps:@\n%a@]@,@[Lemmas:@\n%a@]@,@]@." status - (Vec.print ~sep:"" St.Trail_elt.debug) (S.trail ()) - (Vec.print ~sep:"" St.Clause.debug) (S.temp ()) - (Vec.print ~sep:"" St.Clause.debug) (S.hyps ()) - (Vec.print ~sep:"" St.Clause.debug) (S.history ()) + (Vec.print ~sep:"" St.Trail_elt.debug) (S.trail st) + (Vec.print ~sep:"" St.Clause.debug) (S.temp st) + (Vec.print ~sep:"" St.Clause.debug) (S.hyps st) + (Vec.print ~sep:"" St.Clause.debug) (S.history st) ) - let mk_sat () : (_,_) sat_state = - pp_all 99 "SAT"; - let t = S.trail () in + let mk_sat (st:S.t) : (_,_) sat_state = + pp_all st 99 "SAT"; + let t = S.trail st in let iter f f' = Vec.iter (function | St.Atom a -> f a.St.lit @@ -52,16 +56,16 @@ module Make t in { - eval = S.eval; - eval_level = S.eval_level; + eval = S.eval st; + eval_level = S.eval_level st; iter_trail = iter; - model = S.model; + model = (fun () -> S.model st); } - let mk_unsat () : (_,_) unsat_state = - pp_all 99 "UNSAT"; + let mk_unsat (st:S.t) : (_,_) unsat_state = + pp_all st 99 "UNSAT"; let unsat_conflict () = - match S.unsat_conflict () with + match S.unsat_conflict st with | None -> assert false | Some c -> c in @@ -74,21 +78,21 @@ module Make (* Wrappers around internal functions*) let assume = S.assume - let solve ?(assumptions=[]) () = + let solve (st:t) ?(assumptions=[]) () = try - S.pop (); (* FIXME: what?! *) - S.push (); - S.local assumptions; - S.solve (); - Sat (mk_sat()) + S.pop st; (* FIXME: what?! *) + S.push st; + S.local st assumptions; + S.solve st; + Sat (mk_sat st) with S.Unsat -> - Unsat (mk_unsat()) + Unsat (mk_unsat st) let unsat_core = S.Proof.unsat_core - let true_at_level0 a = + let true_at_level0 st a = try - let b, lev = S.eval_level a in + let b, lev = S.eval_level st a in b && lev = 0 with S.UndecidedLit -> false @@ -97,9 +101,9 @@ module Make let new_lit = S.new_lit let new_atom = S.new_atom - let export () : St.clause export = - let hyps = S.hyps () in - let history = S.history () in - let local = S.temp () in + let export (st:t) : St.clause export = + let hyps = S.hyps st in + let history = S.history st in + let local = S.temp st in {hyps; history; local} end diff --git a/src/core/Solver.mli b/src/core/Solver.mli index 38c2ca9e..c46dc61c 100644 --- a/src/core/Solver.mli +++ b/src/core/Solver.mli @@ -18,8 +18,7 @@ module Make (Th : Plugin_intf.S with type term = St.term and type formula = St.formula and type proof = St.proof) - () : - S with module St = St + : S with module St = St (** Functor to make a safe external interface. *) diff --git a/src/core/Solver_intf.ml b/src/core/Solver_intf.ml index 39e47ef9..4fafcf18 100644 --- a/src/core/Solver_intf.ml +++ b/src/core/Solver_intf.ml @@ -61,6 +61,13 @@ module type S = sig module Proof : Res.S with module St = St (** A module to manipulate proofs. *) + type t + (** Main solver type, containing all state *) + + val create : ?st:St.t -> unit -> t + (** Create new solver *) + (* TODO: add size hint, callbacks, etc. *) + (** {2 Types} *) type atom = St.formula @@ -77,19 +84,19 @@ module type S = sig (** {2 Base operations} *) - val assume : ?tag:int -> atom list list -> unit + val assume : t -> ?tag:int -> atom list list -> unit (** Add the list of clauses to the current set of assumptions. Modifies the sat solver state in place. *) - val solve : ?assumptions:atom list -> unit -> res + val solve : t -> ?assumptions:atom list -> unit -> res (** Try and solves the current set of assumptions. *) - val new_lit : St.term -> unit + val new_lit : t -> St.term -> unit (** Add a new litteral (i.e term) to the solver. This term will be decided on at some point during solving, wether it appears in clauses or not. *) - val new_atom : atom -> unit + val new_atom : t -> atom -> unit (** Add a new atom (i.e propositional formula) to the solver. This formula will be decided on at some point during solving, wether it appears in clauses or not. *) @@ -97,13 +104,13 @@ module type S = sig val unsat_core : Proof.proof -> St.clause list (** Returns the unsat core of a given proof. *) - val true_at_level0 : atom -> bool + val true_at_level0 : t -> atom -> bool (** [true_at_level0 a] returns [true] if [a] was proved at level0, i.e. it must hold in all models *) val get_tag : St.clause -> int option (** Recover tag from a clause, if any *) - val export : unit -> St.clause export + val export : t -> St.clause export end diff --git a/src/core/Solver_types.ml b/src/core/Solver_types.ml index 9f46d043..c22d2df2 100644 --- a/src/core/Solver_types.ml +++ b/src/core/Solver_types.ml @@ -27,7 +27,7 @@ let () = Var_fields.freeze() (* Solver types for McSat Solving *) (* ************************************************************************ *) -module McMake (E : Expr_intf.S)() = struct +module McMake (E : Expr_intf.S) = struct (* Flag for Mcsat v.s Pure Sat *) let mcsat = true @@ -36,6 +36,8 @@ module McMake (E : Expr_intf.S)() = struct type formula = E.Formula.t type proof = E.proof + let pp_form = E.Formula.dummy + type seen = | Nope | Both @@ -136,16 +138,29 @@ module McMake (E : Expr_intf.S)() = struct module MF = Hashtbl.Make(E.Formula) module MT = Hashtbl.Make(E.Term) + type t = { + t_map: lit MT.t; + f_map: var MF.t; + vars: elt Vec.t; + mutable cpt_mk_var: int; + mutable cpt_mk_clause: int; + } + + type state = t + + let create() : t = { + f_map = MF.create 4096; + t_map = MT.create 4096; + vars = Vec.make 107 (E_var dummy_var); + cpt_mk_var = 0; + cpt_mk_clause = 0; + } + (* TODO: embed a state `t` with these inside *) - let f_map = MF.create 4096 - let t_map = MT.create 4096 - let vars = Vec.make 107 (E_var dummy_var) - let nb_elt () = Vec.size vars - let get_elt i = Vec.get vars i - let iter_elt f = Vec.iter f vars - - let cpt_mk_var = ref 0 + let nb_elt st = Vec.size st.vars + let get_elt st i = Vec.get st.vars i + let iter_elt st f = Vec.iter f st.vars let name_of_clause c = match c.cpremise with | Hyp -> "H" ^ string_of_int c.name @@ -165,20 +180,20 @@ module McMake (E : Expr_intf.S)() = struct let[@inline] weight l = l.l_weight let[@inline] set_weight l w = l.l_weight <- w - let make t = - try MT.find t_map t + let make (st:state) (t:term) : t = + try MT.find st.t_map t with Not_found -> let res = { - lid = !cpt_mk_var; + lid = st.cpt_mk_var; term = t; l_weight = 1.; l_idx= -1; l_level = -1; assigned = None; } in - incr cpt_mk_var; - MT.add t_map t res; - Vec.push vars (E_lit res); + st.cpt_mk_var <- st.cpt_mk_var + 1; + MT.add st.t_map t res; + Vec.push st.vars (E_lit res); res let debug_assign fmt v = @@ -208,42 +223,41 @@ module McMake (E : Expr_intf.S)() = struct let[@inline] weight v = v.v_weight let[@inline] set_weight v w = v.v_weight <- w - let make : formula -> var * Expr_intf.negated = - fun t -> - let lit, negated = E.Formula.norm t in - try - MF.find f_map lit, negated - with Not_found -> - let cpt_fois_2 = !cpt_mk_var lsl 1 in - let rec var = - { vid = !cpt_mk_var; - pa = pa; - na = na; - v_fields = Var_fields.empty; - v_level = -1; - v_idx= -1; - v_weight = 0.; - v_assignable = None; - reason = None; - } - and pa = - { var = var; - lit = lit; - watched = Vec.make 10 dummy_clause; - neg = na; - is_true = false; - aid = cpt_fois_2 (* aid = vid*2 *) } - and na = - { var = var; - lit = E.Formula.neg lit; - watched = Vec.make 10 dummy_clause; - neg = pa; - is_true = false; - aid = cpt_fois_2 + 1 (* aid = vid*2+1 *) } in - MF.add f_map lit var; - incr cpt_mk_var; - Vec.push vars (E_var var); - var, negated + let make (st:state) (t:formula) : var * Expr_intf.negated = + let lit, negated = E.Formula.norm t in + try + MF.find st.f_map lit, negated + with Not_found -> + let cpt_double = st.cpt_mk_var lsl 1 in + let rec var = + { vid = st.cpt_mk_var; + pa = pa; + na = na; + v_fields = Var_fields.empty; + v_level = -1; + v_idx= -1; + v_weight = 0.; + v_assignable = None; + reason = None; + } + and pa = + { var = var; + lit = lit; + watched = Vec.make 10 dummy_clause; + neg = na; + is_true = false; + aid = cpt_double (* aid = vid*2 *) } + and na = + { var = var; + lit = E.Formula.neg lit; + watched = Vec.make 10 dummy_clause; + neg = pa; + is_true = false; + aid = cpt_double + 1 (* aid = vid*2+1 *) } in + MF.add st.f_map lit var; + st.cpt_mk_var <- st.cpt_mk_var + 1; + Vec.push st.vars (E_var var); + var, negated (* Marking helpers *) let[@inline] clear v = @@ -281,8 +295,8 @@ module McMake (E : Expr_intf.S)() = struct then a.var.v_fields <- Var_fields.set v_field_seen_pos true a.var.v_fields else a.var.v_fields <- Var_fields.set v_field_seen_neg true a.var.v_fields - let[@inline] make lit = - let var, negated = Var.make lit in + let[@inline] make st lit = + let var, negated = Var.make st lit in match negated with | Formula_intf.Negated -> var.na | Formula_intf.Same_sign -> var.pa @@ -427,19 +441,29 @@ module McMake (E : Expr_intf.S)() = struct in Format.fprintf fmt "%a0" aux atoms end -end + + module Term = struct + include E.Term + let pp = print + end + + module Formula = struct + include E.Formula + let pp = print + end +end[@@inline] (* Solver types for pure SAT Solving *) (* ************************************************************************ *) -module SatMake (E : Formula_intf.S)() = struct +module SatMake (E : Formula_intf.S) = struct include McMake(struct include E module Term = E module Formula = E - end)(struct end) + end) let mcsat = false -end +end[@@inline] diff --git a/src/core/Solver_types.mli b/src/core/Solver_types.mli index 62f1baaa..4d12d93d 100644 --- a/src/core/Solver_types.mli +++ b/src/core/Solver_types.mli @@ -30,11 +30,11 @@ module type S = Solver_types_intf.S module Var_fields = Solver_types_intf.Var_fields -module McMake (E : Expr_intf.S)(): +module McMake (E : Expr_intf.S): S with type term = E.Term.t and type formula = E.Formula.t and type proof = E.proof (** Functor to instantiate the types of clauses for a solver. *) -module SatMake (E : Formula_intf.S)(): +module SatMake (E : Formula_intf.S): S with type term = E.t and type formula = E.t and type proof = E.proof (** Functor to instantiate the types of clauses for a solver. *) diff --git a/src/core/Solver_types_intf.ml b/src/core/Solver_types_intf.ml index 08f011cc..b311f4ea 100644 --- a/src/core/Solver_types_intf.ml +++ b/src/core/Solver_types_intf.ml @@ -32,6 +32,12 @@ module type S = sig val mcsat : bool (** TODO:deprecate. *) + type t + (** State for creating new terms, literals, clauses *) + + (* TODO: add size hint *) + val create: unit -> t + (** {2 Type definitions} *) type term @@ -138,17 +144,19 @@ module type S = sig | E_var of var (**) (** Either a lit of a var *) - val nb_elt : unit -> int - val get_elt : int -> elt - val iter_elt : (elt -> unit) -> unit + val nb_elt : t -> int + val get_elt : t -> int -> elt + val iter_elt : t -> (elt -> unit) -> unit (** Read access to the vector of variables created *) (** {2 Variables, Literals & Clauses } *) + type state = t + module Lit : sig type t = lit val term : t -> term - val make : term -> t + val make : state -> term -> t (** Returns the variable associated with the term *) val level : t -> int @@ -167,7 +175,6 @@ module type S = sig type t = var val dummy : t - val pos : t -> atom val neg : t -> atom @@ -180,7 +187,7 @@ module type S = sig val weight : t -> float val set_weight : t -> float -> unit - val make : formula -> t * Formula_intf.negated + val make : state -> formula -> t * Formula_intf.negated (** Returns the variable linked with the given formula, and whether the atom associated with the formula is [var.pa] or [var.na] *) @@ -207,7 +214,7 @@ module type S = sig val is_true : t -> bool val is_false : t -> bool - val make : formula -> t + val make : state -> formula -> t (** Returns the atom associated with the given formula *) val mark : t -> unit @@ -274,5 +281,19 @@ module type S = sig (** Constructors and destructors *) val debug : t printer end + + module Term : sig + type t = term + val equal : t -> t -> bool + val hash : t -> int + val pp : t printer + end + + module Formula : sig + type t = formula + val equal : t -> t -> bool + val hash : t -> int + val pp : t printer + end end diff --git a/src/main/main.ml b/src/main/main.ml index 4de98a45..0d3d0301 100644 --- a/src/main/main.ml +++ b/src/main/main.ml @@ -37,19 +37,21 @@ module Make let hyps = ref [] - let check_model state = + let st = S.create() + + let check_model sat = let check_clause c = let l = List.map (function a -> Log.debugf 99 - (fun k -> k "Checking value of %a" S.St.Atom.debug (S.St.Atom.make a)); - state.Msat.eval a) c in + (fun k -> k "Checking value of %a" S.St.Formula.pp a); + sat.Msat.eval a) c in List.exists (fun x -> x) l in let l = List.map check_clause !hyps in List.for_all (fun x -> x) l - let prove ~assumptions = - let res = S.solve ~assumptions () in + let prove ~assumptions () = + let res = S.solve st ~assumptions () in let t = Sys.time () in begin match res with | S.Sat state -> @@ -78,26 +80,26 @@ module Make | Dolmen.Statement.Clause l -> let cnf = T.antecedent (Dolmen.Term.or_ l) in hyps := cnf @ !hyps; - S.assume cnf + S.assume st cnf | Dolmen.Statement.Consequent t -> let cnf = T.consequent t in hyps := cnf @ !hyps; - S.assume cnf + S.assume st cnf | Dolmen.Statement.Antecedent t -> let cnf = T.antecedent t in hyps := cnf @ !hyps; - S.assume cnf + S.assume st cnf | Dolmen.Statement.Pack [ { Dolmen.Statement.descr = Dolmen.Statement.Push 1;_ }; { Dolmen.Statement.descr = Dolmen.Statement.Antecedent f;_ }; - { Dolmen.Statement.descr = Dolmen.Statement.Prove []; }; + { Dolmen.Statement.descr = Dolmen.Statement.Prove [];_ }; { Dolmen.Statement.descr = Dolmen.Statement.Pop 1;_ }; ] -> let assumptions = T.assumptions f in - prove ~assumptions + prove ~assumptions () | Dolmen.Statement.Prove l -> - let assumptions = List.map T.assumptions l in - prove ~assumptions + let assumptions = List.map T.assumptions l |> List.flatten in + prove ~assumptions () | Dolmen.Statement.Set_info _ | Dolmen.Statement.Set_logic _ -> () | Dolmen.Statement.Exit -> exit 0 @@ -106,9 +108,9 @@ module Make Dolmen.Statement.print s end -module Sat = Make(Minismt_sat.Make(struct end))(Minismt_sat.Type) -module Smt = Make(Minismt_smt.Make(struct end))(Minismt_smt.Type) -module Mcsat = Make(Minismt_mcsat.Make(struct end))(Minismt_smt.Type) +module Sat = Make(Minismt_sat)(Minismt_sat.Type) +module Smt = Make(Minismt_smt)(Minismt_smt.Type) +module Mcsat = Make(Minismt_mcsat)(Minismt_smt.Type) let solver = ref (module Sat : S) let solver_list = [ diff --git a/src/mcsat/Minismt_mcsat.ml b/src/mcsat/Minismt_mcsat.ml index fcd928e3..a7f01e92 100644 --- a/src/mcsat/Minismt_mcsat.ml +++ b/src/mcsat/Minismt_mcsat.ml @@ -4,10 +4,10 @@ Copyright 2014 Guillaume Bury Copyright 2014 Simon Cruanes *) -module Make() = +include Minismt.Mcsolver.Make(struct type proof = unit module Term = Minismt_smt.Expr.Term module Formula = Minismt_smt.Expr.Atom - end)(Plugin_mcsat)() + end)(Plugin_mcsat) diff --git a/src/mcsat/Minismt_mcsat.mli b/src/mcsat/Minismt_mcsat.mli index 254699a6..1499f361 100644 --- a/src/mcsat/Minismt_mcsat.mli +++ b/src/mcsat/Minismt_mcsat.mli @@ -4,5 +4,5 @@ Copyright 2014 Guillaume Bury Copyright 2014 Simon Cruanes *) -module Make() : Minismt.Solver.S with type St.formula = Minismt_smt.Expr.atom +include Minismt.Solver.S with type St.formula = Minismt_smt.Expr.atom diff --git a/src/sat/Minismt_sat.ml b/src/sat/Minismt_sat.ml index d1e8b933..cbb2b082 100644 --- a/src/sat/Minismt_sat.ml +++ b/src/sat/Minismt_sat.ml @@ -6,6 +6,5 @@ Copyright 2016 Guillaume Bury module Expr = Expr_sat module Type = Type_sat -module Make() = - Minismt.Solver.Make(Expr)(Minismt.Solver.DummyTheory(Expr))() +include Minismt.Solver.Make(Expr)(Minismt.Solver.DummyTheory(Expr)) diff --git a/src/sat/Minismt_sat.mli b/src/sat/Minismt_sat.mli index 8c34b184..ffc5034f 100644 --- a/src/sat/Minismt_sat.mli +++ b/src/sat/Minismt_sat.mli @@ -12,6 +12,6 @@ Copyright 2016 Guillaume Bury module Expr = Expr_sat module Type = Type_sat -module Make() : Minismt.Solver.S with type St.formula = Expr.t +include Minismt.Solver.S with type St.formula = Expr.t (** A functor that can generate as many solvers as needed. *) diff --git a/src/smt/Minismt_smt.ml b/src/smt/Minismt_smt.ml index 6d25bf9b..473b12de 100644 --- a/src/smt/Minismt_smt.ml +++ b/src/smt/Minismt_smt.ml @@ -9,5 +9,5 @@ module Type = Type_smt module Th = Minismt.Solver.DummyTheory(Expr.Atom) -module Make() = Minismt.Solver.Make(Expr.Atom)(Th)() +include Minismt.Solver.Make(Expr.Atom)(Th) diff --git a/src/smt/Minismt_smt.mli b/src/smt/Minismt_smt.mli index d2e99f39..4850929d 100644 --- a/src/smt/Minismt_smt.mli +++ b/src/smt/Minismt_smt.mli @@ -7,5 +7,5 @@ Copyright 2014 Simon Cruanes module Expr = Expr_smt module Type = Type_smt -module Make() : Minismt.Solver.S with type St.formula = Expr_smt.atom +include Minismt.Solver.S with type St.formula = Expr_smt.atom diff --git a/src/solver/mcsolver.ml b/src/solver/mcsolver.ml index a81bf7ae..f6eb0c68 100644 --- a/src/solver/mcsolver.ml +++ b/src/solver/mcsolver.ml @@ -10,10 +10,6 @@ module Make (E : Expr_intf.S) (Th : Plugin_intf.S with type term = E.Term.t and type formula = E.Formula.t and type proof = E.proof) - () = - Msat.Make - (Make_mcsat_expr(E)()) - (Th) - () + = Msat.Make (Make_mcsat_expr(E)) (Th) diff --git a/src/solver/mcsolver.mli b/src/solver/mcsolver.mli index 7ae92fad..04727d35 100644 --- a/src/solver/mcsolver.mli +++ b/src/solver/mcsolver.mli @@ -16,9 +16,8 @@ module Make (E : Expr_intf.S) (Th : Plugin_intf.S with type term = E.Term.t and type formula = E.Formula.t and type proof = E.proof) - () : - S with type St.term = E.Term.t - and type St.formula = E.Formula.t - and type St.proof = E.proof + : S with type St.term = E.Term.t + and type St.formula = E.Formula.t + and type St.proof = E.proof (** Functor to create a solver parametrised by the atomic formulas and a theory. *) diff --git a/src/solver/solver.ml b/src/solver/solver.ml index 8ad787c5..d06de2e8 100644 --- a/src/solver/solver.ml +++ b/src/solver/solver.ml @@ -76,10 +76,6 @@ end module Make (E : Formula_intf.S) (Th : Theory_intf.S with type formula = E.t and type proof = E.proof) - () = - Msat.Make - (Make_smt_expr(E)(struct end)) - (Plugin(E)(Th)) - () + = Msat.Make (Make_smt_expr(E)) (Plugin(E)(Th)) diff --git a/src/solver/solver.mli b/src/solver/solver.mli index 37840eb7..fef0160e 100644 --- a/src/solver/solver.mli +++ b/src/solver/solver.mli @@ -23,9 +23,8 @@ module DummyTheory(F : Formula_intf.S) : module Make (F : Formula_intf.S) (Th : Theory_intf.S with type formula = F.t and type proof = F.proof) - () : - S with type St.formula = F.t - and type St.proof = F.proof + : S with type St.formula = F.t + and type St.proof = F.proof (** Functor to create a SMT Solver parametrised by the atomic formulas and a theory. *) diff --git a/tests/test_api.ml b/tests/test_api.ml index 69c1e526..838ceed0 100644 --- a/tests/test_api.ml +++ b/tests/test_api.ml @@ -40,14 +40,18 @@ type solver_res = exception Incorrect_model module type BASIC_SOLVER = sig - val solve : ?assumptions:F.t list -> unit -> solver_res - val assume : ?tag:int -> F.t list list -> unit + type t + val create : unit -> t + val solve : t -> ?assumptions:F.t list -> unit -> solver_res + val assume : t -> ?tag:int -> F.t list list -> unit end let mk_solver (): (module BASIC_SOLVER) = let module S = struct - include Minismt_sat.Make(struct end) - let solve ?assumptions ()= match solve ?assumptions() with + include Minismt_sat + let create() = create() + let solve st ?assumptions () = + match solve st ?assumptions() with | Sat _ -> R_sat | Unsat us -> @@ -86,13 +90,14 @@ module Test = struct let run (t:t): result = (* Interesting stuff happening *) let (module S: BASIC_SOLVER) = mk_solver () in + let st = S.create() in try List.iter (function | A_assume cs -> - S.assume cs + S.assume st cs | A_solve (assumptions, expect) -> - match S.solve ~assumptions (), expect with + match S.solve st ~assumptions (), expect with | R_sat, `Expect_sat | R_unsat, `Expect_unsat -> () | R_unsat, `Expect_sat ->