diff --git a/src/core/Internal.ml b/src/core/Internal.ml index cee6eae4..7e704403 100644 --- a/src/core/Internal.ml +++ b/src/core/Internal.ml @@ -831,7 +831,9 @@ module Make(Plugin : PLUGIN) mutable clause_incr : float; (* increment for clauses' activity *) - mutable on_conflict : (atom array -> unit); + mutable on_conflict : (atom array -> unit) option; + mutable on_decision : (atom -> unit) option; + mutable on_new_atom: (atom -> unit) option; } type solver = t @@ -868,12 +870,20 @@ module Make(Plugin : PLUGIN) var_incr = 1.; clause_incr = 1.; store_proof; - on_conflict = _nop_on_conflict; + on_conflict = None; + on_decision= None; + on_new_atom = None; } - let create ?(store_proof=true) ?(size=`Big) (th:theory) : t = + let create + ?on_conflict ?on_decision ?on_new_atom + ?(store_proof=true) ?(size=`Big) (th:theory) : t = let st = create_st ~size () in - create_ ~st ~store_proof th + let st = create_ ~st ~store_proof th in + st.on_new_atom <- on_new_atom; + st.on_decision <- on_decision; + st.on_conflict <- on_conflict; + st let[@inline] st t = t.st let[@inline] nb_clauses st = Vec.size st.clauses_hyps @@ -947,13 +957,14 @@ module Make(Plugin : PLUGIN) let make_term st t = let l = Lit.make st.st t in if l.l_level < 0 then ( - insert_elt_order st (E_lit l) + insert_elt_order st (E_lit l); ) let make_atom st (p:formula) : atom = let a = mk_atom st p in if a.var.v_level < 0 then ( insert_elt_order st (E_var a.var); + (match st.on_new_atom with Some f -> f a | None -> ()); ) else ( assert (a.is_true || a.neg.is_true); ); @@ -1923,14 +1934,16 @@ module Make(Plugin : PLUGIN) | Solver_intf.Unknown -> new_decision_level st; let current_level = decision_level st in - enqueue_bool st atom ~level:current_level Decision + enqueue_bool st atom ~level:current_level Decision; + (match st.on_decision with Some f -> f atom | None -> ()); | Solver_intf.Valued (b, l) -> let a = if b then atom else atom.neg in enqueue_semantic st a l ) else ( new_decision_level st; let current_level = decision_level st in - enqueue_bool st atom ~level:current_level Decision + enqueue_bool st atom ~level:current_level Decision; + (match st.on_decision with Some f -> f atom | None -> ()); ) and pick_branch_lit st = @@ -1986,7 +1999,7 @@ module Make(Plugin : PLUGIN) ) else ( add_clause_ st confl ); - st.on_conflict confl.atoms; + (match st.on_conflict with Some f -> f confl.atoms | None -> ()); | None -> (* No Conflict *) assert (st.elt_head = Vec.size st.trail); @@ -2063,7 +2076,7 @@ module Make(Plugin : PLUGIN) check_is_conflict_ c; Array.iter (fun a -> insert_elt_order st (Elt.of_var a.var)) c.atoms; Log.debugf info (fun k -> k "(@[sat.theory-conflict-clause@ %a@])" Clause.debug c); - st.on_conflict c.atoms; + (match st.on_conflict with Some f -> f c.atoms | None -> ()); Vec.push st.clauses_to_add c; flush_clauses st; end; @@ -2187,20 +2200,14 @@ module Make(Plugin : PLUGIN) | E_unsat (US_false c) -> st.unsat_at_0 <- Some c - let solve ?on_conflict ?(assumptions=[]) (st:t) : res = + let solve ?(assumptions=[]) (st:t) : res = cancel_until st 0; Vec.clear st.assumptions; List.iter (Vec.push st.assumptions) assumptions; - begin match on_conflict with - | None -> () - | Some f -> st.on_conflict <- f; - end; try solve_ st; - st.on_conflict <- _nop_on_conflict; Sat (mk_sat st) with E_unsat us -> - st.on_conflict <- _nop_on_conflict; Unsat (mk_unsat st us) let true_at_level0 st a = diff --git a/src/core/Solver_intf.ml b/src/core/Solver_intf.ml index a772e7c9..ab375a4d 100644 --- a/src/core/Solver_intf.ml +++ b/src/core/Solver_intf.ml @@ -444,6 +444,9 @@ module type S = sig (** Main solver type, containing all state for solving. *) val create : + ?on_conflict:(atom array -> unit) -> + ?on_decision:(atom -> unit) -> + ?on_new_atom:(atom -> unit) -> ?store_proof:bool -> ?size:[`Tiny|`Small|`Big] -> theory -> @@ -482,7 +485,6 @@ module type S = sig (** Lower level addition of clauses *) val solve : - ?on_conflict:(atom array -> unit) -> ?assumptions:atom list -> t -> res (** Try and solves the current set of clauses.