wip: new attempt at theory combination

This commit is contained in:
Simon Cruanes 2022-09-01 22:34:27 -04:00
parent 47a0b075f0
commit e74439cf2a
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
11 changed files with 112 additions and 40 deletions

View file

@ -33,7 +33,7 @@ module LRA_term = LRA_term
module Th_data = Th_data module Th_data = Th_data
module Th_bool = Th_bool module Th_bool = Th_bool
module Th_lra = Th_lra module Th_lra = Th_lra
module Th_uf = Th_uf module Th_ty_unin = Th_ty_unin
let k_th_bool_config = Th_bool.k_config let k_th_bool_config = Th_bool.k_config
let th_bool = Th_bool.theory let th_bool = Th_bool.theory
@ -41,4 +41,4 @@ let th_bool_dyn : Solver.theory = Th_bool.theory_dyn
let th_bool_static : Solver.theory = Th_bool.theory_static let th_bool_static : Solver.theory = Th_bool.theory_static
let th_data : Solver.theory = Th_data.theory let th_data : Solver.theory = Th_data.theory
let th_lra : Solver.theory = Th_lra.theory let th_lra : Solver.theory = Th_lra.theory
let th_uf : Solver.theory = Th_uf.theory let th_ty_unin : Solver.theory = Th_ty_unin.theory

View file

@ -4,5 +4,6 @@
(synopsis "Base term definitions for the standalone SMT solver and library") (synopsis "Base term definitions for the standalone SMT solver and library")
(libraries containers iter sidekick.core sidekick.util sidekick.smt-solver (libraries containers iter sidekick.core sidekick.util sidekick.smt-solver
sidekick.cc sidekick.quip sidekick.th-lra sidekick.th-bool-static sidekick.cc sidekick.quip sidekick.th-lra sidekick.th-bool-static
sidekick.th-bool-dyn sidekick.th-data sidekick.zarith zarith) sidekick.th-ty-unin sidekick.th-bool-dyn sidekick.th-data sidekick.zarith
zarith)
(flags :standard -w +32 -open Sidekick_util)) (flags :standard -w +32 -open Sidekick_util))

View file

@ -177,7 +177,7 @@ let main_smt ~config () : _ result =
Log.debugf 1 (fun k -> Log.debugf 1 (fun k ->
k "(@[main.th-bool.pick@ %S@])" k "(@[main.th-bool.pick@ %S@])"
(Sidekick_smt_solver.Theory.name th_bool)); (Sidekick_smt_solver.Theory.name th_bool));
[ th_bool; Process.th_uf; Process.th_data; Process.th_lra ] [ th_bool; Process.th_ty_unin; Process.th_data; Process.th_lra ]
in in
Process.Solver.create_default ~proof ~theories tst Process.Solver.create_default ~proof ~theories tst
in in

View file

@ -168,6 +168,7 @@ let assert_terms self c =
add_clause_l self c pr_c add_clause_l self c pr_c
let assert_term self t = assert_terms self [ t ] let assert_term self t = assert_terms self [ t ]
let add_ty (self : t) ty = SI.add_ty self.si ~ty
let solve ?(on_exit = []) ?(check = true) ?(on_progress = fun _ -> ()) let solve ?(on_exit = []) ?(check = true) ?(on_progress = fun _ -> ())
?(should_stop = fun _ _ -> false) ~assumptions (self : t) : res = ?(should_stop = fun _ _ -> false) ~assumptions (self : t) : res =

View file

@ -111,6 +111,8 @@ val assert_term : t -> term -> unit
(** Helper that turns the term into an atom, before adding the result (** Helper that turns the term into an atom, before adding the result
to the solver as a unit clause assertion *) to the solver as a unit clause assertion *)
val add_ty : t -> ty -> unit
(** Result of solving for the current set of clauses *) (** Result of solving for the current set of clauses *)
type res = type res =
| Sat of Model.t (** Satisfiable *) | Sat of Model.t (** Satisfiable *)

View file

