From 221ed7dcdb9fa3f468fde17328cf530651bf3a4d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 1 Feb 2018 22:44:40 -0600 Subject: [PATCH] continue large refactoring, progress in theory combination - first draft of theory combination - theory interface - have the project compile --- TODO.md | 27 ++ src/core/CDCL.ml | 4 +- src/core/Internal.ml | 16 +- src/core/Res.ml | 2 +- src/core/Res_intf.ml | 2 +- src/core/Theory_intf.ml | 18 +- src/smt/Ast.ml | 382 +++++++++++++++ src/smt/Ast.mli | 186 ++++++++ src/smt/Clause.ml | 30 ++ src/smt/Clause.mli | 12 + src/smt/Config.ml | 11 + src/smt/Config.mli | 48 ++ src/smt/Congruence_closure.ml | 133 +++--- src/smt/Congruence_closure.mli | 42 +- src/smt/Cst.ml | 1 - src/smt/Cst.mli | 1 - src/smt/Het_map.ml | 191 ++++++++ src/smt/Het_map.mli | 85 ++++ src/smt/Lit.ml | 1 - src/smt/Lit.mli | 1 - src/smt/Model.ml | 370 +++++++++++++++ src/smt/Model.mli | 29 ++ src/smt/Solver.ml | 840 +++++++++++++++++++++++++++++++++ src/smt/Solver.mli | 57 +++ src/smt/Stat.ml | 18 + src/smt/Term.ml | 3 + src/smt/Term.mli | 3 +- src/smt/Term_cell.ml | 1 - src/smt/Term_cell.mli | 1 - src/smt/Theory.ml | 60 +++ src/smt/Theory_combine.ml | 245 ++++++++++ src/smt/Theory_combine.mli | 21 + src/smt/Util.ml | 3 + src/smt/Util.mli | 2 + 34 files changed, 2724 insertions(+), 122 deletions(-) create mode 100644 src/smt/Ast.ml create mode 100644 src/smt/Ast.mli create mode 100644 src/smt/Clause.ml create mode 100644 src/smt/Clause.mli create mode 100644 src/smt/Config.ml create mode 100644 src/smt/Config.mli create mode 100644 src/smt/Het_map.ml create mode 100644 src/smt/Het_map.mli create mode 100644 src/smt/Model.ml create mode 100644 src/smt/Model.mli create mode 100644 src/smt/Solver.ml create mode 100644 src/smt/Solver.mli create mode 100644 src/smt/Stat.ml create mode 100644 src/smt/Theory.ml create mode 100644 src/smt/Theory_combine.ml create mode 100644 src/smt/Theory_combine.mli diff --git a/TODO.md b/TODO.md index ec05fdb5..060fdf6e 100644 --- a/TODO.md +++ b/TODO.md @@ -1,5 +1,26 @@ # Goals +## TODO + +- typing and translation Ast -> Term +- main executable for SMT solver +- theory of boolean constructs (on the fly Tseitin using local clauses) +- make CC work on QF_UF + * internalize terms on the fly (backtrackable) + * basic notion of activity for `ite`? +- have `CDCL.push_local` work properly + +- write Shostak theory of datatypes (without acyclicity) with local case splits +- design evaluation system (guards + `eval_bool:(term -> bool) option` in custom TC) +- compilation of rec functions to defined constants + +- Shostak theory of eq-ℚ +- datatype acyclicity check + +- abstract domain propagation in CC +- domain propagation (intervals) for ℚ arith +- full ℚ theory: shostak + domains + if-sat simplex + ## Main goals - Add a backend to send proofs to dedukti @@ -15,3 +36,9 @@ - max-sat/max-smt - coq proofs ? + +## Done + +- base types (Term, Lit, …) +- theory combination +- basic design of theories diff --git a/src/core/CDCL.ml b/src/core/CDCL.ml index b1ad22c9..069108ca 100644 --- a/src/core/CDCL.ml +++ b/src/core/CDCL.ml @@ -3,7 +3,6 @@ module Theory_intf = Theory_intf module Solver_types_intf = Solver_types_intf -module Config = Config module Res = Res @@ -42,13 +41,14 @@ type 'clause export = 'clause Solver_intf.export = { type ('form, 'proof) actions = ('form,'proof) Theory_intf.actions = Actions of { push : 'form list -> 'proof -> unit; + push_local : 'form list -> 'proof -> unit; on_backtrack: (unit -> unit) -> unit; at_level_0 : unit -> bool; + propagate : 'form -> 'form list -> 'proof -> unit; } type ('form, 'proof) slice_actions = ('form, 'proof) Theory_intf.slice_actions = Slice_acts of { slice_iter : ('form -> unit) -> unit; - slice_propagate : 'form -> 'form list -> 'proof -> unit; } module Make(E : Theory_intf.S) = Solver.Make(Solver_types.Make(E))(E) diff --git a/src/core/Internal.ml b/src/core/Internal.ml index cd600c6b..43f6e870 100644 --- a/src/core/Internal.ml +++ b/src/core/Internal.ml @@ -848,13 +848,17 @@ module Make f a.lit done - let slice_push st (l:formula list) (lemma:proof): unit = + let act_push st (l:formula list) (lemma:proof): unit = let atoms = List.rev_map (create_atom st) l in let c = Clause.make atoms (Lemma lemma) in - Log.debugf info (fun k->k "Pushing clause %a" Clause.debug c); + Log.debugf info (fun k->k "(@[sat.push_clause@ %a@])" Clause.debug c); Stack.push c st.clauses_to_add - let slice_propagate (st:t) f causes proof : unit = + (* TODO: ensure that the clause is removed upon backtracking *) + let act_push_local = act_push + + (* TODO: ensure that the clause is removed upon backtracking *) + let act_propagate (st:t) f causes proof : unit = let l = List.rev_map (mk_atom st) causes in if List.for_all (fun a -> a.is_true) l then ( let p = mk_atom st f in @@ -879,19 +883,19 @@ module Make let current_slice st = Theory_intf.Slice_acts { slice_iter = slice_iter st; - slice_propagate = slice_propagate st; } (* full slice, for [if_sat] final check *) let full_slice st = Theory_intf.Slice_acts { slice_iter = slice_iter st; - slice_propagate = slice_propagate st; } let actions st = Theory_intf.Actions { - push = slice_push st; + push = act_push st; + push_local = act_push_local st; on_backtrack = slice_on_backtrack st; at_level_0 = slice_at_level_0 st; + propagate = act_propagate st; } let create ?(size=`Big) ?st () : t = diff --git a/src/core/Res.ml b/src/core/Res.ml index 0114dbb7..3a17807a 100644 --- a/src/core/Res.ml +++ b/src/core/Res.ml @@ -16,7 +16,7 @@ module Make(St : Solver_types.S) = struct type clause = St.clause type atom = St.atom - exception Insuficient_hyps + exception Insufficient_hyps exception Resolution_error of string (* Log levels *) diff --git a/src/core/Res_intf.ml b/src/core/Res_intf.ml index fd4afff6..f88035c5 100644 --- a/src/core/Res_intf.ml +++ b/src/core/Res_intf.ml @@ -13,7 +13,7 @@ module type S = sig (** {3 Type declarations} *) - exception Insuficient_hyps + exception Insufficient_hyps (** Raised when a complete resolution derivation cannot be found using the current hypotheses. *) type formula diff --git a/src/core/Theory_intf.ml b/src/core/Theory_intf.ml index 6ab3e9e3..96472a62 100644 --- a/src/core/Theory_intf.ml +++ b/src/core/Theory_intf.ml @@ -43,24 +43,28 @@ type ('formula, 'proof) res = at any time *) type ('form, 'proof) actions = Actions of { push : 'form list -> 'proof -> unit; - (** Allows to add a clause to the solver. *) + (** Allows to add a persistent clause to the solver. *) + + push_local : 'form list -> 'proof -> unit; + (** Allows to add a local clause to the solver. The clause + will be removed after backtracking. *) on_backtrack: (unit -> unit) -> unit; (** [on_backtrack f] calls [f] when the main solver backtracks *) at_level_0 : unit -> bool; (** Are we at level 0? *) + + propagate : 'form -> 'form list -> 'proof -> unit; + (** [propagate lit causes proof] informs the solver to propagate [lit], with the reason + that the clause [causes => lit] is a theory tautology. It is faster than pushing + the associated clause but the clause will not be remembered by the sat solver, + i.e it will not be used by the solver to do boolean propagation. *) } type ('form, 'proof) slice_actions = Slice_acts of { slice_iter : ('form -> unit) -> unit; (** iterate on the slice of the trail *) - - slice_propagate : 'form -> 'form list -> 'proof -> unit; - (** [propagate lit causes proof] informs the solver to propagate [lit], with the reason - that the clause [causes => lit] is a theory tautology. It is faster than pushing - the associated clause but the clause will not be remembered by the sat solver, - i.e it will not be used by the solver to do boolean propagation. *) } (** The type for a slice. Slices are some kind of view of the current propagation queue. They allow to look at the propagated literals, diff --git a/src/smt/Ast.ml b/src/smt/Ast.ml new file mode 100644 index 00000000..71697e11 --- /dev/null +++ b/src/smt/Ast.ml @@ -0,0 +1,382 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Preprocessing AST} *) + +module Fmt = CCFormat +module S = CCSexp + +type 'a or_error = ('a, string) CCResult.t + +exception Error of string +exception Ill_typed of string + +let () = Printexc.register_printer + (function + | Error msg -> Some ("ast error: " ^ msg) + | Ill_typed msg -> Some ("ill-typed: " ^ msg) + | _ -> None) + +let errorf msg = + CCFormat.ksprintf ~f:(fun e -> raise (Error e)) msg + +(** {2 Types} *) + +module Var = struct + type 'ty t = { + id: ID.t; + ty: 'ty; + } + + let make id ty = {id;ty} + let makef ~ty fmt = + CCFormat.ksprintf fmt ~f:(fun s -> make (ID.make s) ty) + let copy {id;ty} = {ty; id=ID.copy id} + let id v = v.id + let ty v = v.ty + + let equal a b = ID.equal a.id b.id + let compare a b = ID.compare a.id b.id + let pp out v = ID.pp out v.id +end + +module Ty = struct + type t = + | Prop + | Const of ID.t + | Arrow of t * t + + let prop = Prop + let const id = Const id + let arrow a b = Arrow (a,b) + let arrow_l = List.fold_right arrow + + let to_int_ = function + | Prop -> 0 + | Const _ -> 1 + | Arrow _ -> 2 + + let () = CCOrd.() + + let rec compare a b = match a, b with + | Prop, Prop -> 0 + | Const a, Const b -> ID.compare a b + | Arrow (a1,a2), Arrow (b1,b2) -> + compare a1 b1 (compare, a2,b2) + | Prop, _ + | Const _, _ + | Arrow _, _ -> CCInt.compare (to_int_ a) (to_int_ b) + + let equal a b = compare a b = 0 + + let hash _ = 0 (* TODO *) + + let unfold ty = + let rec aux acc ty = match ty with + | Arrow (a,b) -> aux (a::acc) b + | _ -> List.rev acc, ty + in + aux [] ty + + let rec pp out = function + | Prop -> Fmt.string out "prop" + | Const id -> ID.pp out id + | Arrow _ as ty -> + let args, ret = unfold ty in + Fmt.fprintf out "(@[-> %a@ %a@])" + (Util.pp_list ~sep:" " pp) args pp ret + + (** {2 Datatypes} *) + + type data = { + data_id: ID.t; + data_cstors: t ID.Map.t; + } + + (* FIXME + let data_to_sexp d = + let cstors = + ID.Map.fold + (fun c ty acc -> + let ty_args, _ = unfold ty in + let c_sexp = match ty_args with + | [] -> ID.to_sexp c + | _::_ -> S.of_list (ID.to_sexp c :: List.map to_sexp ty_args) + in + c_sexp :: acc) + d.data_cstors [] + in + S.of_list (ID.to_sexp d.data_id :: cstors) + *) + + module Map = CCMap.Make(struct + type _t = t + type t = _t + let compare = compare + end) + + let ill_typed fmt = + CCFormat.ksprintf + ~f:(fun s -> raise (Ill_typed s)) + fmt +end + +type var = Ty.t Var.t + +type binop = + | And + | Or + | Imply + | Eq + +type binder = + | Fun + | Forall + | Exists + | Mu + +type term = { + term: term_cell; + ty: Ty.t; +} +and term_cell = + | Var of var + | Const of ID.t + | Unknown of var (* meta var *) + | App of term * term list + | If of term * term * term + | Select of select * term + | Match of term * (var list * term) ID.Map.t + | Switch of term * term ID.Map.t (* switch on constants *) + | Bind of binder * var * term + | Let of var * term * term + | Not of term + | Binop of binop * term * term + | Asserting of term * term + | Undefined_value + | Bool of bool + +and select = { + select_name: ID.t lazy_t; + select_cstor: ID.t; + select_i: int; +} + +type definition = ID.t * Ty.t * term + +type statement = + | Data of Ty.data list + | TyDecl of ID.t (* new atomic cstor *) + | Decl of ID.t * Ty.t + | Define of definition list + | Assert of term + | Goal of var list * term + +(** {2 Helper} *) + +let unfold_fun t = + let rec aux acc t = match t.term with + | Bind (Fun, v, t') -> aux (v::acc) t' + | _ -> List.rev acc, t + in + aux [] t + +(* TODO *) + +let pp_term out _ = Fmt.string out "todo:term" + +let pp_ty out _ = Fmt.string out "todo:ty" + +let pp_statement out _ = Fmt.string out "todo:stmt" + +(** {2 Constructors} *) + +let term_view t = t.term + +let rec app_ty_ ty l : Ty.t = match ty, l with + | _, [] -> ty + | Ty.Arrow (ty_a,ty_rest), a::tail -> + if Ty.equal ty_a a.ty + then app_ty_ ty_rest tail + else Ty.ill_typed "expected `@[%a@]`,@ got `@[%a : %a@]`" + Ty.pp ty_a pp_term a Ty.pp a.ty + | (Ty.Prop | Ty.Const _), a::_ -> + Ty.ill_typed "cannot apply ty `@[%a@]`@ to `@[%a@]`" Ty.pp ty pp_term a + +let mk_ term ty = {term; ty} +let ty t = t.ty + +let true_ = mk_ (Bool true) Ty.prop +let false_ = mk_ (Bool false) Ty.prop +let undefined_value ty = mk_ Undefined_value ty + +let asserting t g = + if not (Ty.equal Ty.prop g.ty) then ( + Ty.ill_typed "asserting: test must have type prop, not `@[%a@]`" Ty.pp g.ty; + ); + mk_ (Asserting (t,g)) t.ty + +let var v = mk_ (Var v) (Var.ty v) +let unknown v = mk_ (Unknown v) (Var.ty v) + +let const id ty = mk_ (Const id) ty + +let select (s:select) (t:term) ty = mk_ (Select (s,t)) ty + +let app f l = match f.term, l with + | _, [] -> f + | App (f1, l1), _ -> + let ty = app_ty_ f.ty l in + mk_ (App (f1, l1 @ l)) ty + | _ -> + let ty = app_ty_ f.ty l in + mk_ (App (f, l)) ty + +let app_a f a = app f (Array.to_list a) + +let if_ a b c = + if a.ty <> Ty.Prop + then Ty.ill_typed "if: test must have type prop, not `@[%a@]`" Ty.pp a.ty; + if not (Ty.equal b.ty c.ty) + then Ty.ill_typed + "if: both branches must have same type,@ not `@[%a@]` and `@[%a@]`" + Ty.pp b.ty Ty.pp c.ty; + mk_ (If (a,b,c)) b.ty + +let match_ t m = + let c1, (_, rhs1) = ID.Map.choose m in + ID.Map.iter + (fun c (_, rhs) -> + if not (Ty.equal rhs1.ty rhs.ty) + then Ty.ill_typed + "match: cases %a and %a disagree on return type,@ \ + between %a and %a" + ID.pp c1 ID.pp c Ty.pp rhs1.ty Ty.pp rhs.ty) + m; + mk_ (Match (t,m)) rhs1.ty + +let switch u m = + try + let _, t1 = ID.Map.choose m in + mk_ (Switch (u,m)) t1.ty + with Not_found -> + invalid_arg "Ast.switch: empty list of cases" + +let let_ v t u = + if not (Ty.equal (Var.ty v) t.ty) + then Ty.ill_typed + "let: variable %a : @[%a@]@ and bounded term : %a@ should have same type" + Var.pp v Ty.pp (Var.ty v) Ty.pp t.ty; + mk_ (Let (v,t,u)) u.ty + +let bind ~ty b v t = mk_ (Bind(b,v,t)) ty + +let fun_ v t = + let ty = Ty.arrow (Var.ty v) t.ty in + mk_ (Bind (Fun,v,t)) ty + +let quant_ q v t = + if not (Ty.equal t.ty Ty.prop) then ( + Ty.ill_typed + "quantifier: bounded term : %a@ should have type prop" + Ty.pp t.ty; + ); + let ty = Ty.prop in + mk_ (q v t) ty + +let forall = quant_ (fun v t -> Bind (Forall,v,t)) +let exists = quant_ (fun v t -> Bind (Exists,v,t)) + +let mu v t = + if not (Ty.equal (Var.ty v) t.ty) + then Ty.ill_typed "mu-term: var has type %a,@ body %a" + Ty.pp (Var.ty v) Ty.pp t.ty; + let ty = Ty.arrow (Var.ty v) t.ty in + mk_ (Bind (Fun,v,t)) ty + +let fun_l = List.fold_right fun_ +let fun_a = Array.fold_right fun_ +let forall_l = List.fold_right forall +let exists_l = List.fold_right exists + +let eq a b = + if not (Ty.equal a.ty b.ty) + then Ty.ill_typed "eq: `@[%a@]` and `@[%a@]` do not have the same type" + pp_term a pp_term b; + mk_ (Binop (Eq,a,b)) Ty.prop + +let check_prop_ t = + if not (Ty.equal t.ty Ty.prop) + then Ty.ill_typed "expected prop, got `@[%a : %a@]`" pp_term t Ty.pp t.ty + +let binop op a b = mk_ (Binop (op, a, b)) Ty.prop +let binop_prop op a b = + check_prop_ a; check_prop_ b; + binop op a b + +let and_ = binop_prop And +let or_ = binop_prop Or +let imply = binop_prop Imply + +let and_l = function + | [] -> true_ + | [f] -> f + | a :: l -> List.fold_left and_ a l + +let or_l = function + | [] -> false_ + | [f] -> f + | a :: l -> List.fold_left or_ a l + +let not_ t = + check_prop_ t; + mk_ (Not t) Ty.prop + +(** {2 Environment} *) + +type env_entry = + | E_uninterpreted_ty + | E_uninterpreted_cst (* domain element *) + | E_const of Ty.t + | E_data of Ty.t ID.Map.t (* list of cstors *) + | E_cstor of Ty.t (* datatype it belongs to *) + | E_defined of Ty.t * term (* if defined *) + +type env = { + defs: env_entry ID.Map.t; +} +(** Environment with definitions and goals *) + +let env_empty = { + defs=ID.Map.empty; +} + +let add_def id def env = { defs=ID.Map.add id def env.defs} + +let env_add_statement env st = + match st with + | Data l -> + List.fold_left + (fun env {Ty.data_id; data_cstors} -> + let map = add_def data_id (E_data data_cstors) env in + ID.Map.fold + (fun c_id c_ty map -> add_def c_id (E_cstor c_ty) map) + data_cstors map) + env l + | TyDecl id -> add_def id E_uninterpreted_ty env + | Decl (id,ty) -> add_def id (E_const ty) env + | Define l -> + List.fold_left + (fun map (id,ty,def) -> add_def id (E_defined (ty,def)) map) + env l + | Goal _ + | Assert _ -> env + +let env_of_statements seq = + Sequence.fold env_add_statement env_empty seq + +let env_find_def env id = + try Some (ID.Map.find id env.defs) + with Not_found -> None + +let env_add_def env id def = add_def id def env diff --git a/src/smt/Ast.mli b/src/smt/Ast.mli new file mode 100644 index 00000000..2a4f578c --- /dev/null +++ b/src/smt/Ast.mli @@ -0,0 +1,186 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Preprocessing AST} *) + +type 'a or_error = ('a, string) CCResult.t + +(** {2 Types} *) + +exception Error of string +exception Ill_typed of string + +module Var : sig + type 'ty t = private { + id: ID.t; + ty: 'ty; + } + + val make : ID.t -> 'ty -> 'ty t + val copy : 'a t -> 'a t + val id : _ t -> ID.t + val ty : 'a t -> 'a + + val equal : 'a t -> 'a t -> bool + val compare : 'a t -> 'a t -> int + val pp : _ t CCFormat.printer +end + +module Ty : sig + type t = + | Prop + | Const of ID.t + | Arrow of t * t + + val prop : t + val const : ID.t -> t + val arrow : t -> t -> t + val arrow_l : t list -> t -> t + + include Intf.EQ with type t := t + include Intf.ORD with type t := t + include Intf.HASH with type t := t + include Intf.PRINT with type t := t + + val unfold : t -> t list * t + (** [unfold ty] will get the list of arguments, and the return type + of any function. An atomic type is just a function with no arguments *) + + (** {2 Datatypes} *) + + (** Mutually recursive datatypes *) + type data = { + data_id: ID.t; + data_cstors: t ID.Map.t; + } + + module Map : CCMap.S with type key = t + + (** {2 Error Handling} *) + + val ill_typed : ('a, Format.formatter, unit, 'b) format4 -> 'a +end + +type var = Ty.t Var.t + +type binop = + | And + | Or + | Imply + | Eq + +type binder = + | Fun + | Forall + | Exists + | Mu + +type term = private { + term: term_cell; + ty: Ty.t; +} +and term_cell = + | Var of var + | Const of ID.t + | Unknown of var + | App of term * term list + | If of term * term * term + | Select of select * term + | Match of term * (var list * term) ID.Map.t + | Switch of term * term ID.Map.t (* switch on constants *) + | Bind of binder * var * term + | Let of var * term * term + | Not of term + | Binop of binop * term * term + | Asserting of term * term + | Undefined_value + | Bool of bool + +and select = { + select_name: ID.t lazy_t; + select_cstor: ID.t; + select_i: int; +} + +(* TODO: records? *) + +type definition = ID.t * Ty.t * term + +type statement = + | Data of Ty.data list + | TyDecl of ID.t (* new atomic cstor *) + | Decl of ID.t * Ty.t + | Define of definition list + | Assert of term + | Goal of var list * term + +(** {2 Constructors} *) + +val term_view : term -> term_cell +val ty : term -> Ty.t + +val var : var -> term +val const : ID.t -> Ty.t -> term +val unknown : var -> term +val app : term -> term list -> term +val app_a : term -> term array -> term +val select : select -> term -> Ty.t -> term +val if_ : term -> term -> term -> term +val match_ : term -> (var list * term) ID.Map.t -> term +val switch : term -> term ID.Map.t -> term +val let_ : var -> term -> term -> term +val bind : ty:Ty.t -> binder -> var -> term -> term +val fun_ : var -> term -> term +val fun_l : var list -> term -> term +val fun_a : var array -> term -> term +val forall : var -> term -> term +val forall_l : var list -> term -> term +val exists : var -> term -> term +val exists_l : var list -> term -> term +val mu : var -> term -> term +val eq : term -> term -> term +val not_ : term -> term +val binop : binop -> term -> term -> term +val and_ : term -> term -> term +val and_l : term list -> term +val or_ : term -> term -> term +val or_l : term list -> term +val imply : term -> term -> term +val true_ : term +val false_ : term +val undefined_value : Ty.t -> term +val asserting : term -> term -> term + +val unfold_fun : term -> var list * term + +(** {2 Printing} *) + +val pp_ty : Ty.t CCFormat.printer +val pp_term : term CCFormat.printer +val pp_statement : statement CCFormat.printer + +(** {2 Environment} *) + +type env_entry = + | E_uninterpreted_ty + | E_uninterpreted_cst (* domain element *) + | E_const of Ty.t + | E_data of Ty.t ID.Map.t (* list of cstors *) + | E_cstor of Ty.t + | E_defined of Ty.t * term (* if defined *) + +type env = { + defs: env_entry ID.Map.t; +} +(** Environment with definitions and goals *) + +val env_empty : env + +val env_add_statement : env -> statement -> env + +val env_of_statements: statement Sequence.t -> env + +val env_find_def : env -> ID.t -> env_entry option + +val env_add_def : env -> ID.t -> env_entry -> env + diff --git a/src/smt/Clause.ml b/src/smt/Clause.ml new file mode 100644 index 00000000..79464d8e --- /dev/null +++ b/src/smt/Clause.ml @@ -0,0 +1,30 @@ + +open Solver_types + +type t = Lit.t list + +let lits c = c + +let pp out c = match lits c with + | [] -> Fmt.string out "false" + | [lit] -> Lit.pp out lit + | l -> + Format.fprintf out "[@[%a@]]" + (Util.pp_list ~sep:"; " Lit.pp) l + +(* canonical form: sorted list *) +let make = + fun l -> CCList.sort_uniq ~cmp:Lit.compare l + +let equal_ c1 c2 = CCList.equal Lit.equal (lits c1) (lits c2) +let hash_ c = Hash.list Lit.hash (lits c) + +module Tbl = CCHashtbl.Make(struct + type t_ = t + type t = t_ + let equal = equal_ + let hash = hash_ + end) + +let iter f c = List.iter f (lits c) +let to_seq c = Sequence.of_list (lits c) diff --git a/src/smt/Clause.mli b/src/smt/Clause.mli new file mode 100644 index 00000000..c2601141 --- /dev/null +++ b/src/smt/Clause.mli @@ -0,0 +1,12 @@ + +open Solver_types + +type t = Lit.t list + +val make : Lit.t list -> t +val lits : t -> Lit.t list +val iter : (Lit.t -> unit) -> t -> unit +val to_seq : t -> Lit.t Sequence.t +val pp : t Fmt.printer + +module Tbl : CCHashtbl.S with type key = t diff --git a/src/smt/Config.ml b/src/smt/Config.ml new file mode 100644 index 00000000..0e9b59ea --- /dev/null +++ b/src/smt/Config.ml @@ -0,0 +1,11 @@ + +(** {1 Configuration} *) + +type 'a sequence = ('a -> unit) -> unit + +module Key = Het_map.Key + +type pair = Het_map.pair = + | Pair : 'a Key.t * 'a -> pair + +include Het_map.Map diff --git a/src/smt/Config.mli b/src/smt/Config.mli new file mode 100644 index 00000000..2b8380c0 --- /dev/null +++ b/src/smt/Config.mli @@ -0,0 +1,48 @@ + +(** {1 Configuration} *) + +type 'a sequence = ('a -> unit) -> unit + +module Key : sig + type 'a t + + val create : unit -> 'a t + + val equal : 'a t -> 'a t -> bool + (** Compare two keys that have compatible types *) +end + +type t + +val empty : t + +val mem : _ Key.t -> t -> bool + +val add : 'a Key.t -> 'a -> t -> t + +val length : t -> int + +val cardinal : t -> int + +val find : 'a Key.t -> t -> 'a option + +val find_exn : 'a Key.t -> t -> 'a +(** @raise Not_found if the key is not in the table *) + +type pair = + | Pair : 'a Key.t * 'a -> pair + +val iter : (pair -> unit) -> t -> unit + +val to_seq : t -> pair sequence + +val of_seq : pair sequence -> t + +val add_seq : t -> pair sequence -> t + +val add_list : t -> pair list -> t + +val of_list : pair list -> t + +val to_list : t -> pair list + diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index 94019621..bcf5d956 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -17,13 +17,26 @@ module Sig_tbl = CCHashtbl.Make(Signature) type merge_op = node * node * explanation (* a merge operation to perform *) -type actions = - | Propagate of Lit.t * explanation list - | Split of Lit.t list * explanation list - | Merge of node * node (* merge these two classes *) +type actions = { + on_backtrack:(unit -> unit) -> unit; + (** Register a callback to be invoked upon backtracking below the current level *) + + at_lvl_0:unit -> bool; + (** Are we currently at backtracking level 0? *) + + on_merge:repr -> repr -> explanation -> unit; + (** Call this when two classes are merged *) + + raise_conflict: 'a. Explanation.t Bag.t -> 'a; + (** Report a conflict *) + + propagate: Lit.t -> Explanation.t Bag.t -> unit; + (** Propagate a literal *) +} type t = { tst: Term.state; + acts: actions; tbl: node Term.Tbl.t; (* internalization [term -> node] *) signatures_tbl : repr Sig_tbl.t; @@ -34,18 +47,10 @@ type t = { The critical property is that all members of an equivalence class that have the same "shape" (including head symbol) have the same signature *) - on_backtrack: (unit -> unit) -> unit; - (* register a function to be called when we backtrack *) - at_lvl_0: unit -> bool; - (* currently at level 0? *) - on_merge: (repr -> repr -> explanation -> unit) list; - (* callbacks to call when we merge classes *) pending: node Vec.t; (* nodes to check, maybe their new signature is in {!signatures_tbl} *) combine: merge_op Vec.t; (* pairs of terms to merge *) - mutable actions : actions list; - (* some boolean propagations/splits to make. *) mutable ps_lits: Lit.Set.t; (* proof state *) ps_queue: (node*node) Vec.t; @@ -79,8 +84,8 @@ let rec find_rec cc (n:node) : repr = let root = find_rec cc old_root in (* path compression *) if (root :> node) != old_root then ( - if not (cc.at_lvl_0 ()) then ( - cc.on_backtrack (fun () -> n.n_root <- old_root); + if not (cc.acts.at_lvl_0 ()) then ( + cc.acts.on_backtrack (fun () -> n.n_root <- old_root); ); n.n_root <- (root :> node); ); @@ -144,8 +149,8 @@ let add_signature cc (t:term) (r:repr): unit = match signature cc t with (* add, but only if not present already *) begin match Sig_tbl.get cc.signatures_tbl s with | None -> - if not (cc.at_lvl_0 ()) then ( - cc.on_backtrack + if not (cc.acts.at_lvl_0 ()) then ( + cc.acts.on_backtrack (fun () -> Sig_tbl.remove cc.signatures_tbl s); ); Sig_tbl.add cc.signatures_tbl s r; @@ -167,19 +172,11 @@ let push_combine cc t u e : unit = Equiv_class.pp t Equiv_class.pp u Explanation.pp e); Vec.push cc.combine (t,u,e) -let push_split cc (lits:lit list) (expl:explanation list): unit = - Log.debugf 5 - (fun k->k "(@[push_split@ (@[%a@])@ expl: (@[%a@])@])" - (Util.pp_list Lit.pp) lits (Util.pp_list Explanation.pp) expl); - let l = Split (lits, expl) in - cc.actions <- l :: cc.actions - -let push_propagation cc (lit:lit) (expl:explanation list): unit = +let push_propagation cc (lit:lit) (expl:explanation Bag.t): unit = Log.debugf 5 (fun k->k "(@[push_propagate@ %a@ expl: (@[%a@])@])" - Lit.pp lit (Util.pp_list Explanation.pp) expl); - let l = Propagate (lit,expl) in - cc.actions <- l :: cc.actions + Lit.pp lit (Util.pp_seq Explanation.pp) @@ Bag.to_seq expl); + cc.acts.propagate lit expl let[@inline] union cc (a:node) (b:node) (e:explanation): unit = if not (same_class cc a b) then ( @@ -189,10 +186,10 @@ let[@inline] union cc (a:node) (b:node) (e:explanation): unit = (* re-root the explanation tree of the equivalence class of [n] so that it points to [n]. postcondition: [n.n_expl = None] *) -let rec reroot_expl cc (n:node): unit = +let rec reroot_expl (cc:t) (n:node): unit = let old_expl = n.n_expl in - if not (cc.at_lvl_0 ()) then ( - cc.on_backtrack (fun () -> n.n_expl <- old_expl); + if not (cc.acts.at_lvl_0 ()) then ( + cc.acts.on_backtrack (fun () -> n.n_expl <- old_expl); ); begin match old_expl with | E_none -> () (* already root *) @@ -202,19 +199,8 @@ let rec reroot_expl cc (n:node): unit = n.n_expl <- E_none; end -(* TODO: - - move what follows into {!Theory}. - - also, obtain merges of CC via callbacks / [pop_merges] afterwards? - *) - -exception Exn_unsat of explanation Bag.t - -let unsat (e:explanation Bag.t): _ = raise (Exn_unsat e) - -type result = - | Sat of actions list - | Unsat of explanation Bag.t - (* list of direct explanations to the conflict. *) +let[@inline] raise_conflict (cc:t) (e:explanation Bag.t): _ = + cc.acts.raise_conflict e let[@inline] all_classes cc : repr Sequence.t = Term.Tbl.values cc.tbl @@ -222,7 +208,7 @@ let[@inline] all_classes cc : repr Sequence.t = (* main CC algo: add terms from [pending] to the signature table, check for collisions *) -let rec update_pending (cc:t): result = +let rec update_pending (cc:t): unit = (* step 2 deal with pending (parent) terms whose equiv class might have changed *) while not (Vec.is_empty cc.pending) do @@ -240,11 +226,7 @@ let rec update_pending (cc:t): result = eval_pending cc; *) done; - if is_done cc then ( - let actions = cc.actions in - cc.actions <- []; - Sat actions - ) else ( + if not (is_done cc) then ( update_combine cc (* repeat *) ) @@ -285,11 +267,12 @@ and update_combine cc = Term.pp t_a Term.pp t_b (Util.pp_list @@ Util.pp_pair Equiv_class.pp Term.pp) l); List.iter (fun (u1,u2) -> push_combine cc u1 (add cc u2) e_ab) l - | Solve_fail {expl} -> + | Solve_fail {expl} -> Log.debugf 5 (fun k->k "(@[solve-fail@ (@[= %a %a@])@ :expl %a@])" Term.pp t_a Term.pp t_b Explanation.pp expl); - raise (Exn_unsat (Bag.return expl)) + + raise_conflict cc (Bag.return expl) end | _ -> assert false ); @@ -310,7 +293,7 @@ and update_combine cc = let r_into = (r_into :> node) in let rb_old_class = r_into.n_class in let rb_old_parents = r_into.n_parents in - cc.on_backtrack + cc.acts.on_backtrack (fun () -> r_from.n_root <- r_from; r_into.n_class <- rb_old_class; @@ -323,8 +306,8 @@ and update_combine cc = begin reroot_expl cc a; assert (a.n_expl = E_none); - if not (cc.at_lvl_0 ()) then ( - cc.on_backtrack (fun () -> a.n_expl <- E_none); + if not (cc.acts.at_lvl_0 ()) then ( + cc.acts.on_backtrack (fun () -> a.n_expl <- E_none); ); a.n_expl <- E_some {next=b; expl=e_ab}; end; @@ -341,10 +324,7 @@ and update_combine cc = and notify_merge cc (ra:repr) ~into:(rb:repr) (e:explanation): unit = assert (is_root_ (ra:>node)); assert (is_root_ (rb:>node)); - List.iter - (fun f -> f ra rb e) - cc.on_merge; - () + cc.acts.on_merge ra rb e (* FIXME: callback? @@ -371,8 +351,8 @@ and add_new_term cc (t:term) : node = (* how to add a subterm *) let add_to_parents_of_sub_node (sub:node) : unit = let old_parents = sub.n_parents in - if not @@ cc.at_lvl_0 () then ( - cc.on_backtrack (fun () -> sub.n_parents <- old_parents); + if not @@ cc.acts.at_lvl_0 () then ( + cc.acts.on_backtrack (fun () -> sub.n_parents <- old_parents); ); sub.n_parents <- Bag.cons n sub.n_parents; push_pending cc sub @@ -395,8 +375,8 @@ and add_new_term cc (t:term) : node = | Custom {view;tc} -> tc.tc_t_sub view add_sub_t end; (* remove term when we backtrack *) - if not (cc.at_lvl_0 ()) then ( - cc.on_backtrack (fun () -> Term.Tbl.remove cc.tbl t); + if not (cc.acts.at_lvl_0 ()) then ( + cc.acts.on_backtrack (fun () -> Term.Tbl.remove cc.tbl t); ); (* add term to the table *) Term.Tbl.add cc.tbl t n; @@ -430,19 +410,16 @@ let assert_lit cc lit : unit = match Lit.view lit with push_combine cc n rhs (E_lit lit); () -let create ?(size=2048) ~on_backtrack ~at_lvl_0 ~on_merge (tst:Term.state) : t = - assert (at_lvl_0 ()); +let create ?(size=2048) ~actions (tst:Term.state) : t = + assert (actions.at_lvl_0 ()); let nd = Equiv_class.dummy in let rec cc = { tst; + acts=actions; tbl = Term.Tbl.create size; - on_merge; signatures_tbl = Sig_tbl.create size; - on_backtrack; - at_lvl_0; pending=Vec.make_empty Equiv_class.dummy; combine= Vec.make_empty (nd,nd,E_reduce_eq(nd,nd)); - actions=[]; ps_lits=Lit.Set.empty; ps_queue=Vec.make_empty (nd,nd); true_ = lazy (add cc (Term.true_ tst)); @@ -557,24 +534,20 @@ let explain_loop (cc : t) : Lit.Set.t = done; cc.ps_lits -let explain_unfold cc (l:explanation list): Lit.Set.t = +let explain_unfold cc (seq:explanation Sequence.t): Lit.Set.t = Log.debugf 5 (fun k->k "(@[explain_confict@ (@[%a@])@])" - (Util.pp_list Explanation.pp) l); + (Util.pp_seq Explanation.pp) seq); ps_clear cc; - List.iter (decompose_explain cc) l; + Sequence.iter (decompose_explain cc) seq; explain_loop cc -let check_ cc = - try update_pending cc - with Exn_unsat e -> - Unsat e - (* check satisfiability, update congruence closure *) -let check (cc:t) : result = +let check (cc:t) : unit = Log.debug 5 "(cc.check)"; - check_ cc + update_pending cc -let final_check cc : result = +let final_check cc : unit = Log.debug 5 "(CC.final_check)"; - check_ cc + update_pending cc + diff --git a/src/smt/Congruence_closure.mli b/src/smt/Congruence_closure.mli index 90a4f736..79b3abe7 100644 --- a/src/smt/Congruence_closure.mli +++ b/src/smt/Congruence_closure.mli @@ -1,6 +1,5 @@ (** {2 Congruence Closure} *) -open CDCL open Solver_types type t @@ -12,16 +11,30 @@ type node = Equiv_class.t type repr = Equiv_class.t (** Node that is currently a representative *) +type actions = { + on_backtrack:(unit -> unit) -> unit; + (** Register a callback to be invoked upon backtracking below the current level *) + + at_lvl_0:unit -> bool; + (** Are we currently at backtracking level 0? *) + + on_merge:repr -> repr -> explanation -> unit; + (** Call this when two classes are merged *) + + raise_conflict: 'a. Explanation.t Bag.t -> 'a; + (** Report a conflict *) + + propagate: Lit.t -> Explanation.t Bag.t -> unit; + (** Propagate a literal *) +} + val create : ?size:int -> - on_backtrack:((unit -> unit) -> unit) -> - at_lvl_0:(unit -> bool) -> - on_merge:(repr -> repr -> explanation -> unit) list -> + actions:actions -> Term.state -> t (** Create a new congruence closure. - @param on_backtrack used to register undo actions - @param on_merge callbacks called when two equiv classes are merged + @param acts the actions available to the congruence closure *) val find : t -> node -> repr @@ -47,20 +60,13 @@ val add : t -> term -> node val add_seq : t -> term Sequence.t -> unit (** Add a sequence of terms to the congruence closure *) -type actions = - | Propagate of Lit.t * explanation list - | Split of Lit.t list * explanation list - | Merge of node * node (* merge these two classes *) +val all_classes : t -> repr Sequence.t +(** All current classes *) -type result = - | Sat of actions list - | Unsat of explanation Bag.t - (* list of direct explanations to the conflict. *) +val check : t -> unit -val check : t -> result +val final_check : t -> unit -val final_check : t -> result - -val explain_unfold: t -> explanation list -> Lit.Set.t +val explain_unfold: t -> explanation Sequence.t -> Lit.Set.t (** Unfold those explanations into a complete set of literals implying them *) diff --git a/src/smt/Cst.ml b/src/smt/Cst.ml index 5f1f3856..12b0c669 100644 --- a/src/smt/Cst.ml +++ b/src/smt/Cst.ml @@ -1,5 +1,4 @@ -open CDCL open Solver_types type t = cst diff --git a/src/smt/Cst.mli b/src/smt/Cst.mli index b0e234f3..2e68beb6 100644 --- a/src/smt/Cst.mli +++ b/src/smt/Cst.mli @@ -1,5 +1,4 @@ -open CDCL open Solver_types type t = cst diff --git a/src/smt/Het_map.ml b/src/smt/Het_map.ml new file mode 100644 index 00000000..183a296f --- /dev/null +++ b/src/smt/Het_map.ml @@ -0,0 +1,191 @@ + +(* This file is free software, part of containers. See file "license" for more details. *) + +(** {1 Associative containers with Heterogenerous Values} *) + +(*$R + let k1 : int Key.t = Key.create() in + let k2 : int Key.t = Key.create() in + let k3 : string Key.t = Key.create() in + let k4 : float Key.t = Key.create() in + + let tbl = Tbl.create () in + + Tbl.add tbl k1 1; + Tbl.add tbl k2 2; + Tbl.add tbl k3 "k3"; + + assert_equal (Some 1) (Tbl.find tbl k1); + assert_equal (Some 2) (Tbl.find tbl k2); + assert_equal (Some "k3") (Tbl.find tbl k3); + assert_equal None (Tbl.find tbl k4); + assert_equal 3 (Tbl.length tbl); + + Tbl.add tbl k1 10; + assert_equal (Some 10) (Tbl.find tbl k1); + assert_equal 3 (Tbl.length tbl); + assert_equal None (Tbl.find tbl k4); + + Tbl.add tbl k4 0.0; + assert_equal (Some 0.0) (Tbl.find tbl k4); + + () + + +*) + +type 'a sequence = ('a -> unit) -> unit +type 'a gen = unit -> 'a option + +module type KEY_IMPL = sig + type t + exception Store of t + val id : int +end + +module Key = struct + type 'a t = (module KEY_IMPL with type t = 'a) + + let _n = ref 0 + + let create (type k) () = + incr _n; + let id = !_n in + let module K = struct + type t = k + let id = id + exception Store of k + end in + (module K : KEY_IMPL with type t = k) + + let id (type k) (module K : KEY_IMPL with type t = k) = K.id + + let equal + : type a b. a t -> b t -> bool + = fun (module K1) (module K2) -> K1.id = K2.id +end + +type pair = + | Pair : 'a Key.t * 'a -> pair + +type exn_pair = + | E_pair : 'a Key.t * exn -> exn_pair + +let pair_of_e_pair (E_pair (k,e)) = + let module K = (val k) in + match e with + | K.Store v -> Pair (k,v) + | _ -> assert false + +module Tbl = struct + module M = Hashtbl.Make(struct + type t = int + let equal (i:int) j = i=j + let hash (i:int) = Hashtbl.hash i + end) + + type t = exn_pair M.t + + let create ?(size=16) () = M.create size + + let mem t k = M.mem t (Key.id k) + + let find_exn (type a) t (k : a Key.t) : a = + let module K = (val k) in + let E_pair (_, v) = M.find t K.id in + match v with + | K.Store v -> v + | _ -> assert false + + let find t k = + try Some (find_exn t k) + with Not_found -> None + + let add_pair_ t p = + let Pair (k,v) = p in + let module K = (val k) in + let p = E_pair (k, K.Store v) in + M.replace t K.id p + + let add t k v = add_pair_ t (Pair (k,v)) + + let length t = M.length t + + let iter f t = M.iter (fun _ pair -> f (pair_of_e_pair pair)) t + + let to_seq t yield = iter yield t + + let to_list t = M.fold (fun _ p l -> pair_of_e_pair p::l) t [] + + let add_list t l = List.iter (add_pair_ t) l + + let add_seq t seq = seq (add_pair_ t) + + let of_list l = + let t = create() in + add_list t l; + t + + let of_seq seq = + let t = create() in + add_seq t seq; + t +end + +module Map = struct + module M = Map.Make(struct + type t = int + let compare (i:int) j = Pervasives.compare i j + end) + + type t = exn_pair M.t + + let empty = M.empty + + let mem k t = M.mem (Key.id k) t + + let find_exn (type a) (k : a Key.t) t : a = + let module K = (val k) in + let E_pair (_, e) = M.find K.id t in + match e with + | K.Store v -> v + | _ -> assert false + + let find k t = + try Some (find_exn k t) + with Not_found -> None + + let add_e_pair_ p t = + let E_pair ((module K),_) = p in + M.add K.id p t + + let add_pair_ p t = + let Pair ((module K) as k,v) = p in + let p = E_pair (k, K.Store v) in + M.add K.id p t + + let add (type a) (k : a Key.t) v t = + let module K = (val k) in + add_e_pair_ (E_pair (k, K.Store v)) t + + let cardinal t = M.cardinal t + + let length = cardinal + + let iter f t = M.iter (fun _ p -> f (pair_of_e_pair p)) t + + let to_seq t yield = iter yield t + + let to_list t = M.fold (fun _ p l -> pair_of_e_pair p::l) t [] + + let add_list t l = List.fold_right add_pair_ l t + + let add_seq t seq = + let t = ref t in + seq (fun pair -> t := add_pair_ pair !t); + !t + + let of_list l = add_list empty l + + let of_seq seq = add_seq empty seq +end diff --git a/src/smt/Het_map.mli b/src/smt/Het_map.mli new file mode 100644 index 00000000..4b876906 --- /dev/null +++ b/src/smt/Het_map.mli @@ -0,0 +1,85 @@ + +(* This file is free software, part of containers. See file "license" for more details. *) + +(** {1 Associative containers with Heterogenerous Values} *) + +type 'a sequence = ('a -> unit) -> unit +type 'a gen = unit -> 'a option + +module Key : sig + type 'a t + + val create : unit -> 'a t + + val equal : 'a t -> 'a t -> bool + (** Compare two keys that have compatible types *) +end + +type pair = + | Pair : 'a Key.t * 'a -> pair + +(** {2 Imperative table indexed by {!Key}} *) +module Tbl : sig + type t + + val create : ?size:int -> unit -> t + + val mem : t -> _ Key.t -> bool + + val add : t -> 'a Key.t -> 'a -> unit + + val length : t -> int + + val find : t -> 'a Key.t -> 'a option + + val find_exn : t -> 'a Key.t -> 'a + (** @raise Not_found if the key is not in the table *) + + val iter : (pair -> unit) -> t -> unit + + val to_seq : t -> pair sequence + + val of_seq : pair sequence -> t + + val add_seq : t -> pair sequence -> unit + + val add_list : t -> pair list -> unit + + val of_list : pair list -> t + + val to_list : t -> pair list +end + +(** {2 Immutable map} *) +module Map : sig + type t + + val empty : t + + val mem : _ Key.t -> t -> bool + + val add : 'a Key.t -> 'a -> t -> t + + val length : t -> int + + val cardinal : t -> int + + val find : 'a Key.t -> t -> 'a option + + val find_exn : 'a Key.t -> t -> 'a + (** @raise Not_found if the key is not in the table *) + + val iter : (pair -> unit) -> t -> unit + + val to_seq : t -> pair sequence + + val of_seq : pair sequence -> t + + val add_seq : t -> pair sequence -> t + + val add_list : t -> pair list -> t + + val of_list : pair list -> t + + val to_list : t -> pair list +end diff --git a/src/smt/Lit.ml b/src/smt/Lit.ml index 1930112d..964c455b 100644 --- a/src/smt/Lit.ml +++ b/src/smt/Lit.ml @@ -1,5 +1,4 @@ -open CDCL open Solver_types type t = lit diff --git a/src/smt/Lit.mli b/src/smt/Lit.mli index ba178daf..193b7ee3 100644 --- a/src/smt/Lit.mli +++ b/src/smt/Lit.mli @@ -1,6 +1,5 @@ (** {2 Literals} *) -open CDCL open Solver_types type t = lit diff --git a/src/smt/Model.ml b/src/smt/Model.ml new file mode 100644 index 00000000..b6e42f7d --- /dev/null +++ b/src/smt/Model.ml @@ -0,0 +1,370 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Model} *) + +open CDCL + +module A = Ast + +type term = A.term +type ty = A.Ty.t +type domain = ID.t list + +type t = { + env: A.env; + (* environment, defining symbols *) + domains: domain A.Ty.Map.t; + (* uninterpreted type -> its domain *) + consts: term ID.Map.t; + (* constant -> its value *) +} + +let make ~env ~consts ~domains = + (* also add domains to [env] *) + let env = + A.Ty.Map.to_seq domains + |> Sequence.flat_map_l (fun (ty,l) -> List.map (CCPair.make ty) l) + |> Sequence.fold + (fun env (_,cst) -> A.env_add_def env cst A.E_uninterpreted_cst) + env + in + {env; consts; domains} + +type entry = + | E_ty of ty * domain + | E_const of ID.t * term + +let pp out (m:t) = + let pp_cst_name out c = ID.pp_name out c in + let pp_ty = A.Ty.pp in + let pp_term = A.pp_term in + let pp_entry out = function + | E_ty (ty,l) -> + Format.fprintf out "(@[<1>type@ %a@ (@[%a@])@])" + pp_ty ty (Util.pp_list pp_cst_name) l + | E_const (c,t) -> + Format.fprintf out "(@[<1>val@ %a@ %a@])" + ID.pp_name c pp_term t + in + let es = + CCList.append + (A.Ty.Map.to_list m.domains |> List.map (fun (ty,dom) -> E_ty (ty,dom))) + (ID.Map.to_list m.consts |> List.map (fun (c,t) -> E_const (c,t))) + in + Format.fprintf out "(@[%a@])" (Util.pp_list pp_entry) es + +exception Bad_model of t * term * term +exception Error of string + +let () = Printexc.register_printer + (function + | Error msg -> Some ("internal error: " ^ msg) + | Bad_model (m,t,t') -> + let msg = CCFormat.sprintf + "@[Bad model:@ goal `@[%a@]`@ evaluates to `@[%a@]`,@ \ + not true,@ in model @[%a@]@." + A.pp_term t A.pp_term t' pp m + in + Some msg + | _ -> None) + +let errorf msg = CCFormat.ksprintf msg ~f:(fun s -> raise (Error s)) + +module VarMap = CCMap.Make(struct + type t = A.Ty.t A.Var.t + let compare = A.Var.compare + end) + +(* var -> term in normal form *) +type subst = A.term lazy_t VarMap.t + +let empty_subst : subst = VarMap.empty + +let rename_var subst v = + let v' = A.Var.copy v in + VarMap.add v (Lazy.from_val (A.var v')) subst, v' + +let rename_vars = CCList.fold_map rename_var + +let pp_subst out (s:subst) = + let pp_pair out (v,lazy t) = + Format.fprintf out "@[<2>%a@ @<1>→ %a@]" A.Var.pp v A.pp_term t + in + Format.fprintf out "[@[%a@]]" + CCFormat.(list ~sep:(return ",@ ") pp_pair) (VarMap.to_list s |> List.rev) + +let rec as_cstor_app env t = match A.term_view t with + | A.Const id -> + begin match A.env_find_def env id with + | Some (A.E_cstor ty) -> Some (id, ty, []) + | _ -> None + end + | A.App (f, l) -> + CCOpt.map (fun (id,ty,l') -> id,ty,l'@l) (as_cstor_app env f) + | _ -> None + +let as_domain_elt env t = match A.term_view t with + | A.Const id -> + begin match A.env_find_def env id with + | Some A.E_uninterpreted_cst -> Some id + | _ -> None + end + | _ -> None + +let pp_stack out (l:term list) : unit = + let ppt out t = Format.fprintf out "(@[%a@ :ty %a@])" A.pp_term t A.Ty.pp t.A.ty in + CCFormat.(within "[" "]" (hvbox (list ppt))) out l + +let apply_subst (subst:subst) t = + let rec aux subst t = match A.term_view t with + | A.Var v -> + begin match VarMap.get v subst with + | None -> t + | Some (lazy t') -> t' + end + | A.Undefined_value + | A.Bool _ | A.Const _ | A.Unknown _ -> t + | A.Select (sel, t) -> A.select sel (aux subst t) t.A.ty + | A.App (f,l) -> A.app (aux subst f) (List.map (aux subst) l) + | A.If (a,b,c) -> A.if_ (aux subst a) (aux subst b) (aux subst c) + | A.Match (u,m) -> + A.match_ (aux subst u) + (ID.Map.map + (fun (vars,rhs) -> + let subst, vars = rename_vars subst vars in + vars, aux subst rhs) m) + | A.Switch (u,m) -> + A.switch (aux subst u) (ID.Map.map (aux subst) m) + | A.Let (x,t,u) -> + let subst', x' = rename_var subst x in + A.let_ x' (aux subst t) (aux subst' u) + | A.Bind (A.Mu, _,_) -> assert false + | A.Bind (b, x,body) -> + let subst', x' = rename_var subst x in + A.bind ~ty:(A.ty t) b x' (aux subst' body) + | A.Not f -> A.not_ (aux subst f) + | A.Binop (op,a,b) -> A.binop op (aux subst a)(aux subst b) + | A.Asserting (t,g) -> + A.asserting (aux subst t)(aux subst g) + in + if VarMap.is_empty subst then t else aux subst t + +(* Weak Head Normal Form. + @param m the model + @param st the "stack trace" (terms around currently being evaluated) + @param t the term to eval *) +let rec eval_whnf (m:t) (st:term list) (subst:subst) (t:term): term = + Log.debugf 5 + (fun k->k "%s@[<2>eval_whnf `@[%a@]`@ in @[%a@]@]" + (String.make (List.length st) ' ') (* indent *) + A.pp_term t pp_subst subst); + let st = t :: st in + try + eval_whnf_rec m st subst t + with A.Ill_typed msg -> + errorf "@[<2>Model:@ internal type error `%s`@ in %a@]" msg pp_stack st +and eval_whnf_rec m st subst t = match A.term_view t with + | A.Undefined_value | A.Bool _ | A.Unknown _ -> t + | A.Var v -> + begin match VarMap.get v subst with + | None -> t + | Some (lazy t') -> + eval_whnf m st empty_subst t' + end + | A.Const c -> + begin match A.env_find_def m.env c with + | Some (A.E_defined (_, t')) -> eval_whnf m st empty_subst t' + | _ -> + begin match ID.Map.get c m.consts with + | None -> t + | Some {A.term=A.Const c';_} when (ID.equal c c') -> t (* trivial cycle *) + | Some t' -> eval_whnf m st empty_subst t' + end + end + | A.App (f,l) -> eval_whnf_app m st subst subst f l + | A.If (a,b,c) -> + let a = eval_whnf m st subst a in + begin match A.term_view a with + | A.Bool true -> eval_whnf m st subst b + | A.Bool false -> eval_whnf m st subst c + | _ -> + let b = apply_subst subst b in + let c = apply_subst subst c in + A.if_ a b c + end + | A.Bind (A.Mu,v,body) -> + let subst' = VarMap.add v (lazy t) subst in + eval_whnf m st subst' body + | A.Let (x,t,u) -> + let t = lazy (eval_whnf m st subst t) in + let subst' = VarMap.add x t subst in + eval_whnf m st subst' u + | A.Bind (A.Fun,_,_) -> apply_subst subst t + | A.Bind ((A.Forall | A.Exists) as b,v,body) -> + let ty = A.Var.ty v in + let dom = + try A.Ty.Map.find ty m.domains + with Not_found -> + errorf "@[<2>could not find type %a in model@ stack %a@]" + A.Ty.pp ty pp_stack st + in + (* expand into and/or over the domain *) + let t' = + let l = + List.map + (fun c_dom -> + let subst' = VarMap.add v (lazy (A.const c_dom ty)) subst in + eval_whnf m st subst' body) + dom + in + begin match b with + | A.Forall -> A.and_l l + | A.Exists -> A.or_l l + | _ -> assert false + end + in + eval_whnf m st subst t' + | A.Select (sel, u) -> + let u = eval_whnf m st subst u in + let t' = A.select sel u t.A.ty in + begin match as_cstor_app m.env u with + | None -> t' + | Some (cstor, _, args) -> + if ID.equal cstor sel.A.select_cstor then ( + (* cstors match, take the argument *) + assert (List.length args > sel.A.select_i); + let new_t = List.nth args sel.A.select_i in + eval_whnf m st subst new_t + ) else ( + A.undefined_value t.A.ty + ) + end + | A.Match (u, branches) -> + let u = eval_whnf m st subst u in + begin match as_cstor_app m.env u with + | None -> + let branches = + ID.Map.map + (fun (vars,rhs) -> + let subst, vars = rename_vars subst vars in + vars, apply_subst subst rhs) + branches + in + A.match_ u branches + | Some (c, _, cstor_args) -> + match ID.Map.get c branches with + | None -> assert false + | Some (vars, rhs) -> + assert (List.length vars = List.length cstor_args); + let subst' = + List.fold_left2 + (fun s v arg -> + let arg' = lazy (apply_subst subst arg) in + VarMap.add v arg' s) + subst vars cstor_args + in + eval_whnf m st subst' rhs + end + | A.Switch (u, map) -> + let u = eval_whnf m st subst u in + begin match as_domain_elt m.env u with + | None -> + let map = ID.Map.map (apply_subst subst) map in + A.switch u map + | Some cst -> + begin match ID.Map.get cst map with + | Some rhs -> eval_whnf m st subst rhs + | None -> + let map = ID.Map.map (apply_subst subst) map in + A.switch u map + end + end + | A.Not f -> + let f = eval_whnf m st subst f in + begin match A.term_view f with + | A.Bool true -> A.false_ + | A.Bool false -> A.true_ + | _ -> A.not_ f + end + | A.Asserting (u, g) -> + let g' = eval_whnf m st subst g in + begin match A.term_view g' with + | A.Bool true -> eval_whnf m st subst u + | A.Bool false -> + A.undefined_value u.A.ty (* assertion failed, uncharted territory! *) + | _ -> A.asserting u g' + end + | A.Binop (op, a, b) -> + let a = eval_whnf m st subst a in + let b = eval_whnf m st subst b in + begin match op with + | A.And -> + begin match A.term_view a, A.term_view b with + | A.Bool true, A.Bool true -> A.true_ + | A.Bool false, _ + | _, A.Bool false -> A.false_ + | _ -> A.and_ a b + end + | A.Or -> + begin match A.term_view a, A.term_view b with + | A.Bool true, _ + | _, A.Bool true -> A.true_ + | A.Bool false, A.Bool false -> A.false_ + | _ -> A.or_ a b + end + | A.Imply -> + begin match A.term_view a, A.term_view b with + | _, A.Bool true + | A.Bool false, _ -> A.true_ + | A.Bool true, A.Bool false -> A.false_ + | _ -> A.imply a b + end + | A.Eq -> + begin match A.term_view a, A.term_view b with + | A.Bool true, A.Bool true + | A.Bool false, A.Bool false -> A.true_ + | A.Bool true, A.Bool false + | A.Bool false, A.Bool true -> A.false_ + | A.Var v1, A.Var v2 when A.Var.equal v1 v2 -> A.true_ + | A.Const id1, A.Const id2 when ID.equal id1 id2 -> A.true_ + | _ -> + begin match as_cstor_app m.env a, as_cstor_app m.env b with + | Some (c1,_,l1), Some (c2,_,l2) -> + if ID.equal c1 c2 then ( + assert (List.length l1 = List.length l2); + eval_whnf m st subst (A.and_l (List.map2 A.eq l1 l2)) + ) else A.false_ + | _ -> + begin match as_domain_elt m.env a, as_domain_elt m.env b with + | Some c1, Some c2 -> + (* domain elements: they are all distinct *) + if ID.equal c1 c2 + then A.true_ + else A.false_ + | _ -> + A.eq a b + end + end + end + end +(* beta-reduce [f l] while [f] is a function,constant or variable *) +and eval_whnf_app m st subst_f subst_l f l = match A.term_view f, l with + | A.Bind (A.Fun,v, body), arg :: tail -> + let subst_f = VarMap.add v (lazy (apply_subst subst_l arg)) subst_f in + eval_whnf_app m st subst_f subst_l body tail + | _ -> eval_whnf_app' m st subst_f subst_l f l +(* evaluate [f] and try to beta-reduce if [eval_whnf m f] is a function *) +and eval_whnf_app' m st subst_f subst_l f l = + let f' = eval_whnf m st subst_f f in + begin match A.term_view f', l with + | A.Bind (A.Fun,_,_), _::_ -> + eval_whnf_app m st subst_l subst_l f' l (* beta-reduce again *) + | _ -> + (* blocked *) + let l = List.map (apply_subst subst_l) l in + A.app f' l + end + +(* eval term [t] under model [m] *) +let eval (m:t) (t:term) = eval_whnf m [] empty_subst t diff --git a/src/smt/Model.mli b/src/smt/Model.mli new file mode 100644 index 00000000..52b1f7ba --- /dev/null +++ b/src/smt/Model.mli @@ -0,0 +1,29 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Model} *) + +type term = Ast.term +type ty = Ast.Ty.t +type domain = ID.t list + +type t = private { + env: Ast.env; + (* environment, defining symbols *) + domains: domain Ast.Ty.Map.t; + (* uninterpreted type -> its domain *) + consts: term ID.Map.t; + (* constant -> its value *) +} + +val make : + env:Ast.env -> + consts:term ID.Map.t -> + domains:domain Ast.Ty.Map.t -> + t + +val pp : t CCFormat.printer + +val eval : t -> term -> term + +exception Bad_model of t * term * term diff --git a/src/smt/Solver.ml b/src/smt/Solver.ml new file mode 100644 index 00000000..bc11187d --- /dev/null +++ b/src/smt/Solver.ml @@ -0,0 +1,840 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Main Solver} *) + +open Solver_types + +type term = Term.t +type cst = Cst.t +type ty = Ty.t +type ty_def = Solver_types.ty_def + +type ty_cell = Solver_types.ty_cell = + | Prop + | Atomic of ID.t * ty_def + | Arrow of ty * ty + +let get_time : unit -> float = Sys.time + +(** {2 The Main Solver} *) + +type level = int + +module Sat = CDCL.Make(Theory_combine) + +(* main solver state *) +type t = { + solver: Sat.t; + stat: Stat.t; + config: Config.t +} + +let th_combine (self:t) : Theory_combine.t = + Sat.theory self.solver + +let create ?size ?(config=Config.empty) ~theories () : t = + let self = { + solver=Sat.create ?size (); + stat=Stat.create (); + config; + } in + (* now add the theories *) + Theory_combine.add_theory_l (th_combine self) theories; + self + +(** {2 Sat Solver} *) + +let print_progress (st:t) : unit = + Printf.printf "\r[%.2f] expanded %d | clauses %d | lemmas %d%!" + (get_time()) + st.stat.Stat.num_cst_expanded + st.stat.Stat.num_clause_push + st.stat.Stat.num_clause_tautology + +let flush_progress (): unit = + Printf.printf "\r%-80d\r%!" 0 + +(** {2 Toplevel Goals} + + List of toplevel goals to satisfy. Mainly used for checking purpose +*) + +module Top_goals: sig + val push : term -> unit + val to_seq : term Sequence.t + val check: unit -> unit +end = struct + (* list of terms to fully evaluate *) + let toplevel_goals_ : term list ref = ref [] + + (* add [t] to the set of terms that must be evaluated *) + let push (t:term): unit = + toplevel_goals_ := t :: !toplevel_goals_; + () + + let to_seq k = List.iter k !toplevel_goals_ + + (* FIXME + (* check that this term fully evaluates to [true] *) + let is_true_ (t:term): bool = match CC.normal_form t with + | None -> false + | Some (NF_bool b) -> b + | Some (NF_cstor _) -> assert false (* not a bool *) + + let check () = + if not (List.for_all is_true_ !toplevel_goals_) + then ( + if Config.progress then flush_progress(); + Log.debugf 1 + (fun k-> + let pp_lit out t = + let nf = CC.normal_form t in + Format.fprintf out "(@[term: %a@ nf: %a@])" + Term.pp t (Fmt.opt pp_term_nf) nf + in + k "(@[Top_goals.check@ (@[%a@])@])" + (Util.pp_list pp_lit) !toplevel_goals_); + assert false; + ) + *) + + let check () : unit = () +end + +(** {2 Conversion} *) + +(* list of constants we are interested in *) +let model_support_ : Cst.t list ref = ref [] + +let model_env_ : Ast.env ref = ref Ast.env_empty + +let add_cst_support_ (c:cst): unit = + CCList.Ref.push model_support_ c + +let add_ty_support_ (_ty:Ty.t): unit = () + +(* FIXME: do this in another module, perhaps? +module Conv : sig + val add_statement : Ast.statement -> unit + val add_statement_l : Ast.statement list -> unit + val ty_to_ast: Ty.t -> Ast.Ty.t + val term_to_ast: term -> Ast.term +end = struct + (* for converting Ast.Ty into Ty *) + let ty_tbl_ : Ty.t lazy_t ID.Tbl.t = ID.Tbl.create 16 + + (* for converting constants *) + let decl_ty_ : cst lazy_t ID.Tbl.t = ID.Tbl.create 16 + + (* environment for variables *) + type conv_env = { + let_bound: (term * int) ID.Map.t; + (* let-bound variables, to be replaced. int=depth at binding position *) + bound: (int * Ty.t) ID.Map.t; + (* set of bound variables. int=depth at binding position *) + depth: int; + } + + let empty_env : conv_env = + {let_bound=ID.Map.empty; bound=ID.Map.empty; depth=0} + + let rec conv_ty (ty:Ast.Ty.t): Ty.t = match ty with + | Ast.Ty.Prop -> Ty.prop + | Ast.Ty.Const id -> + begin try ID.Tbl.find ty_tbl_ id |> Lazy.force + with Not_found -> Util.errorf "type %a not in ty_tbl" ID.pp id + end + | Ast.Ty.Arrow (a,b) -> Ty.arrow (conv_ty a) (conv_ty b) + + let add_bound env v = + let ty = Ast.Var.ty v |> conv_ty in + { env with + depth=env.depth+1; + bound=ID.Map.add (Ast.Var.id v) (env.depth,ty) env.bound; } + + (* add [v := t] to bindings. Depth is not incremented + (there will be no binders) *) + let add_let_bound env v t = + { env with + let_bound=ID.Map.add (Ast.Var.id v) (t,env.depth) env.let_bound } + + let find_env env v = + let id = Ast.Var.id v in + ID.Map.get id env.let_bound, ID.Map.get id env.bound + + let rec conv_term_rec + (env: conv_env) + (t:Ast.term): term = match Ast.term_view t with + | Ast.Bool true -> Term.true_ + | Ast.Bool false -> Term.false_ + | Ast.Unknown _ -> assert false + | Ast.Const id -> + begin + try ID.Tbl.find decl_ty_ id |> Lazy.force |> Term.const + with Not_found -> + errorf "could not find constant `%a`" ID.pp id + end + | Ast.App (f, l) -> + begin match Ast.term_view f with + | Ast.Const id -> + let f = + try ID.Tbl.find decl_ty_ id |> Lazy.force + with Not_found -> + errorf "could not find constant `%a`" ID.pp id + in + let l = List.map (conv_term_rec env) l in + if List.length l = fst (Ty.unfold_n (Cst.ty f)) + then Term.app_cst f (IArray.of_list l) (* fully applied *) + else Term.app (Term.const f) l + | _ -> + let f = conv_term_rec env f in + let l = List.map (conv_term_rec env) l in + Term.app f l + end + | Ast.Var v -> + (* look whether [v] must be replaced by some term *) + begin match AstVarMap.get v env.subst with + | Some t -> t + | None -> + (* lookup as bound variable *) + begin match CCList.find_idx (Ast.Var.equal v) env.bound with + | None -> errorf "could not find var `%a`" Ast.Var.pp v + | Some (i,_) -> + let ty = Ast.Var.ty v |> conv_ty in + Term.db (DB.make i ty) + end + end + | Ast.Bind (Ast.Fun,v,body) -> + let body = conv_term_rec {env with bound=v::env.bound} body in + let ty = Ast.Var.ty v |> conv_ty in + Term.fun_ ty body + | Ast.Bind ((Ast.Forall | Ast.Exists),_, _) -> + errorf "quantifiers not supported" + | Ast.Bind (Ast.Mu,v,body) -> + let env' = add_bound env v in + let body = conv_term_rec env' body in + Term.mu body + | Ast.Select _ -> assert false (* TODO *) + | Ast.Match (u,m) -> + let any_rhs_depends_vars = ref false in (* some RHS depends on matched arg? *) + let m = + ID.Map.map + (fun (vars,rhs) -> + let n_vars = List.length vars in + let env', tys = + CCList.fold_map + (fun env v -> add_bound env v, Ast.Var.ty v |> conv_ty) + env vars + in + let rhs = conv_term_rec env' rhs in + let depends_on_vars = + Term.to_seq_depth rhs + |> Sequence.exists + (fun (t,k) -> match t.term_cell with + | DB db -> + DB.level db < n_vars + k (* [k]: number of intermediate binders *) + | _ -> false) + in + if depends_on_vars then any_rhs_depends_vars := true; + tys, rhs) + m + in + (* optim: check whether all branches return the same term, that + does not depend on matched variables *) + (* TODO: do the closedness check during conversion, above *) + let rhs_l = + ID.Map.values m + |> Sequence.map snd + |> Sequence.sort_uniq ~cmp:Term.compare + |> Sequence.to_rev_list + in + begin match rhs_l with + | [x] when not (!any_rhs_depends_vars) -> + (* every branch yields the same [x], which does not depend + on the argument: remove the match and return [x] instead *) + x + | _ -> + let u = conv_term_rec env u in + Term.match_ u m + end + | Ast.Switch _ -> + errorf "cannot convert switch %a" Ast.pp_term t + | Ast.Let (v,t,u) -> + (* substitute on the fly *) + let t = conv_term_rec env t in + let env' = add_let_bound env v t in + conv_term_rec env' u + | Ast.If (a,b,c) -> + let b = conv_term_rec env b in + let c = conv_term_rec env c in + (* optim: [if _ b b --> b] *) + if Term.equal b c + then b + else Term.if_ (conv_term_rec env a) b c + | Ast.Not t -> Term.not_ (conv_term_rec env t) + | Ast.Binop (op,a,b) -> + let a = conv_term_rec env a in + let b = conv_term_rec env b in + begin match op with + | Ast.And -> Term.and_ a b + | Ast.Or -> Term.or_ a b + | Ast.Imply -> Term.imply a b + | Ast.Eq -> Term.eq a b + end + | Ast.Undefined_value -> + Term.undefined_value (conv_ty t.Ast.ty) Undef_absolute + | Ast.Asserting (t, g) -> + (* [t asserting g] becomes [if g t fail] *) + let t = conv_term_rec env t in + let g = conv_term_rec env g in + Term.if_ g t (Term.undefined_value t.term_ty Undef_absolute) + + let add_statement st = + Log.debugf 2 + (fun k->k "(@[add_statement@ @[%a@]@])" Ast.pp_statement st); + model_env_ := Ast.env_add_statement !model_env_ st; + begin match st with + | Ast.Assert t -> + let t = conv_term_rec empty_env t in + Top_goals.push t; + push_clause (Clause.make [Lit.atom t]) + | Ast.Goal (vars, t) -> + (* skolemize *) + let env, consts = + CCList.fold_map + (fun env v -> + let ty = Ast.Var.ty v |> conv_ty in + let c = Cst.make_undef (Ast.Var.id v) ty in + {env with subst=AstVarMap.add v (Term.const c) env.subst}, c) + empty_env + vars + in + (* model should contain values of [consts] *) + List.iter add_cst_support_ consts; + let t = conv_term_rec env t in + Top_goals.push t; + push_clause (Clause.make [Lit.atom t]) + | Ast.TyDecl id -> + let ty = Ty.atomic id Uninterpreted ~card:(Lazy.from_val Infinite) in + add_ty_support_ ty; + ID.Tbl.add ty_tbl_ id (Lazy.from_val ty) + | Ast.Decl (id, ty) -> + assert (not (ID.Tbl.mem decl_ty_ id)); + let ty = conv_ty ty in + let cst = Cst.make_undef id ty in + add_cst_support_ cst; (* need it in model *) + ID.Tbl.add decl_ty_ id (Lazy.from_val cst) + | Ast.Data l -> + (* the datatypes in [l]. Used for computing cardinalities *) + let in_same_block : ID.Set.t = + List.map (fun {Ast.Ty.data_id; _} -> data_id) l |> ID.Set.of_list + in + (* declare the type, and all the constructors *) + List.iter + (fun {Ast.Ty.data_id; data_cstors} -> + let ty = lazy ( + let card_ : ty_card ref = ref Finite in + let cstors = lazy ( + data_cstors + |> ID.Map.map + (fun c -> + let c_id = c.Ast.Ty.cstor_id in + let ty_c = conv_ty c.Ast.Ty.cstor_ty in + let ty_args, ty_ret = Ty.unfold ty_c in + (* add cardinality of [c] to the cardinality of [data_id]. + (product of cardinalities of args) *) + let cstor_card = + ty_args + |> List.map + (fun ty_arg -> match ty_arg.ty_cell with + | Atomic (id, _) when ID.Set.mem id in_same_block -> + Infinite + | _ -> Lazy.force ty_arg.ty_card) + |> Ty_card.product + in + card_ := Ty_card.( !card_ + cstor_card ); + let rec cst = lazy ( + Cst.make_cstor c_id ty_c cstor + ) and cstor = lazy ( + let cstor_proj = lazy ( + let n = ref 0 in + List.map2 + (fun id ty_arg -> + let ty_proj = Ty.arrow ty_ret ty_arg in + let i = !n in + incr n; + Cst.make_proj id ty_proj cstor i) + c.Ast.Ty.cstor_proj ty_args + |> IArray.of_list + ) in + let cstor_test = lazy ( + let ty_test = Ty.arrow ty_ret Ty.prop in + Cst.make_tester c.Ast.Ty.cstor_test ty_test cstor + ) in + { cstor_ty=ty_c; cstor_cst=Lazy.force cst; + cstor_args=IArray.of_list ty_args; + cstor_proj; cstor_test; cstor_card; } + ) in + ID.Tbl.add decl_ty_ c_id cst; (* declare *) + Lazy.force cstor) + ) + in + let data = { data_cstors=cstors; } in + let card = lazy ( + ignore (Lazy.force cstors); + let r = !card_ in + Log.debugf 5 + (fun k->k "(@[card_of@ %a@ %a@])" ID.pp data_id Ty_card.pp r); + r + ) in + Ty.atomic data_id (Data data) ~card + ) in + ID.Tbl.add ty_tbl_ data_id ty; + ) + l; + (* force evaluation *) + List.iter + (fun {Ast.Ty.data_id; _} -> + let lazy ty = ID.Tbl.find ty_tbl_ data_id in + ignore (Lazy.force ty.ty_card); + begin match ty.ty_cell with + | Atomic (_, Data {data_cstors=lazy _; _}) -> () + | _ -> assert false + end) + l + | Ast.Define (k,l) -> + (* declare the mutually recursive functions *) + List.iter + (fun (id,ty,rhs) -> + let ty = conv_ty ty in + let rhs = lazy (conv_term_rec empty_env rhs) in + let k = match k with + | Ast.Recursive -> Cst_recursive + | Ast.Non_recursive -> Cst_non_recursive + in + let cst = lazy ( + Cst.make_defined id ty rhs k + ) in + ID.Tbl.add decl_ty_ id cst) + l; + (* force thunks *) + List.iter + (fun (id,_,_) -> ignore (ID.Tbl.find decl_ty_ id |> Lazy.force)) + l + end + + let add_statement_l = List.iter add_statement + + module A = Ast + + let rec ty_to_ast (t:Ty.t): A.Ty.t = match t.ty_cell with + | Prop -> A.Ty.Prop + | Atomic (id,_) -> A.Ty.const id + | Arrow (a,b) -> A.Ty.arrow (ty_to_ast a) (ty_to_ast b) + + let fresh_var = + let n = ref 0 in + fun ty -> + let id = ID.makef "x%d" !n in + incr n; + A.Var.make id (ty_to_ast ty) + + let with_var ty env ~f = + let v = fresh_var ty in + let env = DB_env.push (A.var v) env in + f v env + + let term_to_ast (t:term): Ast.term = + let rec aux env t = match t.term_cell with + | True -> A.true_ + | False -> A.false_ + | DB d -> + begin match DB_env.get d env with + | Some t' -> t' + | None -> errorf "cannot find DB %a in env" Term.pp t + end + | App_cst (f, args) when IArray.is_empty args -> + A.const f.cst_id (ty_to_ast t.term_ty) + | App_cst (f, args) -> + let f = A.const f.cst_id (ty_to_ast (Cst.ty f)) in + let args = IArray.map (aux env) args in + A.app f (IArray.to_list args) + | App_ho (f,l) -> A.app (aux env f) (List.map (aux env) l) + | Fun (ty,bod) -> + with_var ty env + ~f:(fun v env -> A.fun_ v (aux env bod)) + | Mu _ -> assert false + | If (a,b,c) -> A.if_ (aux env a)(aux env b) (aux env c) + | Case (u,m) -> + let u = aux env u in + let m = + ID.Map.mapi + (fun _c_id _rhs -> + assert false (* TODO: fetch cstor; bind variables; convert rhs *) + (* + with_vars tys env ~f:(fun vars env -> vars, aux env rhs) + *) + ) + m + in + A.match_ u m + | Builtin b -> + begin match b with + | B_not t -> A.not_ (aux env t) + | B_and (a,b) -> A.and_ (aux env a) (aux env b) + | B_or (a,b) -> A.or_ (aux env a) (aux env b) + | B_eq (a,b) -> A.eq (aux env a) (aux env b) + | B_imply (a,b) -> A.imply (aux env a) (aux env b) + end + in aux DB_env.empty t +end + *) + +(** {2 Result} *) + +type unknown = + | U_timeout + | U_max_depth + | U_incomplete + +type model = Model.t +let pp_model = Model.pp + +type res = + | Sat of model + | Unsat (* TODO: proof *) + | Unknown of unknown + +(* FIXME: repair this and output a nice model. +module Model_build : sig + val make: unit -> model + + val check : model -> unit +end = struct + module ValueListMap = CCMap.Make(struct + type t = Term.t list (* normal forms *) + let compare = CCList.compare Term.compare + end) + + type doms = { + dom_of_ty: ID.t list Ty.Tbl.t; (* uninterpreted type -> domain elements *) + dom_of_class: term Term.Tbl.t; (* representative -> normal form *) + dom_of_cst: term Cst.Tbl.t; (* cst -> its normal form *) + dom_of_fun: term ValueListMap.t Cst.Tbl.t; (* function -> args -> normal form *) + dom_traversed: unit Term.Tbl.t; (* avoid cycles *) + } + + let create_doms() : doms = + { dom_of_ty=Ty.Tbl.create 32; + dom_of_class = Term.Tbl.create 32; + dom_of_cst=Cst.Tbl.create 32; + dom_of_fun=Cst.Tbl.create 32; + dom_traversed=Term.Tbl.create 128; + } + + (* pick a term belonging to this type. + we just generate a new constant, as picking true/a constructor might + refine the partial model into an unsatisfiable state. *) + let pick_default ~prefix (doms:doms)(ty:Ty.t) : term = + (* introduce a fresh constant for this equivalence class *) + let elts = Ty.Tbl.get_or ~default:[] doms.dom_of_ty ty in + let cst = ID.makef "%s%s_%d" prefix (Ty.mangle ty) (List.length elts) in + let nf = Term.const (Cst.make_undef cst ty) in + Ty.Tbl.replace doms.dom_of_ty ty (cst::elts); + nf + + (* follow "normal form" pointers deeply in the term *) + let deref_deep (doms:doms) (t:term) : term = + let rec aux t = + let repr = (CC.find t :> term) in + (* if not already done, traverse all parents to update the functions' + models *) + if not (Term.Tbl.mem doms.dom_traversed repr) then ( + Term.Tbl.add doms.dom_traversed repr (); + Bag.to_seq repr.term_parents |> Sequence.iter aux_ignore; + ); + (* find a normal form *) + let nf: term = + begin match CC.normal_form t with + | Some (NF_bool true) -> Term.true_ + | Some (NF_bool false) -> Term.false_ + | Some (NF_cstor (cstor, args)) -> + (* cstor applied to sub-normal forms *) + Term.app_cst cstor.cstor_cst (IArray.map aux args) + | None -> + let repr = (CC.find t :> term) in + begin match Term.Tbl.get doms.dom_of_class repr with + | Some u -> u + | None when Ty.is_uninterpreted t.term_ty -> + let nf = pick_default ~prefix:"$" doms t.term_ty in + Term.Tbl.add doms.dom_of_class repr nf; + nf + | None -> + let nf = pick_default ~prefix:"?" doms t.term_ty in + Term.Tbl.add doms.dom_of_class repr nf; + nf + end + end + in + (* update other tables *) + begin match t.term_cell with + | True | False -> assert false (* should have normal forms *) + | Fun _ | DB _ | Mu _ + -> () + | Builtin b -> ignore (Term.map_builtin aux b) + | If (a,b,c) -> aux_ignore a; aux_ignore b; aux_ignore c + | App_ho (f, l) -> aux_ignore f; List.iter aux_ignore l + | Case (t, m) -> aux_ignore t; ID.Map.iter (fun _ rhs -> aux_ignore rhs) m + | App_cst (f, a) when IArray.is_empty a -> + (* remember [f := c] *) + Cst.Tbl.replace doms.dom_of_cst f nf + | App_cst (f, a) -> + (* remember [f a := c] *) + let a_values = IArray.map aux a |> IArray.to_list in + let map = + Cst.Tbl.get_or ~or_:ValueListMap.empty doms.dom_of_fun f + in + Cst.Tbl.replace doms.dom_of_fun f (ValueListMap.add a_values nf map) + end; + nf + and aux_ignore t = + ignore (aux t) + in + aux t + + (* TODO: maybe we really need a notion of "Undefined" that is + also not a domain element (i.e. equality not defined on it) + + some syntax for it *) + + (* build the model of a function *) + let model_of_fun (doms:doms) (c:cst): Ast.term = + let ty_args, ty_ret = Ty.unfold (Cst.ty c) in + assert (ty_args <> []); + let vars = + List.mapi + (fun i ty -> Ast.Var.make (ID.makef "x_%d" i) (Conv.ty_to_ast ty)) + ty_args + in + let default = match ty_ret.ty_cell with + | Prop -> Ast.true_ (* should be safe: we would have split it otherwise *) + | _ -> + (* TODO: what about other finites types? *) + pick_default ~prefix:"?" doms ty_ret |> Conv.term_to_ast + in + let cases = + Cst.Tbl.get_or ~or_:ValueListMap.empty doms.dom_of_fun c + |> ValueListMap.to_list + |> List.map + (fun (args,rhs) -> + assert (List.length ty_args = List.length vars); + let tests = + List.map2 + (fun v arg -> Ast.eq (Ast.var v) (Conv.term_to_ast arg)) + vars args + in + Ast.and_l tests, Conv.term_to_ast rhs) + in + (* decision tree for the body *) + let body = + List.fold_left + (fun else_ (test, then_) -> Ast.if_ test then_ else_) + default cases + in + Ast.fun_l vars body + + let make () : model = + let env = !model_env_ in + let doms = create_doms () in + (* compute values of meta variables *) + let consts = + !model_support_ + |> Sequence.of_list + |> Sequence.filter_map + (fun c -> + if Ty.is_arrow (Cst.ty c) then None + else + (* find normal form of [c] *) + let t = Term.const c in + let t = deref_deep doms t |> Conv.term_to_ast in + Some (c.cst_id, t)) + |> ID.Map.of_seq + in + (* now compute functions (the previous "deref_deep" have updated their use cases) *) + let consts = + !model_support_ + |> Sequence.of_list + |> Sequence.filter_map + (fun c -> + if Ty.is_arrow (Cst.ty c) + then ( + let t = model_of_fun doms c in + Some (c.cst_id, t) + ) else None) + |> ID.Map.add_seq consts + in + (* now we can convert domains *) + let domains = + Ty.Tbl.to_seq doms.dom_of_ty + |> Sequence.filter_map + (fun (ty,dom) -> + if Ty.is_uninterpreted ty + then Some (Conv.ty_to_ast ty, List.rev dom) + else None) + |> Ast.Ty.Map.of_seq + (* and update env: add every domain element to it *) + and env = + Ty.Tbl.to_seq doms.dom_of_ty + |> Sequence.flat_map_l (fun (_,dom) -> dom) + |> Sequence.fold + (fun env id -> Ast.env_add_def env id Ast.E_uninterpreted_cst) + env + in + Model.make ~env ~consts ~domains + + let check m = + Log.debugf 1 (fun k->k "checking model…"); + Log.debugf 5 (fun k->k "(@[<1>candidate model: %a@])" Model.pp m); + let goals = + Top_goals.to_seq + |> Sequence.map Conv.term_to_ast + |> Sequence.to_list + in + Model.check m ~goals +end + *) + +(** {2 Main} *) + +let[@inline] clause_of_mclause (c:Sat.clause): Clause.t = + Sat.Clause.atoms_l c |> Clause.make + +(* convert unsat-core *) +let clauses_of_unsat_core (core:Sat.clause list): Clause.t Sequence.t = + Sequence.of_list core + |> Sequence.map clause_of_mclause + +(* print all terms reachable from watched literals *) +let pp_term_graph _out (_:t) = + () + +let pp_stats out (s:t) : unit = + Format.fprintf out + "(@[stats@ \ + :num_expanded %d@ \ + :num_uty_expanded %d@ \ + :num_clause_push %d@ \ + :num_clause_tautology %d@ \ + :num_propagations %d@ \ + :num_unif %d@ \ + @])" + s.stat.Stat.num_cst_expanded + s.stat.Stat.num_uty_expanded + s.stat.Stat.num_clause_push + s.stat.Stat.num_clause_tautology + s.stat.Stat.num_propagations + s.stat.Stat.num_unif + +let do_on_exit ~on_exit = + List.iter (fun f->f()) on_exit; + () + +let add_statement_l (_:t) _ = () +(* FIXME + Conv.add_statement_l + *) + +(* TODO: move this into submodule *) +let pp_proof out p = + let pp_step_res out p = + let {Sat.Proof.conclusion; _ } = Sat.Proof.expand p in + let conclusion = clause_of_mclause conclusion in + Clause.pp out conclusion + in + let pp_step out = function + | Sat.Proof.Lemma _ -> Format.fprintf out "(@[<1>lemma@ ()@])" + | Sat.Proof.Resolution (p1, p2, _) -> + Format.fprintf out "(@[<1>resolution@ %a@ %a@])" + pp_step_res p1 pp_step_res p2 + | _ -> Fmt.string out "" + in + Format.fprintf out "(@["; + Sat.Proof.fold + (fun () {Sat.Proof.conclusion; step } -> + let conclusion = clause_of_mclause conclusion in + Format.fprintf out "(@[step@ %a@ @[<1>from:@ %a@]@])@," + Clause.pp conclusion pp_step step) + () p; + Format.fprintf out "@])"; + () + +(* +type unsat_core = Sat.clause list + *) + +(* TODO: main loop with iterative deepening of the unrolling limit + (not the value depth limit) *) +let solve ?on_exit:(_=[]) ?check:(_=true) (_self:t) : res = + Unknown U_incomplete + +(* FIXME +(* TODO: max_depth should actually correspond to the maximum depth + of un-expanded terms (expand in body of t --> depth = depth(t)+1), + so it corresponds to unfolding call graph to some depth *) + +let solve ?(on_exit=[]) ?(check=true) () = + let n_iter = ref 0 in + let rec check_cc (): res = + assert (Backtrack.at_level_0 ()); + if !n_iter > Config.max_depth then Unknown U_max_depth (* exceeded limit *) + else begin match CC.check () with + | CC.Unsat _ -> Unsat (* TODO proof *) + | CC.Sat lemmas -> + add_cc_lemmas lemmas; + check_solver() + end + + and check_solver (): res = + (* assume all literals [expanded t] are false *) + let assumptions = + Terms_to_expand.to_seq + |> Sequence.map (fun {Terms_to_expand.lit; _} -> Lit.neg lit) + |> Sequence.to_rev_list + in + incr n_iter; + Log.debugf 2 + (fun k->k + "(@[<1>@{solve@}@ @[:with-assumptions@ (@[%a@])@ n_iter: %d]@])" + (Util.pp_list Lit.pp) assumptions !n_iter); + begin match M.solve ~assumptions() with + | M.Sat _ -> + Log.debugf 1 (fun k->k "@{** found SAT@}"); + do_on_exit ~on_exit; + let m = Model_build.make () in + if check then Model_build.check m; + Sat m + | M.Unsat us -> + let p = us.SI.get_proof () in + Log.debugf 4 (fun k->k "proof: @[%a@]@." pp_proof p); + let core = p |> M.unsat_core in + (* check if unsat because of assumptions *) + expand_next core + end + + (* pick a term to expand, or UNSAT *) + and expand_next (core:unsat_core) = + begin match find_to_expand core with + | None -> Unsat (* TODO proof *) + | Some to_expand -> + let t = to_expand.Terms_to_expand.term in + Log.debugf 2 (fun k->k "(@[<1>@{expand_next@}@ :term %a@])" Term.pp t); + CC.expand_term t; + Terms_to_expand.remove t; + Clause.push_new (Clause.make [to_expand.Terms_to_expand.lit]); + Backtrack.backtrack_to_level_0 (); + check_cc () (* recurse *) + end + in + check_cc() + + *) diff --git a/src/smt/Solver.mli b/src/smt/Solver.mli new file mode 100644 index 00000000..d94460a6 --- /dev/null +++ b/src/smt/Solver.mli @@ -0,0 +1,57 @@ + +(* This file is free software. See file "license" for more details. *) + +(** {1 Solver} + + The solving algorithm, based on MCSat *) + +open CDCL + +type term +type cst +type ty = Solver_types.ty (** types *) +type ty_def = Solver_types.ty_def + +type ty_cell = Solver_types.ty_cell = + | Prop + | Atomic of ID.t * ty_def + | Arrow of ty * ty + +(** {2 Result} *) + +type model = Model.t + +type unknown = + | U_timeout + | U_max_depth + | U_incomplete + +type res = + | Sat of Model.t + | Unsat (* TODO: proof *) + | Unknown of unknown + +(** {2 Main} *) + +type t +(** Solver state *) + +val create : + ?size:[`Big | `Tiny | `Small] -> + ?config:Config.t -> + theories:Theory.t list -> + unit -> t + +val add_statement_l : t -> Ast.statement list -> unit + +val solve : + ?on_exit:(unit -> unit) list -> + ?check:bool -> + t -> + res +(** [solve s] checks the satisfiability of the statement added so far to [s] + @param check if true, the model is checked before returning + @param on_exit functions to be run before this returns *) + +val pp_term_graph: t CCFormat.printer +val pp_stats : t CCFormat.printer diff --git a/src/smt/Stat.ml b/src/smt/Stat.ml new file mode 100644 index 00000000..26854a46 --- /dev/null +++ b/src/smt/Stat.ml @@ -0,0 +1,18 @@ + +type t = { + mutable num_cst_expanded : int; + mutable num_uty_expanded : int; + mutable num_clause_push : int; + mutable num_clause_tautology : int; + mutable num_propagations : int; + mutable num_unif : int; +} + +let create () : t = { + num_cst_expanded = 0; + num_uty_expanded = 0; + num_clause_push = 0; + num_clause_tautology = 0; + num_propagations = 0; + num_unif = 0; +} diff --git a/src/smt/Term.ml b/src/smt/Term.ml index 465b2469..f4fbac4c 100644 --- a/src/smt/Term.ml +++ b/src/smt/Term.ml @@ -114,6 +114,9 @@ let fold_map_builtin let acc, b = f acc b in acc, B_imply (a, b) +let[@inline] is_true t = match t.term_cell with True -> true | _ -> false +let is_false t = match t.term_cell with Builtin (B_not u) -> is_true u | _ -> false + let[@inline] is_const t = match t.term_cell with | App_cst (_, a) -> IArray.is_empty a | _ -> false diff --git a/src/smt/Term.mli b/src/smt/Term.mli index 6ae651c3..45765a31 100644 --- a/src/smt/Term.mli +++ b/src/smt/Term.mli @@ -49,8 +49,9 @@ val pp : t Fmt.printer (** {6 Views} *) +val is_true : t -> bool +val is_false : t -> bool val is_const : t -> bool - val is_custom : t -> bool val is_semantic : t -> bool diff --git a/src/smt/Term_cell.ml b/src/smt/Term_cell.ml index 3e25b5c1..f4349a36 100644 --- a/src/smt/Term_cell.ml +++ b/src/smt/Term_cell.ml @@ -1,5 +1,4 @@ -open CDCL open Solver_types (* TODO: normalization of {!term_cell} for use in signatures? *) diff --git a/src/smt/Term_cell.mli b/src/smt/Term_cell.mli index 71d3c2a8..9d5f17a6 100644 --- a/src/smt/Term_cell.mli +++ b/src/smt/Term_cell.mli @@ -1,5 +1,4 @@ -open CDCL open Solver_types type t = term term_cell diff --git a/src/smt/Theory.ml b/src/smt/Theory.ml new file mode 100644 index 00000000..02f76de5 --- /dev/null +++ b/src/smt/Theory.ml @@ -0,0 +1,60 @@ + +open Solver_types + +(** Runtime state of a theory, with all the operations it provides *) +type state = { + on_merge: Equiv_class.t -> Equiv_class.t -> Explanation.t -> unit; + (** Called when two classes are merged *) + + on_assert: Lit.t -> unit; + (** Called when a literal becomes true *) + + final_check: Lit.t Sequence.t -> unit; + (** Final check, must be complete (i.e. must raise a conflict + if the set of literals is not satisfiable) *) +} + +(** Unsatisfiable conjunction. + Will be turned into a set of literals, whose negation becomes a + conflict clause *) +type conflict = Explanation.t Bag.t + +(** Actions available to a theory during its lifetime *) +type actions = { + on_backtrack: (unit -> unit) -> unit; + (** Register an action to do when we backtrack *) + + at_lvl_0: unit -> bool; + (** Are we at level 0 of backtracking? *) + + raise_conflict: 'a. conflict -> 'a; + (** Give a conflict clause to the solver *) + + propagate_eq: Term.t -> Term.t -> Explanation.t -> unit; + (** Propagate an equality [t = u] because [e] *) + + propagate: Lit.t -> Explanation.t Bag.t -> unit; + (** Propagate a boolean using a unit clause. + [expl => lit] must be a theory lemma, that is, a T-tautology *) + + case_split: Clause.t -> unit; + (** Force the solver to case split on this clause. + The clause will be removed upon backtracking. *) + + add_axiom: Clause.t -> unit; + (** Add a persistent axiom to the SAT solver. This will not + be backtracked *) + + find: Term.t -> Equiv_class.t; + (** Find representative of this term *) + + all_classes: Equiv_class.t Sequence.t; + (** All current equivalence classes + (caution: linear in the number of terms existing in the solver) *) +} + +type t = { + name: string; + make: Term.state -> actions -> state; +} + diff --git a/src/smt/Theory_combine.ml b/src/smt/Theory_combine.ml new file mode 100644 index 00000000..fe4aa1a6 --- /dev/null +++ b/src/smt/Theory_combine.ml @@ -0,0 +1,245 @@ + +(** {1 Main theory} *) + +(** Combine the congruence closure with a number of plugins *) + +open Solver_types + +module Proof = struct + type t = Proof + let default = Proof +end + +module Form = Lit + +type formula = Lit.t +type proof = Proof.t + +type conflict = Explanation.t Bag.t + +(* raise upon conflict *) +exception Exn_conflict of conflict + +type t = { + cdcl_acts: (formula,proof) CDCL.actions; + (** actions provided by the SAT solver *) + tst: Term.state; + (** state for managing terms *) + cc: Congruence_closure.t lazy_t; + (** congruence closure *) + mutable theories : Theory.state list; + (** Set of theories *) + lemma_q : Clause.t Queue.t; + (** list of clauses that have been newly generated, waiting + to be propagated to the core solver. + invariant: those clauses must be tautologies *) + split_q : Clause.t Queue.t; + (** Local clauses to be added to the core solver, that will + be removed on backtrack *) + mutable conflict: conflict option; + (** current conflict, if any *) +} + +let[@inline] cc t = Lazy.force t.cc +let[@inline] tst t = t.tst +let[@inline] theories (self:t) : Theory.state Sequence.t = + fun k -> List.iter k self.theories + +(** {2 Interface with the SAT solver} *) + +(* handle a literal assumed by the SAT solver *) +let assume_lit (self:t) (lit:Lit.t) : unit = + CDCL.Log.debugf 2 + (fun k->k "(@[<1>@{theory_combine.assume_lit@}@ @[%a@]@])" Lit.pp lit); + (* check consistency first *) + begin match Lit.view lit with + | Lit_fresh _ -> () + | Lit_expanded _ + | Lit_atom {term_cell=True; _} -> () + | Lit_atom t when Term.is_false t -> assert false + | Lit_atom _ -> + (* transmit to CC and theories *) + Congruence_closure.assert_lit (cc self) lit; + theories self (fun th -> th.Theory.on_assert lit); + end + +(* push clauses from {!lemma_queue} into the slice *) +let push_new_clauses_into_cdcl (self:t) : unit = + let CDCL.Actions r = self.cdcl_acts in + (* persistent lemmas *) + while not (Queue.is_empty self.lemma_q) do + let c = Queue.pop self.lemma_q in + CDCL.Log.debugf 5 (fun k->k "(@[<2>push_lemma@ %a@])" Clause.pp c); + r.push c Proof.default + done; + (* local splits *) + while not (Queue.is_empty self.split_q) do + let c = Queue.pop self.split_q in + CDCL.Log.debugf 5 (fun k->k "(@[<2>split_on@ %a@])" Clause.pp c); + r.push_local c Proof.default + done + +(* return result to the SAT solver *) +let cdcl_return_res (self:t) : _ CDCL.res = + begin match self.conflict with + | None -> + push_new_clauses_into_cdcl self; + CDCL.Sat + | Some c -> + let lit_set = + Bag.to_seq c + |> Congruence_closure.explain_unfold (cc self) + in + let conflict_clause = + Lit.Set.to_list lit_set + |> List.map Lit.neg + |> Clause.make + in + CDCL.Log.debugf 3 + (fun k->k "(@[<1>conflict@ clause: %a@])" + Clause.pp conflict_clause); + CDCL.Unsat (Clause.lits conflict_clause, Proof.default) + end + +let[@inline] check (self:t) : unit = + Congruence_closure.check (cc self) + +(* propagation from the bool solver *) +let assume_real (self:t) (slice:_ CDCL.slice_actions) = + (* TODO if Config.progress then print_progress(); *) + let CDCL.Slice_acts slice = slice in + begin + try + slice.slice_iter (assume_lit self); + (* now check satisfiability *) + check self; + with Exn_conflict c -> + assert (CCOpt.is_none self.conflict); + self.conflict <- Some c; + end; + cdcl_return_res self + +(* propagation from the bool solver *) +let assume (self:t) (slice:_ CDCL.slice_actions) = + match self.conflict with + | None -> assume_real self slice + | Some _ -> + (* already in conflict! *) + cdcl_return_res self + +(* perform final check of the model *) +let if_sat (self:t) (slice:_) : _ CDCL.res = + Congruence_closure.final_check (cc self); + (* all formulas in the SAT solver's trail *) + let forms = + let CDCL.Slice_acts r = slice in + r.slice_iter + in + (* final check for each theory *) + theories self + (fun th -> th.Theory.final_check forms); + cdcl_return_res self + +(** {2 Various helpers} *) + +(* forward propagations from CC or theories directly to the SMT core *) +let act_propagate (self:t) f guard : unit = + let CDCL.Actions r = self.cdcl_acts in + let guard = + Bag.to_seq guard + |> Congruence_closure.explain_unfold (cc self) + |> Lit.Set.to_list + in + CDCL.Log.debugf 2 + (fun k->k "(@[@{propagate@}@ %a@ :guard %a@])" + Lit.pp f Clause.pp guard); + r.propagate f guard Proof.default + +(** {2 Interface to Congruence Closure} *) + +let act_raise_conflict e = raise (Exn_conflict e) + +(* when CC decided to merge [r1] and [r2], notify theories *) +let on_merge_from_cc (self:t) r1 r2 e : unit = + theories self + (fun th -> th.Theory.on_merge r1 r2 e) + +let mk_cc_actions (self:t) : Congruence_closure.actions = + let CDCL.Actions r = self.cdcl_acts in + { + Congruence_closure. + on_backtrack = r.on_backtrack; + at_lvl_0 = r.at_level_0; + on_merge = on_merge_from_cc self; + raise_conflict = act_raise_conflict; + propagate = act_propagate self; + } + +(** {2 Main} *) + +(* create a new theory combination *) +let create (cdcl_acts:_ CDCL.actions) : t = + CDCL.Log.debug 5 "theory_combine.create"; + let rec self = { + cdcl_acts; + tst=Term.create ~size:1024 (); + cc = lazy ( + (* lazily tie the knot *) + let actions = mk_cc_actions self in + Congruence_closure.create ~size:1024 ~actions self.tst; + ); + theories = []; + lemma_q = Queue.create(); + split_q = Queue.create(); + conflict = None; + } in + ignore @@ Lazy.force @@ self.cc; + self + +(** {2 Interface to individual theories} *) + +let act_all_classes self = Congruence_closure.all_classes (cc self) + +let act_propagate_eq self t u guard = + let r_t = Congruence_closure.add (cc self) t in + let r_u = Congruence_closure.add (cc self) u in + Congruence_closure.union (cc self) r_t r_u guard + +let act_find self t = + Congruence_closure.add (cc self) t + |> Congruence_closure.find (cc self) + +let act_case_split self (c:Clause.t) = + CDCL.Log.debugf 2 (fun k->k "(@[<1>add_split@ @[%a@]@])" Clause.pp c); + Queue.push c self.split_q + +(* push one clause into [M], in the current level (not a lemma but + an axiom) *) +let act_add_axiom self (c:Clause.t): unit = + CDCL.Log.debugf 2 (fun k->k "(@[<1>add_axiom@ @[%a@]@])" Clause.pp c); + (* TODO incr stat_num_clause_push; *) + Queue.push c self.lemma_q + +let mk_theory_actions (self:t) : Theory.actions = + let CDCL.Actions r = self.cdcl_acts in + { + Theory. + on_backtrack = r.on_backtrack; + at_lvl_0 = r.at_level_0; + raise_conflict = act_raise_conflict; + propagate = act_propagate self; + all_classes = act_all_classes self; + propagate_eq = act_propagate_eq self; + case_split = act_case_split self; + add_axiom = act_add_axiom self; + find = act_find self; + } + +let add_theory (self:t) (th:Theory.t) : unit = + CDCL.Log.debugf 2 + (fun k->k "(@[theory_combine.add_th@ :name %S@])" th.Theory.name); + let th_s = th.Theory.make self.tst (mk_theory_actions self) in + self.theories <- th_s :: self.theories + +let add_theory_l self = List.iter (add_theory self) + diff --git a/src/smt/Theory_combine.mli b/src/smt/Theory_combine.mli new file mode 100644 index 00000000..25ccdd30 --- /dev/null +++ b/src/smt/Theory_combine.mli @@ -0,0 +1,21 @@ + +(** {1 Main theory} *) + +(** Combine the congruence closure with a number of plugins *) + +module Proof : sig + type t = Proof +end + +include CDCL.Theory_intf.S + with type formula = Lit.t + and type proof = Proof.t + +val cc : t -> Congruence_closure.t +val tst : t -> Term.state +val theories : t -> Theory.state Sequence.t + +val add_theory : t -> Theory.t -> unit +(** How to add new theories *) + +val add_theory_l : t -> Theory.t list -> unit diff --git a/src/smt/Util.ml b/src/smt/Util.ml index 83f814b2..56843058 100644 --- a/src/smt/Util.ml +++ b/src/smt/Util.ml @@ -12,6 +12,9 @@ let pp_sep sep out () = Format.fprintf out "%s@," sep let pp_list ?(sep=" ") pp out l = Fmt.list ~sep:(pp_sep sep) pp out l +let pp_seq ?(sep=" ") pp out l = + Fmt.seq ~sep:(pp_sep sep) pp out l + let pp_pair ?(sep=" ") pp1 pp2 out t = Fmt.pair ~sep:(pp_sep sep) pp1 pp2 out t diff --git a/src/smt/Util.mli b/src/smt/Util.mli index f39267f0..165082ed 100644 --- a/src/smt/Util.mli +++ b/src/smt/Util.mli @@ -7,6 +7,8 @@ type 'a printer = 'a CCFormat.printer val pp_list : ?sep:string -> 'a printer -> 'a list printer +val pp_seq : ?sep:string -> 'a printer -> 'a Sequence.t printer + val pp_array : ?sep:string -> 'a printer -> 'a array printer val pp_pair : ?sep:string -> 'a printer -> 'b printer -> ('a * 'b) printer