From e74439cf2acb23cee28ac98003756ea5a48fb200 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 1 Sep 2022 22:34:27 -0400 Subject: [PATCH] wip: new attempt at theory combination --- src/base/Sidekick_base.ml | 4 +- src/base/dune | 3 +- src/main/main.ml | 2 +- src/smt/solver.ml | 1 + src/smt/solver.mli | 2 + src/smt/solver_internal.ml | 68 ++++++++++++++++++++++++++++++--- src/smt/solver_internal.mli | 13 +++++-- src/smtlib/Process.ml | 6 ++- src/smtlib/Process.mli | 2 +- src/th-data/Sidekick_th_data.ml | 6 ++- src/th-lra/sidekick_th_lra.ml | 45 +++++++++++----------- 11 files changed, 112 insertions(+), 40 deletions(-) diff --git a/src/base/Sidekick_base.ml b/src/base/Sidekick_base.ml index 1411d4c2..5d7df622 100644 --- a/src/base/Sidekick_base.ml +++ b/src/base/Sidekick_base.ml @@ -33,7 +33,7 @@ module LRA_term = LRA_term module Th_data = Th_data module Th_bool = Th_bool module Th_lra = Th_lra -module Th_uf = Th_uf +module Th_ty_unin = Th_ty_unin let k_th_bool_config = Th_bool.k_config let th_bool = Th_bool.theory @@ -41,4 +41,4 @@ let th_bool_dyn : Solver.theory = Th_bool.theory_dyn let th_bool_static : Solver.theory = Th_bool.theory_static let th_data : Solver.theory = Th_data.theory let th_lra : Solver.theory = Th_lra.theory -let th_uf : Solver.theory = Th_uf.theory +let th_ty_unin : Solver.theory = Th_ty_unin.theory diff --git a/src/base/dune b/src/base/dune index 1e1c0c7c..a55ab04b 100644 --- a/src/base/dune +++ b/src/base/dune @@ -4,5 +4,6 @@ (synopsis "Base term definitions for the standalone SMT solver and library") (libraries containers iter sidekick.core sidekick.util sidekick.smt-solver sidekick.cc sidekick.quip sidekick.th-lra sidekick.th-bool-static - sidekick.th-bool-dyn sidekick.th-data sidekick.zarith zarith) + sidekick.th-ty-unin sidekick.th-bool-dyn sidekick.th-data sidekick.zarith + zarith) (flags :standard -w +32 -open Sidekick_util)) diff --git a/src/main/main.ml b/src/main/main.ml index 625897a5..87ce53d0 100644 --- a/src/main/main.ml +++ b/src/main/main.ml @@ -177,7 +177,7 @@ let main_smt ~config () : _ result = Log.debugf 1 (fun k -> k "(@[main.th-bool.pick@ %S@])" (Sidekick_smt_solver.Theory.name th_bool)); - [ th_bool; Process.th_uf; Process.th_data; Process.th_lra ] + [ th_bool; Process.th_ty_unin; Process.th_data; Process.th_lra ] in Process.Solver.create_default ~proof ~theories tst in diff --git a/src/smt/solver.ml b/src/smt/solver.ml index 235967f6..673a3dc4 100644 --- a/src/smt/solver.ml +++ b/src/smt/solver.ml @@ -168,6 +168,7 @@ let assert_terms self c = add_clause_l self c pr_c let assert_term self t = assert_terms self [ t ] +let add_ty (self : t) ty = SI.add_ty self.si ~ty let solve ?(on_exit = []) ?(check = true) ?(on_progress = fun _ -> ()) ?(should_stop = fun _ _ -> false) ~assumptions (self : t) : res = diff --git a/src/smt/solver.mli b/src/smt/solver.mli index dd08c07e..41024d78 100644 --- a/src/smt/solver.mli +++ b/src/smt/solver.mli @@ -111,6 +111,8 @@ val assert_term : t -> term -> unit (** Helper that turns the term into an atom, before adding the result to the solver as a unit clause assertion *) +val add_ty : t -> ty -> unit + (** Result of solving for the current set of clauses *) type res = | Sat of Model.t (** Satisfiable *) diff --git a/src/smt/solver_internal.ml b/src/smt/solver_internal.ml index 2596b227..fe2216bf 100644 --- a/src/smt/solver_internal.ml +++ b/src/smt/solver_internal.ml @@ -38,7 +38,9 @@ type t = { cc: CC.t; (** congruence closure *) proof: proof_trace; (** proof logger *) registry: Registry.t; + seen_types: Term.Weak_set.t; (** types we've seen so far *) on_progress: (unit, unit) Event.Emitter.t; + on_new_ty: (ty, unit) Event.Emitter.t; th_comb: Th_combination.t; mutable on_partial_check: (t -> theory_actions -> lit Iter.t -> unit) list; mutable on_final_check: (t -> theory_actions -> lit Iter.t -> unit) list; @@ -82,7 +84,9 @@ let[@inline] has_delayed_actions self = not (Queue.is_empty self.delayed_actions) let on_preprocess self f = self.preprocess <- f :: self.preprocess -let claim_term self ~th_id t = Th_combination.claim_term self.th_comb ~th_id t + +let claim_sort self ~th_id ~ty = + Th_combination.claim_sort self.th_comb ~th_id ~ty let on_model ?ask ?complete self = Option.iter (fun f -> self.model_ask <- f :: self.model_ask) ask; @@ -135,6 +139,14 @@ let preprocess_term_ (self : t) (t0 : term) : unit = if not (Term.Tbl.mem self.preprocessed t) then ( Term.Tbl.add self.preprocessed t (); + (* see if this is a new type *) + let ty = Term.ty t in + if not (Term.Weak_set.mem self.seen_types ty) then ( + Log.debugf 5 (fun k -> k "(@[solver.seen-new-type@ %a@])" Term.pp ty); + Term.Weak_set.add self.seen_types ty; + Event.Emitter.emit self.on_new_ty ty + ); + (* process sub-terms first *) Term.iter_shallow t ~f:(fun _inb u -> preproc_rec_ u); @@ -166,6 +178,7 @@ let preprocess_term_ (self : t) (t0 : term) : unit = let simplify_and_preproc_lit (self : t) (lit : Lit.t) : Lit.t * step_id option = let t = Lit.term lit in let sign = Lit.sign lit in + let u, pr = match simplify_t self t with | None -> t, None @@ -274,6 +287,12 @@ let[@inline] add_clause_permanent self _acts c (proof : step_id) : unit = let[@inline] mk_lit self ?sign t : lit = Lit.atom ?sign self.tst t +let add_ty self ~ty : unit = + if not (Term.Weak_set.mem self.seen_types ty) then ( + Term.Weak_set.add self.seen_types ty; + Event.Emitter.emit self.on_new_ty ty + ) + let[@inline] add_lit self _acts ?default_pol lit = delayed_add_lit self ?default_pol lit @@ -288,6 +307,7 @@ let on_partial_check self f = self.on_partial_check <- f :: self.on_partial_check let on_progress self = Event.of_emitter self.on_progress +let on_new_ty self = Event.of_emitter self.on_new_ty let on_cc_new_term self f = Event.on (CC.on_new_term (cc self)) ~f let on_cc_pre_merge self f = Event.on (CC.on_pre_merge (cc self)) ~f let on_cc_post_merge self f = Event.on (CC.on_post_merge (cc self)) ~f @@ -334,6 +354,9 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = let model = Model_builder.create tst in + Model_builder.add model (Term.true_ tst) (Term.true_ tst); + Model_builder.add model (Term.false_ tst) (Term.false_ tst); + (* first, add all literals to the model using the given propositional model induced by the trail [lits]. *) lits (fun lit -> @@ -363,6 +386,9 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = let rec compute_fixpoint () = match MB.pop_required model with | None -> () + | Some t when Term.is_pi (Term.ty t) -> + (* TODO: when we support lambdas? *) + () | Some t -> (* compute a value for [t] *) Log.debugf 5 (fun k -> @@ -371,11 +397,9 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = (* try each model hook *) let rec try_hooks_ = function | [] -> - let c = MB.gensym model ~pre:"@c" ~ty:(Term.ty t) in - Log.debugf 10 (fun k -> - k "(@[model.fixpoint.pick-default-val@ %a@ :for %a@])" Term.pp c - Term.pp t); - MB.add model t c + (* should not happen *) + Error.errorf "cannot build a value for term@ `%a`@ of type `%a`" + Term.pp t Term.pp (Term.ty t) | h :: hooks -> (match h self model t with | None -> try_hooks_ hooks @@ -392,6 +416,32 @@ let mk_model_ (self : t) (lits : lit Iter.t) : Model.t = compute_fixpoint (); MB.to_model model +(* theory combination: find terms occurring as foreign variables in + other terms *) +let theory_comb_register_new_term (self : t) (t : term) : unit = + Log.debugf 50 (fun k -> k "(@[solver.th-comb-register@ %a@])" Term.pp t); + match Th_combination.claimed_by self.th_comb ~ty:(Term.ty t) with + | None -> () + | Some theory_for_t -> + let args = + let _f, args = Term.unfold_app t in + match Term.view _f, args, Term.view t with + | Term.E_const { Const.c_ops = (module OP); c_view; _ }, _, _ + when OP.opaque_to_cc c_view -> + [] + | _, [], Term.E_app_fold { args; _ } -> args + | _ -> args + in + List.iter + (fun arg -> + match Th_combination.claimed_by self.th_comb ~ty:(Term.ty arg) with + | Some theory_for_arg + when not (Theory_id.equal theory_for_t theory_for_arg) -> + (* [arg] is foreign *) + Th_combination.add_term_needing_combination self.th_comb arg + | _ -> ()) + args + (* call congruence closure, perform the actions it scheduled *) let check_cc_with_acts_ (self : t) (acts : theory_actions) = let (module A) = acts in @@ -529,8 +579,10 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t = stat; simp = Simplify.create tst ~proof; last_model = None; + seen_types = Term.Weak_set.create 8; th_comb = Th_combination.create ~stat tst; on_progress = Event.Emitter.create (); + on_new_ty = Event.Emitter.create (); preprocess = []; model_ask = []; model_complete = []; @@ -547,4 +599,8 @@ let create (module A : ARG) ~stat ~proof (tst : Term.store) () : t = complete = true; } in + (* observe new terms in the CC *) + on_cc_new_term self (fun (_, _, t) -> + theory_comb_register_new_term self t; + []); self diff --git a/src/smt/solver_internal.mli b/src/smt/solver_internal.mli index aee6ec58..b0b28280 100644 --- a/src/smt/solver_internal.mli +++ b/src/smt/solver_internal.mli @@ -98,9 +98,10 @@ val preprocess_clause_array : t -> lit array -> step_id -> lit array * step_id val simplify_and_preproc_lit : t -> lit -> lit * step_id option (** Simplify literal then preprocess it *) -val claim_term : t -> th_id:Theory_id.t -> term -> unit -(** Claim a term, for a theory that might decide or merge it with another - term. This is useful for theory combination. *) +val claim_sort : t -> th_id:Theory_id.t -> ty:ty -> unit +(** Claim a sort, to be called by the theory with id [th_id] which is + responsible for this sort in models. This is useful for theory combination. + *) (** {3 hooks for the theory} *) @@ -130,6 +131,9 @@ val add_clause_permanent : t -> theory_actions -> lit list -> step_id -> unit (** Add toplevel clause to the SAT solver. This clause will not be backtracked. *) +val add_ty : t -> ty:term -> unit +(** Declare a sort for the SMT solver *) + val mk_lit : t -> ?sign:bool -> term -> lit (** Create a literal. This automatically preprocesses the term. *) @@ -204,6 +208,9 @@ val on_cc_propagate : unit (** Callback called on every CC propagation *) +val on_new_ty : t -> (ty, unit) Event.t +(** Add a callback for when new types are added via {!add_ty} *) + val on_partial_check : t -> (t -> theory_actions -> lit Iter.t -> unit) -> unit (** Register callbacked to be called with the slice of literals newly added on the trail. diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index 1ad041e6..88a94ab9 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -332,7 +332,9 @@ let process_stmt ?gc ?restarts ?(pp_cnf = false) ?proof_file ?pp_model Fmt.printf "(@[%a@])@." (Util.pp_list pp_pair) l | _ -> Error.errorf "cannot access model"); E.return () - | Statement.Stmt_data _ -> E.return () + | Statement.Stmt_data ds -> + List.iter (fun d -> Solver.add_ty solver (Data_ty.data_as_ty d)) ds; + E.return () | Statement.Stmt_define _ -> Error.errorf "cannot deal with definitions yet" open Sidekick_base @@ -342,4 +344,4 @@ let th_bool_dyn : Solver.theory = Th_bool.theory_dyn let th_bool_static : Solver.theory = Th_bool.theory_static let th_data : Solver.theory = Th_data.theory let th_lra : Solver.theory = Th_lra.theory -let th_uf = Th_uf.theory +let th_ty_unin = Th_ty_unin.theory diff --git a/src/smtlib/Process.mli b/src/smtlib/Process.mli index 54bcb71c..8e0501e0 100644 --- a/src/smtlib/Process.mli +++ b/src/smtlib/Process.mli @@ -8,7 +8,7 @@ val th_bool_static : Solver.theory val th_bool : Config.t -> Solver.theory val th_data : Solver.theory val th_lra : Solver.theory -val th_uf : Solver.theory +val th_ty_unin : Solver.theory type 'a or_error = ('a, string) CCResult.t diff --git a/src/th-data/Sidekick_th_data.ml b/src/th-data/Sidekick_th_data.ml index 805efc78..fcf6a368 100644 --- a/src/th-data/Sidekick_th_data.ml +++ b/src/th-data/Sidekick_th_data.ml @@ -465,7 +465,8 @@ end = struct | T_cstor _ | T_other _ -> [] let on_is_subterm (self : t) (si : SI.t) (_cc, _repr, t) : _ list = - if is_data_ty (Term.ty t) then SI.claim_term si ~th_id:self.th_id t; + if is_data_ty (Term.ty t) then + SI.claim_sort si ~th_id:self.th_id ~ty:(Term.ty t); [] let cstors_of_ty (ty : ty) : A.Cstor.t list = @@ -788,6 +789,9 @@ end = struct Some (c, args)) | None -> None + (* TODO: event/function to declare new datatypes, so we can claim them + early *) + let create_and_setup ~id:th_id (solver : SI.t) : t = let self = { diff --git a/src/th-lra/sidekick_th_lra.ml b/src/th-lra/sidekick_th_lra.ml index 58ee88c1..9f36e2a8 100644 --- a/src/th-lra/sidekick_th_lra.ml +++ b/src/th-lra/sidekick_th_lra.ml @@ -298,15 +298,6 @@ module Make (A : ARG) = (* : S with module A = A *) struct proxy, A.Q.one) - (* look for subterms of type Real, for they will need theory combination *) - let on_subterm (self : state) (si : SI.t) (t : Term.t) : unit = - Log.debugf 50 (fun k -> k "(@[lra.cc-on-subterm@ %a@])" Term.pp_debug t); - match A.view_as_lra t with - | LRA_other _ when not (A.has_ty_real t) -> () - | LRA_pred _ | LRA_const _ -> () - | LRA_op _ | LRA_other _ | LRA_mult _ -> - SI.claim_term si ~th_id:self.th_id t - (* preprocess linear expressions away *) let preproc_lra (self : state) si (module PA : SI.PREPROCESS_ACTS) (t : Term.t) : unit = @@ -314,11 +305,10 @@ module Make (A : ARG) = (* : S with module A = A *) struct let tst = SI.tst si in (* tell the CC this term exists *) - let declare_term_to_cc ~sub t = + let declare_term_to_cc ~sub:_ t = Log.debugf 50 (fun k -> k "(@[lra.declare-term-to-cc@ %a@])" Term.pp_debug t); - ignore (CC.add_term (SI.cc si) t : E_node.t); - if sub then on_subterm self si t + ignore (CC.add_term (SI.cc si) t : E_node.t) in match A.view_as_lra t with @@ -672,14 +662,25 @@ module Make (A : ARG) = (* : S with module A = A *) struct (* help generating model *) let model_ask_ (self : state) _si _model (t : Term.t) : _ option = - match self.last_res with - | Some (SimpSolver.Sat m) -> - Log.debugf 50 (fun k -> k "(@[lra.model-ask@ %a@])" Term.pp_debug t); - (match A.view_as_lra t with - | LRA_const n -> Some n (* always eval constants to themselves *) - | _ -> SimpSolver.V_map.get t m) - |> Option.map (fun t -> t_const self t, []) - | _ -> None + let res = + match self.last_res with + | Some (SimpSolver.Sat m) -> + Log.debugf 50 (fun k -> k "(@[lra.model-ask@ %a@])" Term.pp_debug t); + (match A.view_as_lra t with + | LRA_const n -> Some n (* always eval constants to themselves *) + | _ -> SimpSolver.V_map.get t m) + |> Option.map (fun t -> t_const self t, []) + | _ -> None + in + match res with + | Some _ -> res + | None when A.has_ty_real t -> + (* last resort: return 0 *) + (* NOTE: this should go away maybe? no term should escape the LRA model… *) + Log.debugf 0 (fun k -> k "MODEL TY REAL DEFAULT %a" Term.pp t); + let zero = A.mk_lra self.tst (LRA_const A.Q.zero) in + Some (zero, []) + | None -> None (* help generating model *) let model_complete_ (self : state) _si ~add : unit = @@ -710,9 +711,7 @@ module Make (A : ARG) = (* : S with module A = A *) struct SI.on_final_check si (final_check_ st); (* SI.on_partial_check si (partial_check_ st); *) SI.on_model si ~ask:(model_ask_ st) ~complete:(model_complete_ st); - SI.on_cc_is_subterm si (fun (_, _, t) -> - on_subterm st si t; - []); + SI.claim_sort si ~th_id:id ~ty:(A.ty_real (SI.tst si)); SI.on_cc_pre_merge si (fun (_cc, n1, n2, expl) -> match as_const_ (E_node.term n1), as_const_ (E_node.term n2) with | Some q1, Some q2 when A.Q.(q1 <> q2) ->