@ -38,7 +38,9 @@ type t = {
cc: CC.t; (** congruence closure *) cc: CC.t; (** congruence closure *)
proof: proof_trace; (** proof logger *) proof: proof_trace; (** proof logger *)
registry: Registry.t; registry: Registry.t;
seen_types: Term.Weak_set.t; (** types we've seen so far *)
on_progress: (unit, unit) Event.Emitter.t; on_progress: (unit, unit) Event.Emitter.t;
on_new_ty: (ty, unit) Event.Emitter.t;
th_comb: Th_combination.t; th_comb: Th_combination.t;
mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list;
mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list;
@ -82,7 +84,9 @@ let[@inline] has_delayed_actions self =
not (Queue.is_empty self.delayed_actions) not (Queue.is_empty self.delayed_actions)
let on_preprocess self f = self.preprocess <- f :: self.preprocess let on_preprocess self f = self.preprocess <- f :: self.preprocess
let claim_term self ~th_id t = Th_combination.claim_term self.th_comb ~th_id t
let claim_sort self ~th_id ~ty =
Th_combination.claim_sort self.th_comb ~th_id ~ty
let on_model ?ask ?complete self = let on_model ?ask ?complete self =
Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask; Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask;
@ -135,6 +139,14 @@ let preprocess_term_ (self : t) (t0 : term) : unit =
if not (Term.Tbl.mem self.preprocessed t) then ( if not (Term.Tbl.mem self.preprocessed t) then (
Term.Tbl.add self.preprocessed t (); Term.Tbl.add self.preprocessed t ();
(* see if this is a new type *)
let ty = Term.ty t in
if not (Term.Weak_set.mem self.seen_types ty) then (
Log.debugf 5 (fun k -> k "(@[solver.seen-new-type@ %a@])" Term.pp ty);
Term.Weak_set.add self.seen_types ty;
Event.Emitter.emit self.on_new_ty ty
);
(* process sub-terms first *) (* process sub-terms first *)
Term.iter_shallow t ~f:(fun _inb u -> preproc_rec_ u); Term.iter_shallow t ~f:(fun _inb u -> preproc_rec_ u);
@ -166,6 +178,7 @@ let preprocess_term_ (self : t) (t0 : term) : unit =
let simplify_and_preproc_lit (self : t) (lit : Lit.t) : Lit.t * step_id option = let simplify_and_preproc_lit (self : t) (lit : Lit.t) : Lit.t * step_id option =
let t = Lit.term lit in let t = Lit.term lit in
let sign = Lit.sign lit in let sign = Lit.sign lit in
let u, pr = let u, pr =
match simplify_t self t with match simplify_t self t with
| None -> t, None | None -> t, None
@ -274,6 +287,12 @@ let[@inline] add_clause_permanent self _acts c (proof : step_id) : unit =
let[@inline] mk_lit self ?sign t : lit = Lit.atom ?sign self.tst t let[@inline] mk_lit self ?sign t : lit = Lit.atom ?sign self.tst t
let add_ty self ~ty : unit =
if not (Term.Weak_set.mem self.seen_types ty) then (
Term.Weak_set.add self.seen_types ty;
Event.Emitter.emit self.on_new_ty ty
)
let[@inline] add_lit self _acts ?default_pol lit = let[@inline] add_lit self _acts ?default_pol lit =
delayed_add_lit self ?default_pol lit delayed_add_lit self ?default_pol lit
@ -288,6 +307,7 @@ let on_partial_check self f =
self.on_partial_check <- f :: self.on_partial_check self.on_partial_check <- f :: self.on_partial_check
let on_progress self = Event.of_emitter self.on_progress let on_progress self = Event.of_emitter self.on_progress
let on_new_ty self = Event.of_emitter self.on_new_ty
let on_cc_new_term self f = Event.on (CC.on_new_term (cc self)) ~f let on_cc_new_term self f = Event.on (CC.on_new_term (cc self)) ~f
let on_cc_pre_merge self f = Event.on (CC.on_pre_merge (cc self)) ~f let on_cc_pre_merge self f = Event.on (CC.on_pre_merge (cc self)) ~f
let on_cc_post_merge self f = Event.on (CC.on_post_merge (cc self)) ~f let on_cc_post_merge self f = Event.on (CC.on_post_merge (cc self)) ~f
@ -334,6 +354,9 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
let model = Model_builder.create tst in let model = Model_builder.create tst in
Model_builder.add model (Term.true_ tst) (Term.true_ tst);
Model_builder.add model (Term.false_ tst) (Term.false_ tst);
(* first, add all literals to the model using the given propositional model (* first, add all literals to the model using the given propositional model
induced by the trail [lits]. *) induced by the trail [lits]. *)
lits (fun lit -> lits (fun lit ->
@ -363,6 +386,9 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
let rec compute_fixpoint () = let rec compute_fixpoint () =
match MB.pop_required model with match MB.pop_required model with
| None -> () | None -> ()
| Some t when Term.is_pi (Term.ty t) ->
(* TODO: when we support lambdas? *)
()
| Some t -> | Some t ->
(* compute a value for [t] *) (* compute a value for [t] *)
Log.debugf 5 (fun k -> Log.debugf 5 (fun k ->
@ -371,11 +397,9 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
(* try each model hook *) (* try each model hook *)
let rec try_hooks_ = function let rec try_hooks_ = function
| [] -> | [] ->
let c = MB.gensym model ~pre:"@c" ~ty:(Term.ty t) in (* should not happen *)
Log.debugf 10 (fun k -> Error.errorf "cannot build a value for term@ `%a`@ of type `%a`"
k "(@[model.fixpoint.pick-default-val@ %a@ :for %a@])" Term.pp c Term.pp t Term.pp (Term.ty t)
Term.pp t);
MB.add model t c
| h :: hooks -> | h :: hooks ->
(match h self model t with (match h self model t with
| None -> try_hooks_ hooks | None -> try_hooks_ hooks
@ -392,6 +416,32 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t =
compute_fixpoint (); compute_fixpoint ();
MB.to_model model MB.to_model model
(* theory combination: find terms occurring as foreign variables in
other terms *)
let theory_comb_register_new_term (self : t) (t : term) : unit =
Log.debugf 50 (fun k -> k "(@[solver.th-comb-register@ %a@])" Term.pp t);
match Th_combination.claimed_by self.th_comb ~ty:(Term.ty t) with
| None -> ()
| Some theory_for_t ->
let args =
let _f, args = Term.unfold_app t in
match Term.view _f, args, Term.view t with
| Term.E_const { Const.c_ops = (module OP); c_view; _ }, _, _
when OP.opaque_to_cc c_view ->
[]
| _, [], Term.E_app_fold { args; _ } -> args
| _ -> args
in
List.iter
(fun arg ->
match Th_combination.claimed_by self.th_comb ~ty:(Term.ty arg) with
| Some theory_for_arg
when not (Theory_id.equal theory_for_t theory_for_arg) ->
(* [arg] is foreign *)
Th_combination.add_term_needing_combination self.th_comb arg
| _ -> ())
args
(* call congruence closure, perform the actions it scheduled *) (* call congruence closure, perform the actions it scheduled *)
let check_cc_with_acts_ (self : t) (acts : theory_actions) = let check_cc_with_acts_ (self : t) (acts : theory_actions) =
let (module A) = acts in let (module A) = acts in
@ -529,8 +579,10 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
stat; stat;
simp = Simplify.create tst ~proof; simp = Simplify.create tst ~proof;
last_model = None; last_model = None;
seen_types = Term.Weak_set.create 8;
th_comb = Th_combination.create ~stat tst; th_comb = Th_combination.create ~stat tst;
on_progress = Event.Emitter.create (); on_progress = Event.Emitter.create ();
on_new_ty = Event.Emitter.create ();
preprocess = []; preprocess = [];
model_ask = []; model_ask = [];
model_complete = []; model_complete = [];
@ -547,4 +599,8 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t =
complete = true; complete = true;
} }
in in
(* observe new terms in the CC *)
on_cc_new_term self (fun (_, _, t) ->
theory_comb_register_new_term self t;
[]);
self self

View file

@ -98,9 +98,10 @@ val preprocess_clause_array : t -> lit array -> step_id -> lit array * step_id
val simplify_and_preproc_lit : t -> lit -> lit * step_id option val simplify_and_preproc_lit : t -> lit -> lit * step_id option
(** Simplify literal then preprocess it *) (** Simplify literal then preprocess it *)
val claim_term : t -> th_id:Theory_id.t -> term -> unit val claim_sort : t -> th_id:Theory_id.t -> ty:ty -> unit
(** Claim a term, for a theory that might decide or merge it with another (** Claim a sort, to be called by the theory with id [th_id] which is
term. This is useful for theory combination. *) responsible for this sort in models. This is useful for theory combination.
*)
(** {3 hooks for the theory} *) (** {3 hooks for the theory} *)
@ -130,6 +131,9 @@ val add_clause_permanent : t -> theory_actions -> lit list -> step_id -> unit
(** Add toplevel clause to the SAT solver. This clause will (** Add toplevel clause to the SAT solver. This clause will
not be backtracked. *) not be backtracked. *)
val add_ty : t -> ty:term -> unit
(** Declare a sort for the SMT solver *)
val mk_lit : t -> ?sign:bool -> term -> lit val mk_lit : t -> ?sign:bool -> term -> lit
(** Create a literal. This automatically preprocesses the term. *) (** Create a literal. This automatically preprocesses the term. *)
@ -204,6 +208,9 @@ val on_cc_propagate :
unit unit
(** Callback called on every CC propagation *) (** Callback called on every CC propagation *)
val on_new_ty : t -> (ty, unit) Event.t
(** Add a callback for when new types are added via {!add_ty} *)
val on_partial_check : t -> (t -> theory_actions -> lit Iter.t -> unit) -> unit val on_partial_check : t -> (t -> theory_actions -> lit Iter.t -> unit) -> unit
(** Register callbacked to be called with the slice of literals (** Register callbacked to be called with the slice of literals
newly added on the trail. newly added on the trail.

View file

@ -332,7 +332,9 @@ let process_stmt ?gc ?restarts ?(pp_cnf = false) ?proof_file ?pp_model
Fmt.printf "(@[%a@])@." (Util.pp_list pp_pair) l Fmt.printf "(@[%a@])@." (Util.pp_list pp_pair) l
| _ -> Error.errorf "cannot access model"); | _ -> Error.errorf "cannot access model");
E.return () E.return ()
| Statement.Stmt_data _ -> E.return () | Statement.Stmt_data ds ->
List.iter (fun d -> Solver.add_ty solver (Data_ty.data_as_ty d)) ds;
E.return ()
| Statement.Stmt_define _ -> Error.errorf "cannot deal with definitions yet" | Statement.Stmt_define _ -> Error.errorf "cannot deal with definitions yet"
open Sidekick_base open Sidekick_base
@ -342,4 +344,4 @@ let th_bool_dyn : Solver.theory = Th_bool.theory_dyn
let th_bool_static : Solver.theory = Th_bool.theory_static let th_bool_static : Solver.theory = Th_bool.theory_static
let th_data : Solver.theory = Th_data.theory let th_data : Solver.theory = Th_data.theory
let th_lra : Solver.theory = Th_lra.theory let th_lra : Solver.theory = Th_lra.theory
let th_uf = Th_uf.theory let th_ty_unin = Th_ty_unin.theory

View file

@ -8,7 +8,7 @@ val th_bool_static : Solver.theory
val th_bool : Config.t -> Solver.theory val th_bool : Config.t -> Solver.theory
val th_data : Solver.theory val th_data : Solver.theory
val th_lra : Solver.theory val th_lra : Solver.theory
val th_uf : Solver.theory val th_ty_unin : Solver.theory
type 'a or_error = ('a, string) CCResult.t type 'a or_error = ('a, string) CCResult.t

View file

@ -465,7 +465,8 @@ end = struct
| T_cstor _ | T_other _ -> [] | T_cstor _ | T_other _ -> []
let on_is_subterm (self : t) (si : SI.t) (_cc, _repr, t) : _ list = let on_is_subterm (self : t) (si : SI.t) (_cc, _repr, t) : _ list =
if is_data_ty (Term.ty t) then SI.claim_term si ~th_id:self.th_id t; if is_data_ty (Term.ty t) then
SI.claim_sort si ~th_id:self.th_id ~ty:(Term.ty t);
[] []
let cstors_of_ty (ty : ty) : A.Cstor.t list = let cstors_of_ty (ty : ty) : A.Cstor.t list =
@ -788,6 +789,9 @@ end = struct
Some (c, args)) Some (c, args))
| None -> None | None -> None
(* TODO: event/function to declare new datatypes, so we can claim them
early *)
let create_and_setup ~id:th_id (solver : SI.t) : t = let create_and_setup ~id:th_id (solver : SI.t) : t =
let self = let self =
{ {

View file

@ -298,15 +298,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct
proxy, A.Q.one) proxy, A.Q.one)
(* look for subterms of type Real, for they will need theory combination *)
let on_subterm (self : state) (si : SI.t) (t : Term.t) : unit =
Log.debugf 50 (fun k -> k "(@[lra.cc-on-subterm@ %a@])" Term.pp_debug t);
match A.view_as_lra t with
| LRA_other _ when not (A.has_ty_real t) -> ()
| LRA_pred _ | LRA_const _ -> ()
| LRA_op _ | LRA_other _ | LRA_mult _ ->
SI.claim_term si ~th_id:self.th_id t
(* preprocess linear expressions away *) (* preprocess linear expressions away *)
let preproc_lra (self : state) si (module PA : SI.PREPROCESS_ACTS) let preproc_lra (self : state) si (module PA : SI.PREPROCESS_ACTS)
(t : Term.t) : unit = (t : Term.t) : unit =
@ -314,11 +305,10 @@ module Make (A : ARG) = (* : S with module A = A *) struct
let tst = SI.tst si in let tst = SI.tst si in
(* tell the CC this term exists *) (* tell the CC this term exists *)
let declare_term_to_cc ~sub t = let declare_term_to_cc ~sub:_ t =
Log.debugf 50 (fun k -> Log.debugf 50 (fun k ->
k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t); k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t);
ignore (CC.add_term (SI.cc si) t : E_node.t); ignore (CC.add_term (SI.cc si) t : E_node.t)
if sub then on_subterm self si t
in in
match A.view_as_lra t with match A.view_as_lra t with
@ -672,6 +662,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
(* help generating model *) (* help generating model *)
let model_ask_ (self : state) _si _model (t : Term.t) : _ option = let model_ask_ (self : state) _si _model (t : Term.t) : _ option =
let res =
match self.last_res with match self.last_res with
| Some (SimpSolver.Sat m) -> | Some (SimpSolver.Sat m) ->
Log.debugf 50 (fun k -> k "(@[lra.model-ask@ %a@])" Term.pp_debug t); Log.debugf 50 (fun k -> k "(@[lra.model-ask@ %a@])" Term.pp_debug t);
@ -680,6 +671,16 @@ module Make (A : ARG) = (* : S with module A = A *) struct
| _ -> SimpSolver.V_map.get t m) | _ -> SimpSolver.V_map.get t m)
|> Option.map (fun t -> t_const self t, []) |> Option.map (fun t -> t_const self t, [])
| _ -> None | _ -> None
in
match res with
| Some _ -> res
| None when A.has_ty_real t ->
(* last resort: return 0 *)
(* NOTE: this should go away maybe? no term should escape the LRA model… *)
Log.debugf 0 (fun k -> k "MODEL TY REAL DEFAULT %a" Term.pp t);
let zero = A.mk_lra self.tst (LRA_const A.Q.zero) in
Some (zero, [])
| None -> None
(* help generating model *) (* help generating model *)
let model_complete_ (self : state) _si ~add : unit = let model_complete_ (self : state) _si ~add : unit =
@ -710,9 +711,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct
SI.on_final_check si (final_check_ st); SI.on_final_check si (final_check_ st);
(* SI.on_partial_check si (partial_check_ st); *) (* SI.on_partial_check si (partial_check_ st); *)
SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st); SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st);
SI.on_cc_is_subterm si (fun (_, _, t) -> SI.claim_sort si ~th_id:id ~ty:(A.ty_real (SI.tst si));
on_subterm st si t;
[]);
SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) -> SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) ->
match as_const_ (E_node.term n1), as_const_ (E_node.term n2) with match as_const_ (E_node.term n1), as_const_ (E_node.term n2) with
| Some q1, Some q2 when A.Q.(q1 <> q2) -> | Some q1, Some q2 when A.Q.(q1 <> q2) ->