refactor(preprocess): introduce Find_foreign, runs after preprocess

This commit is contained in:
Simon Cruanes 2022-09-10 14:10:36 -04:00
parent 3d95fc16c4
commit 721ed2eac0
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
11 changed files with 173 additions and 60 deletions

View file

@ -15,6 +15,7 @@ module Solver = Solver
module Theory = Theory
module Theory_id = Theory_id
module Preprocess = Preprocess
module Find_foreign = Find_foreign
type theory = Theory.t
type solver = Solver.t

37
src/smt/find_foreign.ml Normal file
View file

@ -0,0 +1,37 @@
open Sidekick_core
module type ACTIONS = sig
val declare_need_th_combination : Term.t -> unit
(** Declare that this term is a foreign variable in some other subterm. *)
val add_lit_for_bool_term : ?default_pol:bool -> Term.t -> unit
(** Add the (boolean) term to the SAT solver *)
end
type actions = (module ACTIONS)
type hook = actions -> is_sub:bool -> Term.t -> unit
type t = { seen: unit Term.Tbl.t; mutable hooks: hook list }
let create () : t = { hooks = []; seen = Term.Tbl.create 8 }
let add_hook self h = self.hooks <- h :: self.hooks
let traverse_term (self : t) ((module A) as acts : actions) (t : Term.t) : unit
=
let rec loop ~is_sub t =
if (not (Term.is_a_type t)) && not (Term.Tbl.mem self.seen t) then (
Term.Tbl.add self.seen t ();
Log.debugf 10 (fun k -> k "(@[find-foreign-in@ %a@])" Term.pp t);
(* boolean subterm: need a literal *)
if Term.is_bool (Term.ty t) then A.add_lit_for_bool_term t;
(* call hooks *)
List.iter (fun (h : hook) -> h acts ~is_sub t) self.hooks;
match Term.open_eq t with
| Some (_, _) when not is_sub ->
Term.iter_shallow t ~f:(fun _ u -> loop ~is_sub:false u)
| _ -> Term.iter_shallow t ~f:(fun _ u -> loop ~is_sub:true u)
)
in
loop ~is_sub:false t

28
src/smt/find_foreign.mli Normal file
View file

@ -0,0 +1,28 @@
(** Find foreign variables.
This module is a modular discoverer of foreign variables (and boolean terms).
It should run after preprocessing of terms.
*)
open Sidekick_core
module type ACTIONS = sig
val declare_need_th_combination : Term.t -> unit
(** Declare that this term is a foreign variable in some other subterm. *)
val add_lit_for_bool_term : ?default_pol:bool -> Term.t -> unit
(** Add the (boolean) term to the SAT solver *)
end
type actions = (module ACTIONS)
type t
type hook = actions -> is_sub:bool -> Term.t -> unit
val create : unit -> t
val add_hook : t -> hook -> unit
(** Register a hook to detect foreign subterms *)
val traverse_term : t -> actions -> Term.t -> unit
(** Traverse subterms of this term to detect foreign variables
and boolean subterms. *)

View file

@ -1,4 +1,3 @@
open Sidekick_core
open Sigs
module T = Term

View file

