refactor: use 1st class for theory actions

This commit is contained in:
Simon Cruanes 2018-05-25 20:20:43 -05:00
parent edeb28c8ad
commit 47ddce5960
8 changed files with 153 additions and 115 deletions

View file

@ -1345,23 +1345,33 @@ module Make (Th : Theory_intf.S) = struct
(Util.pp_list Atom.debug) l
)
let current_slice st head = Theory_intf.Slice_acts {
slice_iter = slice_iter st head (Vec.size st.trail);
}
let current_slice st head : formula Theory_intf.slice_actions =
let module A = struct
type form = formula
let slice_iter = slice_iter st head (Vec.size st.trail)
end in
(module A)
(* full slice, for [if_sat] final check *)
let full_slice st = Theory_intf.Slice_acts {
slice_iter = slice_iter st 0 (Vec.size st.trail);
}
let full_slice st : formula Theory_intf.slice_actions =
let module A = struct
type form = formula
let slice_iter = slice_iter st 0 (Vec.size st.trail)
end in
(module A)
let act_at_level_0 st () = at_level_0 st
let actions st = Theory_intf.Actions {
push_persistent = act_push_persistent st;
push_local = act_push_local st;
on_backtrack = on_backtrack st;
propagate = act_propagate st;
}
let actions st: (formula,lemma) Theory_intf.actions =
let module A = struct
type nonrec formula = formula
type proof = lemma
let push_persistent = act_push_persistent st
let push_local = act_push_local st
let on_backtrack = on_backtrack st
let propagate = act_propagate st
end in
(module A)
let create ?(size=`Big) () : t =
let size_map, size_vars, size_trail, size_lvl = match size with

View file

@ -34,16 +34,8 @@ type 'clause export = 'clause Solver_intf.export = {
clauses : 'clause Vec.t;
}
type ('form, 'proof) actions = ('form,'proof) Theory_intf.actions = Actions of {
push_persistent : 'form IArray.t -> 'proof -> unit;
push_local : 'form IArray.t -> 'proof -> unit;
on_backtrack: (unit -> unit) -> unit;
propagate : 'form -> 'form list -> 'proof -> unit;
}
type ('form, 'proof) slice_actions = ('form, 'proof) Theory_intf.slice_actions = Slice_acts of {
slice_iter : ('form -> unit) -> unit;
}
type ('form, 'proof) actions = ('form,'proof) Theory_intf.actions
type 'form slice_actions = 'form Theory_intf.slice_actions
module Make = Solver.Make

View file

@ -39,30 +39,40 @@ type ('formula, 'proof) res =
theory tautology (with its proof), for which every literal is false
under the current assumptions. *)
(** Actions given to the theory during initialization, to be used
at any time *)
type ('form, 'proof) actions = Actions of {
push_persistent : 'form IArray.t -> 'proof -> unit;
module type ACTIONS = sig
type formula
type proof
val push_persistent : formula IArray.t -> proof -> unit
(** Allows to add a persistent clause to the solver. *)
push_local : 'form IArray.t -> 'proof -> unit;
val push_local : formula IArray.t -> proof -> unit
(** Allows to add a local clause to the solver. The clause
will be removed after backtracking. *)
on_backtrack: (unit -> unit) -> unit;
val on_backtrack: (unit -> unit) -> unit
(** [on_backtrack f] calls [f] when the main solver backtracks *)
propagate : 'form -> 'form list -> 'proof -> unit;
val propagate : formula -> formula list -> proof -> unit
(** [propagate lit causes proof] informs the solver to propagate [lit], with the reason
that the clause [causes => lit] is a theory tautology. It is faster than pushing
the associated clause but the clause will not be remembered by the sat solver,
i.e it will not be used by the solver to do boolean propagation. *)
}
end
type ('form, 'proof) slice_actions = Slice_acts of {
slice_iter : ('form -> unit) -> unit;
(** Actions given to the theory during initialization, to be used
at any time *)
type ('form, 'proof) actions =
(module ACTIONS with type formula = 'form and type proof = 'proof)
module type SLICE_ACTIONS = sig
type form
val slice_iter : (form -> unit) -> unit
(** iterate on the slice of the trail *)
}
end
type 'form slice_actions = (module SLICE_ACTIONS with type form = 'form)
(** The type for a slice. Slices are some kind of view of the current
propagation queue. They allow to look at the propagated literals,
and to add new clauses to the solver. *)
@ -110,11 +120,11 @@ module type S = sig
val create : (formula, proof) actions -> t
(** Create a new instance of the theory *)
val assume : t -> (formula, proof) slice_actions -> (formula, proof) res
val assume : t -> formula slice_actions -> (formula, proof) res
(** Assume the formulas in the slice, possibly pushing new formulas to be propagated,
and returns the result of the new assumptions. *)
val if_sat : t -> (formula, proof) slice_actions -> (formula, proof) res
val if_sat : t -> formula slice_actions -> (formula, proof) res
(** Called at the end of the search in case a model has been found. If no new clause is
pushed, then 'sat' is returned, else search is resumed. *)
end

View file

@ -19,19 +19,21 @@ module Sig_tbl = CCHashtbl.Make(Signature)
type merge_op = node * node * explanation
(* a merge operation to perform *)
type actions = {
on_backtrack:(unit -> unit) -> unit;
module type ACTIONS = sig
val on_backtrack: (unit -> unit) -> unit
(** Register a callback to be invoked upon backtracking below the current level *)
on_merge:repr -> repr -> explanation -> unit;
val on_merge: repr -> repr -> explanation -> unit
(** Call this when two classes are merged *)
raise_conflict: 'a. conflict -> 'a;
val raise_conflict: conflict -> 'a
(** Report a conflict *)
propagate: Lit.t -> Lit.t list -> unit;
val propagate: Lit.t -> Lit.t list -> unit
(** Propagate a literal *)
}
end
type actions = (module ACTIONS)
type t = {
tst: Term.state;
@ -63,7 +65,9 @@ type t = {
several times.
See "fast congruence closure and extensions", Nieuwenhis&al, page 14 *)
let[@inline] on_backtrack cc f : unit = cc.acts.on_backtrack f
let[@inline] on_backtrack cc f : unit =
let (module A) = cc.acts in
A.on_backtrack f
let[@inline] is_root_ (n:node) : bool = n.n_root == n
@ -192,7 +196,8 @@ let rec reroot_expl (cc:t) (n:node): unit =
end
let[@inline] raise_conflict (cc:t) (e:conflict): _ =
cc.acts.raise_conflict e
let (module A) = cc.acts in
A.raise_conflict e
let[@inline] all_classes cc : repr Sequence.t =
Term.Tbl.values cc.tbl
@ -482,7 +487,8 @@ and update_combine cc =
Side effect: also pushes sub-tasks *)
and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit =
assert (is_root_ rb);
cc.acts.on_merge ra rb e
let (module A) = cc.acts in
A.on_merge ra rb e
(* FIXME: callback?

View file

@ -13,19 +13,21 @@ type repr = Equiv_class.t
type conflict = Theory.conflict
type actions = {
on_backtrack:(unit -> unit) -> unit;
module type ACTIONS = sig
val on_backtrack: (unit -> unit) -> unit
(** Register a callback to be invoked upon backtracking below the current level *)
on_merge:repr -> repr -> explanation -> unit;
val on_merge: repr -> repr -> explanation -> unit
(** Call this when two classes are merged *)
raise_conflict: 'a. conflict -> 'a;
val raise_conflict: conflict -> 'a
(** Report a conflict *)
propagate: Lit.t -> Lit.t list -> unit;
val propagate: Lit.t -> Lit.t list -> unit
(** Propagate a literal *)
}
end
type actions = (module ACTIONS)
val create :
?size:int ->

View file

@ -14,58 +14,65 @@ end = struct
)
end
(** Runtime state of a theory, with all the operations it provides.
['a] is the internal state *)
type state = State : {
mutable st: 'a;
on_merge: 'a -> Equiv_class.t -> Equiv_class.t -> Explanation.t -> unit;
module type STATE = sig
type t
val state : t
val on_merge: t -> Equiv_class.t -> Equiv_class.t -> Explanation.t -> unit
(** Called when two classes are merged *)
on_assert: 'a -> Lit.t -> unit;
val on_assert: t -> Lit.t -> unit
(** Called when a literal becomes true *)
final_check: 'a -> Lit.t Sequence.t -> unit;
val final_check: t -> Lit.t Sequence.t -> unit
(** Final check, must be complete (i.e. must raise a conflict
if the set of literals is not satisfiable) *)
} -> state
end
(** Runtime state of a theory, with all the operations it provides. *)
type state = (module STATE)
(** Unsatisfiable conjunction.
Its negation will become a conflict clause *)
type conflict = Lit.t list
(** Actions available to a theory during its lifetime *)
type actions = {
on_backtrack: (unit -> unit) -> unit;
module type ACTIONS = sig
val on_backtrack: (unit -> unit) -> unit
(** Register an action to do when we backtrack *)
raise_conflict: 'a. conflict -> 'a;
val raise_conflict: conflict -> 'a
(** Give a conflict clause to the solver *)
propagate_eq: Term.t -> Term.t -> Lit.t list -> unit;
val propagate_eq: Term.t -> Term.t -> Lit.t list -> unit
(** Propagate an equality [t = u] because [e] *)
propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit;
val propagate_distinct: Term.t list -> neq:Term.t -> Lit.t -> unit
(** Propagate a [distinct l] because [e] (where [e = neq] *)
propagate: Lit.t -> Lit.t list -> unit;
val propagate: Lit.t -> Lit.t list -> unit
(** Propagate a boolean using a unit clause.
[expl => lit] must be a theory lemma, that is, a T-tautology *)
add_local_axiom: Lit.t IArray.t -> unit;
val add_local_axiom: Lit.t IArray.t -> unit
(** Add local clause to the SAT solver. This clause will be
removed when the solver backtracks. *)
add_persistent_axiom: Lit.t IArray.t -> unit;
val add_persistent_axiom: Lit.t IArray.t -> unit
(** Add toplevel clause to the SAT solver. This clause will
not be backtracked. *)
find: Term.t -> Equiv_class.t;
val find: Term.t -> Equiv_class.t
(** Find representative of this term *)
all_classes: Equiv_class.t Sequence.t;
val all_classes: Equiv_class.t Sequence.t
(** All current equivalence classes
(caution: linear in the number of terms existing in the solver) *)
}
end
type actions = (module ACTIONS)
type t = {
name: string;
@ -75,9 +82,17 @@ type t = {
let make ~name ~make () : t = {name;make}
let make_st
(type st)
?(on_merge=fun _ _ _ _ -> ())
?(on_assert=fun _ _ -> ())
~final_check
~st
() : state =
State { st; on_merge; on_assert; final_check }
let module A = struct
type nonrec t = st
let state = st
let on_merge = on_merge
let on_assert = on_assert
let final_check = final_check
end in
(module A : STATE)

View file

@ -53,7 +53,7 @@ let assume_lit (self:t) (lit:Lit.t) : unit =
| Lit_atom _ ->
(* transmit to theories. *)
Congruence_closure.assert_lit (cc self) lit;
theories self (fun (Theory.State th) -> th.on_assert th.st lit);
theories self (fun (module Th) -> Th.on_assert Th.state lit);
end
(* return result to the SAT solver *)
@ -84,13 +84,13 @@ let with_conflict_catch self f =
cdcl_return_res self
(* propagation from the bool solver *)
let assume_real (self:t) (slice:_ Sat_solver.slice_actions) =
let assume_real (self:t) (slice:Lit.t Sat_solver.slice_actions) =
(* TODO if Config.progress then print_progress(); *)
let Sat_solver.Slice_acts slice = slice in
Log.debugf 5 (fun k->k "(th_combine.assume :len %d)" (Sequence.length slice.slice_iter));
let (module SA) = slice in
Log.debugf 5 (fun k->k "(th_combine.assume :len %d)" (Sequence.length @@ SA.slice_iter));
with_conflict_catch self
(fun () ->
slice.slice_iter (assume_lit self);
SA.slice_iter (assume_lit self);
(* now check satisfiability *)
check self
)
@ -104,28 +104,28 @@ let assume (self:t) (slice:_ Sat_solver.slice_actions) =
cdcl_return_res self
(* perform final check of the model *)
let if_sat (self:t) (slice:_) : _ Sat_solver.res =
let if_sat (self:t) (slice:Lit.t Sat_solver.slice_actions) : _ Sat_solver.res =
Congruence_closure.final_check (cc self);
(* all formulas in the SAT solver's trail *)
let forms =
let Sat_solver.Slice_acts r = slice in
r.slice_iter
let (module SA) = slice in
SA.slice_iter
in
(* final check for each theory *)
with_conflict_catch self
(fun () ->
theories self
(fun (Theory.State th) -> th.final_check th.st forms))
(fun (module Th) -> Th.final_check Th.state forms))
(** {2 Various helpers} *)
(* forward propagations from CC or theories directly to the SMT core *)
let act_propagate (self:t) f guard : unit =
let Sat_solver.Actions r = self.cdcl_acts in
let (module A) = self.cdcl_acts in
Sat_solver.Log.debugf 2
(fun k->k "(@[@{<green>propagate@}@ %a@ :guard %a@])"
Lit.pp f (Util.pp_list Lit.pp) guard);
r.propagate f guard Proof.default
A.propagate f guard Proof.default
(** {2 Interface to Congruence Closure} *)
@ -134,16 +134,17 @@ let act_raise_conflict e = raise (Exn_conflict e)
(* when CC decided to merge [r1] and [r2], notify theories *)
let on_merge_from_cc (self:t) r1 r2 e : unit =
theories self
(fun (Theory.State th) -> th.on_merge th.st r1 r2 e)
(fun (module Th) -> Th.on_merge Th.state r1 r2 e)
let mk_cc_actions (self:t) : Congruence_closure.actions =
let Sat_solver.Actions r = self.cdcl_acts in
{ Congruence_closure.
on_backtrack = r.on_backtrack;
on_merge = on_merge_from_cc self;
raise_conflict = act_raise_conflict;
propagate = act_propagate self;
}
let (module A) = self.cdcl_acts in
let module R = struct
let on_backtrack = A.on_backtrack
let on_merge = on_merge_from_cc self
let raise_conflict = act_raise_conflict
let propagate = act_propagate self
end in
(module R)
(** {2 Main} *)
@ -180,29 +181,30 @@ let act_find self t =
let act_add_local_axiom self c : unit =
Sat_solver.Log.debugf 5 (fun k->k "(@[<2>th_combine.push_local_lemma@ %a@])" Theory.Clause.pp c);
let Sat_solver.Actions r = self.cdcl_acts in
r.push_local c Proof.default
let (module A) = self.cdcl_acts in
A.push_local c Proof.default
(* push one clause into [M], in the current level (not a lemma but
an axiom) *)
let act_add_persistent_axiom self c : unit =
Sat_solver.Log.debugf 5 (fun k->k "(@[<2>th_combine.push_persistent_lemma@ %a@])" Theory.Clause.pp c);
let Sat_solver.Actions r = self.cdcl_acts in
r.push_persistent c Proof.default
let (module A) = self.cdcl_acts in
A.push_persistent c Proof.default
let mk_theory_actions (self:t) : Theory.actions =
let Sat_solver.Actions r = self.cdcl_acts in
{ Theory.
on_backtrack = r.on_backtrack;
raise_conflict = act_raise_conflict;
propagate = act_propagate self;
all_classes = act_all_classes self;
propagate_eq = act_propagate_eq self;
propagate_distinct = act_propagate_distinct self;
add_local_axiom = act_add_local_axiom self;
add_persistent_axiom = act_add_persistent_axiom self;
find = act_find self;
}
let (module A) = self.cdcl_acts in
let module R = struct
let on_backtrack = A.on_backtrack
let raise_conflict = act_raise_conflict
let propagate = act_propagate self
let all_classes = act_all_classes self
let propagate_eq = act_propagate_eq self
let propagate_distinct = act_propagate_distinct self
let add_local_axiom = act_add_local_axiom self
let add_persistent_axiom = act_add_persistent_axiom self
let find = act_find self
end
in (module R)
let add_theory (self:t) (th:Theory.t) : unit =
Sat_solver.Log.debugf 2

View file

@ -251,17 +251,18 @@ type t = {
let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit =
Log.debugf 5 (fun k->k "(@[th_bool.tseitin@ %a@])" Lit.pp lit);
let (module A) = self.acts in
match b with
| B_not _ -> assert false (* normalized *)
| B_eq (t,u) ->
if Lit.sign lit then (
self.acts.Theory.propagate_eq t u [lit]
A.propagate_eq t u [lit]
) else (
self.acts.Theory.propagate_distinct [t;u] ~neq:lit_t lit
A.propagate_distinct [t;u] ~neq:lit_t lit
)
| B_distinct l ->
if Lit.sign lit then (
self.acts.Theory.propagate_distinct l ~neq:lit_t lit
A.propagate_distinct l ~neq:lit_t lit
) else (
(* TODO: propagate pairwise equalities? *)
Error.errorf "cannot process negative distinct lit %a" Lit.pp lit;
@ -272,39 +273,39 @@ let tseitin (self:t) (lit:Lit.t) (lit_t:term) (b:term builtin) : unit =
List.iter
(fun sub ->
let sublit = Lit.atom sub in
self.acts.Theory.propagate sublit [lit])
A.propagate sublit [lit])
subs
) else (
(* propagate [¬lit => _i ¬ subs_i] *)
let c = Lit.neg lit :: List.map (Lit.atom ~sign:false) subs in
self.acts.Theory.add_local_axiom (IArray.of_list c)
A.add_local_axiom (IArray.of_list c)
)
| B_or subs ->
if Lit.sign lit then (
(* propagate [lit => _i subs_i] *)
let c = Lit.neg lit :: List.map (Lit.atom ~sign:true) subs in
self.acts.Theory.add_local_axiom (IArray.of_list c)
A.add_local_axiom (IArray.of_list c)
) else (
(* propagate [¬lit => ¬subs_i] *)
List.iter
(fun sub ->
let sublit = Lit.atom ~sign:false sub in
self.acts.Theory.propagate sublit [lit])
A.propagate sublit [lit])
subs
)
| B_imply (guard,concl) ->
if Lit.sign lit then (
(* propagate [lit => _i ¬guard_i concl] *)
let c = Lit.atom concl :: Lit.neg lit :: List.map (Lit.atom ~sign:false) guard in
self.acts.Theory.add_local_axiom (IArray.of_list c)
A.add_local_axiom (IArray.of_list c)
) else (
(* propagate [¬lit => ¬concl] *)
self.acts.Theory.propagate (Lit.atom ~sign:false concl) [lit];
A.propagate (Lit.atom ~sign:false concl) [lit];
(* propagate [¬lit => ∧_i guard_i] *)
List.iter
(fun sub ->
let sublit = Lit.atom ~sign:true sub in
self.acts.Theory.propagate sublit [lit])
A.propagate sublit [lit])
guard
)