@ -5,7 +5,6 @@ module type PREPROCESS_ACTS = sig
val mk_lit : ?sign:bool -> term -> lit
val add_clause : lit list -> step_id -> unit
val add_lit : ?default_pol:bool -> lit -> unit
val declare_need_th_combination : term -> unit
end
type preprocess_actions = (module PREPROCESS_ACTS)
@ -49,7 +48,7 @@ let preprocess_term_ (self : t) ((module A : PREPROCESS_ACTS) as acts)
match Term.Tbl.find_opt self.preprocessed t0 with
| Some u -> u
| None ->
Log.debugf 50 (fun k -> k "(@[smt.preprocess@ %a@])" Term.pp_debug t0);
Log.debugf 50 (fun k -> k "(@[smt.preprocess@ %a@])" Term.pp t0);
(* try hooks first *)
let t =
@ -82,21 +81,6 @@ let preprocess_term_ (self : t) ((module A : PREPROCESS_ACTS) as acts)
in
Term.Tbl.add self.preprocessed t0 t;
(* signal boolean subterms, so as to decide them
in the SAT solver *)
if Term.is_bool (Term.ty t) then (
Log.debugf 5 (fun k ->
k "(@[solver.map-bool-subterm-to-lit@ :subterm %a@])" Term.pp t);
(* ensure that SAT solver has a boolean atom for [t] *)
let lit = Lit.atom self.tst t in
A.add_lit lit;
(* also map [sub] to this atom in the congruence closure, for propagation *)
(* FIXME: use a delayed action "DA_declare_cc_lit (t, lit)" instead *)
CC.set_as_lit self.cc (CC.add_term self.cc t) lit
);
t
in
preproc_rec_ ~is_sub:false t
@ -114,8 +98,7 @@ let simplify_and_preproc_lit (self : t) acts (lit : Lit.t) :
| None -> t, None
| Some (u, pr_t_u) ->
Log.debugf 30 (fun k ->
k "(@[smt-solver.simplify@ :t %a@ :into %a@])" Term.pp_debug t
Term.pp_debug u);
k "(@[smt-solver.simplify@ :t %a@ :into %a@])" Term.pp t Term.pp u);
u, Some pr_t_u
in
let v = preprocess_term_ self acts u in

View file

@ -32,8 +32,6 @@ module type PREPROCESS_ACTS = sig
val add_lit : ?default_pol:bool -> lit -> unit
(** Ensure the literal will be decided/handled by the SAT solver. *)
val declare_need_th_combination : term -> unit
end
type preprocess_actions = (module PREPROCESS_ACTS)

View file

@ -1,5 +1,4 @@
open Sigs
module Ty = Term
type th_states =
| Ths_nil
@ -27,6 +26,7 @@ module Registry = Registry
type delayed_action =
| DA_add_clause of { c: lit list; pr: step_id; keep: bool }
| DA_add_lit of { default_pol: bool option; lit: lit }
| DA_add_preprocessed_lit of { default_pol: bool option; lit: lit }
type preprocess_hook =
Preprocess.t ->
@ -48,6 +48,7 @@ type t = {
mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list;
preprocess: Preprocess.t;
find_foreign: Find_foreign.t;
mutable model_ask: model_ask_hook list;
mutable model_complete: model_completion_hook list;
simp: Simplify.t;
@ -83,6 +84,7 @@ let[@inline] has_delayed_actions self =
not (Queue.is_empty self.delayed_actions)
let on_preprocess self f = Preprocess.on_preprocess self.preprocess f
let on_find_foreign self f = Find_foreign.add_hook self.find_foreign f
let on_model ?ask ?complete self =
Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
@ -118,6 +120,9 @@ let add_sat_lit_ _self ?default_pol (acts : theory_actions) (lit : Lit.t) : unit
let delayed_add_lit (self : t) ?default_pol (lit : Lit.t) : unit =
Queue.push (DA_add_lit { default_pol; lit }) self.delayed_actions
let delayed_add_preprocessed_lit (self : t) ?default_pol (lit : Lit.t) : unit =
Queue.push (DA_add_preprocessed_lit { default_pol; lit }) self.delayed_actions
let delayed_add_clause (self : t) ~keep (c : Lit.t list) (pr : step_id) : unit =
Queue.push (DA_add_clause { c; pr; keep }) self.delayed_actions
@ -127,11 +132,34 @@ let preprocess_acts (self : t) : Preprocess.preprocess_actions =
let mk_lit ?sign t : Lit.t = Lit.atom ?sign self.tst t
let add_clause c pr = delayed_add_clause self ~keep:true c pr
let add_lit ?default_pol lit = delayed_add_lit self ?default_pol lit
end)
let find_foreign_acts (self : t) : Find_foreign.actions =
(module struct
let add_lit_for_bool_term ?default_pol t =
let lit = Lit.atom self.tst t in
(* [lit] has already been preprocessed, do not preprocess it
again lest we meet an infinite recursion *)
delayed_add_preprocessed_lit self ?default_pol lit
let declare_need_th_combination t =
Th_combination.add_term_needing_combination self.th_comb t
end)
(* find boolean subterms/foreign variables in [t] *)
let find_foreign_vars_in (self : t) (t : Term.t) : unit =
let acts = find_foreign_acts self in
Find_foreign.traverse_term self.find_foreign acts t
let find_foreign_vars_in_lit (self : t) (lit : Lit.t) =
find_foreign_vars_in self (Lit.term lit)
let find_foreign_vars_in_lits (self : t) (c : Lit.t list) =
List.iter (find_foreign_vars_in_lit self) c
let find_foreign_vars_in_lit_arr (self : t) (c : Lit.t array) =
Array.iter (find_foreign_vars_in_lit self) c
let push_decision (self : t) (acts : theory_actions) (lit : lit) : unit =
let (module A) = acts in
(* make sure the literal is preprocessed *)
@ -139,6 +167,7 @@ let push_decision (self : t) (acts : theory_actions) (lit : lit) : unit =
Preprocess.simplify_and_preproc_lit self.preprocess (preprocess_acts self)
lit
in
find_foreign_vars_in_lit self lit;
let sign = Lit.sign lit in
A.add_decision_lit (Lit.abs lit) sign
@ -160,12 +189,22 @@ module Perform_delayed (A : PERFORM_ACTS) = struct
Preprocess.preprocess_clause self.preprocess (preprocess_acts self) c
pr_c
in
find_foreign_vars_in_lits self c';
A.add_clause self acts ~keep c' pr_c'
| DA_add_lit { default_pol; lit } ->
let lit, _ =
Preprocess.simplify_and_preproc_lit self.preprocess
(preprocess_acts self) lit
in
let t = Lit.term lit in
find_foreign_vars_in_lit self lit;
CC.set_as_lit self.cc (CC.add_term self.cc t) lit;
A.add_lit self acts ?default_pol lit
| DA_add_preprocessed_lit { default_pol; lit } ->
let t = Lit.term lit in
Log.debugf 5 (fun k ->
k "(@[solver.map-bool-subterm-to-lit@ :subterm %a@])" Term.pp t);
CC.set_as_lit self.cc (CC.add_term self.cc t) lit;
A.add_lit self acts ?default_pol lit
done
end
@ -182,26 +221,43 @@ module Perform_delayed_th = Perform_delayed (struct
end)
let[@inline] preprocess self = self.preprocess
let[@inline] find_foreign self = self.find_foreign
let preprocess_clause self c pr =
Preprocess.preprocess_clause self.preprocess (preprocess_acts self) c pr
let c, pr =
Preprocess.preprocess_clause self.preprocess (preprocess_acts self) c pr
in
find_foreign_vars_in_lits self c;
c, pr
let preprocess_clause_array self c pr =
Preprocess.preprocess_clause_array self.preprocess (preprocess_acts self) c pr
let c, pr =
Preprocess.preprocess_clause_array self.preprocess (preprocess_acts self) c
pr
in
find_foreign_vars_in_lit_arr self c;
c, pr
let simplify_and_preproc_lit self lit =
Preprocess.simplify_and_preproc_lit self.preprocess (preprocess_acts self) lit
let lit, pr =
Preprocess.simplify_and_preproc_lit self.preprocess (preprocess_acts self)
lit
in
find_foreign_vars_in_lit self lit;
lit, pr
let[@inline] add_clause_temp self _acts c (proof : step_id) : unit =
let add_clause_temp self _acts c (proof : step_id) : unit =
let c, proof =
Preprocess.preprocess_clause self.preprocess (preprocess_acts self) c proof
in
find_foreign_vars_in_lits self c;
delayed_add_clause self ~keep:false c proof
let[@inline] add_clause_permanent self _acts c (proof : step_id) : unit =
let add_clause_permanent self _acts c (proof : step_id) : unit =
let c, proof =
Preprocess.preprocess_clause self.preprocess (preprocess_acts self) c proof
in
find_foreign_vars_in_lits self c;
delayed_add_clause self ~keep:true c proof
let[@inline] mk_lit self ?sign t : lit = Lit.atom ?sign self.tst t
@ -221,6 +277,7 @@ let add_lit_t self _acts ?sign t =
Preprocess.simplify_and_preproc_lit self.preprocess (preprocess_acts self)
lit
in
find_foreign_vars_in_lit self lit;
delayed_add_lit self lit
let on_final_check self f = self.on_final_check <- f :: self.on_final_check
@ -388,10 +445,8 @@ let assert_lits_ ~final (self : t) (acts : theory_actions) (lits : Lit.t Iter.t)
if new_intf_eqns <> [] then (
let (module A) = acts in
List.iter (fun lit -> A.add_lit ~default_pol:false lit) new_intf_eqns
);
(* if theory combination didn't add new clauses, compute a model *)
if not (has_delayed_actions self) then (
) else if not (has_delayed_actions self) then (
(* if theory combination didn't add new clauses, compute a model *)
let m = mk_model_ self lits in
self.last_model <- Some m
)
@ -469,6 +524,7 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
let simp = Simplify.create tst ~proof in
let cc = CC.create (module A : CC.ARG) ~size:`Big tst proof in
let preprocess = Preprocess.create ~stat ~proof ~cc ~simplify:simp tst in
let find_foreign = Find_foreign.create () in
let self =
{
tst;
@ -478,6 +534,7 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
stat;
simp;
preprocess;
find_foreign;
last_model = None;
seen_types = Term.Weak_set.create 8;
th_comb = Th_combination.create ~stat tst;

View file

@ -94,6 +94,13 @@ val preprocess_clause_array : t -> lit array -> step_id -> lit array * step_id
val simplify_and_preproc_lit : t -> lit -> lit * step_id option
(** Simplify literal then preprocess it *)
(** {3 Finding foreign variables} *)
val find_foreign : t -> Find_foreign.t
val on_find_foreign : t -> Find_foreign.hook -> unit
(** Add a hook for finding foreign variables *)
(** {3 hooks for the theory} *)
val raise_conflict : t -> theory_actions -> lit list -> step_id -> 'a

View file

@ -25,7 +25,7 @@ let processed_ (self : t) t : bool =
| Some set -> T.Set.mem t set
let add_term_needing_combination (self : t) (t : T.t) : unit =
if not (processed_ self t) then (
if not (processed_ self t) && not (T.is_bool @@ T.ty t) then (
Log.debugf 50 (fun k -> k "(@[th.comb.add-term-needing-comb@ %a@])" T.pp t);
Vec.push self.unprocessed t
)

View file

@ -148,8 +148,7 @@ end = struct
let fresh_term self ~for_t ~pre ty =
let u = Gensym.fresh_term self.gensym ~pre ty in
Log.debugf 20 (fun k ->
k "(@[sidekick.bool.proxy@ :t %a@ :for %a@])" T.pp_debug u T.pp_debug
for_t);
k "(@[sidekick.bool.proxy@ :t %a@ :for %a@])" T.pp u T.pp for_t);
assert (Term.equal ty (T.ty u));
u
@ -160,7 +159,7 @@ end = struct
(* TODO: polarity? *)
let cnf (self : state) (_preproc : SMT.Preprocess.t) ~is_sub:_ ~recurse
(module PA : SI.PREPROCESS_ACTS) (t : T.t) : T.t option =
Log.debugf 50 (fun k -> k "(@[th-bool.cnf@ %a@])" T.pp_debug t);
Log.debugf 50 (fun k -> k "(@[th-bool.cnf@ %a@])" T.pp t);
let[@inline] mk_step_ r = Proof_trace.add_step PA.proof r in
(* handle boolean equality *)

View file

@ -29,7 +29,7 @@ module SimpVar : Linear_expr.VAR with type t = Term.t and type lit = Tag.t =
struct
type t = Term.t
let pp = Term.pp_debug
let pp = Term.pp
let compare = Term.compare
type lit = Tag.t
@ -62,7 +62,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
match A.view_as_lra t with
| LRA_other _ -> LE.monomial1 t
| LRA_pred _ ->
Error.errorf "type error: in linexp, LRA predicate %a" Term.pp_debug t
Error.errorf "type error: in linexp, LRA predicate %a" Term.pp t
| LRA_op (op, t1, t2) ->
let t1 = as_linexp t1 in
let t2 = as_linexp t2 in
@ -247,7 +247,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
self.encoded_le <- Comb_map.add le_comb proxy self.encoded_le;
Log.debugf 50 (fun k ->
k "(@[lra.encode-linexp@ `@[%a@]`@ :into-var %a@])" LE_.Comb.pp
le_comb Term.pp_debug proxy);
le_comb Term.pp proxy);
LE_.Comb.iter (fun v _ -> SimpSolver.add_var self.simplex v) le_comb;
SimpSolver.define self.simplex proxy (LE_.Comb.to_list le_comb);
@ -293,20 +293,19 @@ module Make (A : ARG) = (* : S with module A = A *) struct
Log.debugf 50 (fun k ->
k "(@[lra.encode-linexp.to-term@ `@[%a@]`@ :new-t %a@])" LE_.Comb.pp
le_comb Term.pp_debug proxy);
le_comb Term.pp proxy);
proxy, A.Q.one)
(* preprocess linear expressions away *)
let preproc_lra (self : state) _preproc ~is_sub ~recurse
let preproc_lra (self : state) _preproc ~is_sub:_ ~recurse
(module PA : SI.PREPROCESS_ACTS) (t : Term.t) : Term.t option =
Log.debugf 50 (fun k -> k "(@[lra.preprocess@ %a@])" Term.pp_debug t);
Log.debugf 50 (fun k -> k "(@[lra.preprocess@ %a@])" Term.pp t);
let tst = self.tst in
(* tell the CC this term exists *)
let declare_term_to_cc ~sub:_ t =
Log.debugf 50 (fun k ->
k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t);
Log.debugf 50 (fun k -> k "(@[lra.declare-term-to-cc@ %a@])" Term.pp t);
ignore (CC.add_term (SMT.Preprocess.cc _preproc) t : E_node.t)
in
@ -331,7 +330,8 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* TODO: box [t], recurse on [t1 <= t2] and [t1 >= t2],
add 3 atomic clauses, return [box t] *)
let _, t = Term.abs self.tst t in
if not (Term.Tbl.mem self.encoded_eqs t) then (
let box_t = Box.box self.tst t in
if not (Term.Tbl.mem self.encoded_eqs box_t) then (
(* preprocess t1, t2 recursively *)
let t1 = recurse t1 in
let t2 = recurse t2 in
@ -339,10 +339,10 @@ module Make (A : ARG) = (* : S with module A = A *) struct
let u1 = A.mk_lra tst (LRA_pred (Leq, t1, t2)) in
let u2 = A.mk_lra tst (LRA_pred (Geq, t1, t2)) in
Term.Tbl.add self.encoded_eqs t ();
Term.Tbl.add self.encoded_eqs box_t ();
(* encode [t <=> (u1 /\ u2)] *)
let lit_t = PA.mk_lit t in
let lit_t = PA.mk_lit box_t in
let lit_u1 = PA.mk_lit u1 in
let lit_u2 = PA.mk_lit u2 in
add_clause_lra_ (module PA) [ Lit.neg lit_t; lit_u1 ];
@ -383,14 +383,14 @@ module Make (A : ARG) = (* : S with module A = A *) struct
op
in
let lit = fresh_lit self ~mk_lit:PA.mk_lit ~pre:"$lra" in
let lit = PA.mk_lit ~sign:true box_t in
let constr = SimpSolver.Constraint.mk v op q in
SimpSolver.declare_bound self.simplex constr (Tag.Lit lit);
Term.Tbl.add self.simp_preds (Lit.term lit) (v, op, q);
Term.Tbl.add self.encoded_lits (Lit.term lit) box_t;
Term.Tbl.add self.simp_preds box_t (v, op, q);
Term.Tbl.add self.encoded_lits box_t box_t;
Log.debugf 50 (fun k ->
k "(@[lra.preproc@ :t %a@ :to-constr %a@])" Term.pp_debug t
k "(@[lra.preproc@ :t %a@ :to-constr %a@])" Term.pp t
SimpSolver.Constraint.pp constr);
Some box_t
@ -404,11 +404,15 @@ module Make (A : ARG) = (* : S with module A = A *) struct
Term.Tbl.add self.simp_defined t (box_t, le);
Some box_t)
| LRA_const _n -> None
| LRA_other t when A.has_ty_real t && is_sub ->
PA.declare_need_th_combination t;
None
| LRA_other _ -> None
let find_foreign (acts : SMT.Find_foreign.actions) ~is_sub (t : Term.t) : unit
=
if A.has_ty_real t && is_sub then (
let (module FA : SMT.Find_foreign.ACTIONS) = acts in
FA.declare_need_th_combination t
)
let simplify (self : state) (_recurse : _) (t : Term.t) :
(Term.t * Proof_step.id Iter.t) option =
let proof_eq t u =
@ -526,7 +530,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
let add_local_eq_t (self : state) si acts t1 t2 ~tag : unit =
Log.debugf 20 (fun k ->
k "(@[lra.add-local-eq@ %a@ %a@])" Term.pp_debug t1 Term.pp_debug t2);
k "(@[lra.add-local-eq@ %a@ %a@])" Term.pp t1 Term.pp t2);
reset_res_ self;
let t1, t2 =
if Term.compare t1 t2 > 0 then
@ -618,8 +622,8 @@ module Make (A : ARG) = (* : S with module A = A *) struct
match Term.Tbl.get self.simp_preds lit_t, A.view_as_lra lit_t with
| Some (v, op, q), _ ->
Log.debugf 50 (fun k ->
k "(@[lra.partial-check.add@ :lit %a@ :lit-t %a@])" Lit.pp lit
Term.pp_debug lit_t);
k "(@[lra.partial-check.add@ :lit %a@ :lit-t %a@])" Lit.pp lit Term.pp
lit_t);
(* need to account for the literal's sign *)
let op =
@ -686,7 +690,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
let res =
match self.last_res with
| Some (SimpSolver.Sat m) ->
Log.debugf 50 (fun k -> k "(@[lra.model-ask@ %a@])" Term.pp_debug t);
Log.debugf 50 (fun k -> k "(@[lra.model-ask@ %a@])" Term.pp t);
(match A.view_as_lra t with
| LRA_const n -> Some n (* always eval constants to themselves *)
| _ -> SimpSolver.V_map.get t m)
@ -709,8 +713,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
match self.last_res with
| Some (SimpSolver.Sat m) when Term.Tbl.length self.in_model > 0 ->
Log.debugf 50 (fun k ->
k "(@[lra.in_model@ %a@])"
(Util.pp_iter Term.pp_debug)
k "(@[lra.in_model@ %a@])" (Util.pp_iter Term.pp)
(Term.Tbl.keys self.in_model));
let add_t t () =
@ -729,6 +732,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
SMT.Registry.set (SI.registry si) k_state st;
SI.add_simplifier si (simplify st);
SI.on_preprocess si (preproc_lra st);
SI.on_find_foreign si find_foreign;
SI.on_final_check si (final_check_ st);
(* SI.on_partial_check si (partial_check_ st); *)
SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st);