diff --git a/src/main/main.ml b/src/main/main.ml index 7a0ad1c8..ada5dd9f 100644 --- a/src/main/main.ml +++ b/src/main/main.ml @@ -123,10 +123,11 @@ let main () = (* process statements *) let res = try + let hyps = Vec.make_empty [] in E.fold_l (fun () -> Process.process_stmt - ~gc:!gc ~restarts:!restarts ~pp_cnf:!p_cnf + ~hyps ~gc:!gc ~restarts:!restarts ~pp_cnf:!p_cnf ~time:!time_limit ~memory:!size_limit ?dot_proof ~pp_model:!p_model ~check:!check ~progress:!p_progress solver) diff --git a/src/sat/Solver_intf.ml b/src/sat/Solver_intf.ml index 26fea6fd..e41e17cc 100644 --- a/src/sat/Solver_intf.ml +++ b/src/sat/Solver_intf.ml @@ -152,7 +152,9 @@ module type S = sig val compare : t -> t -> int val equal : t -> t -> bool val get_formula : t -> formula + val is_true : t -> bool + val dummy : t val make : solver -> formula -> t val pp : t printer diff --git a/src/smt/Congruence_closure.ml b/src/smt/Congruence_closure.ml index fdd8b6ea..ed960bf3 100644 --- a/src/smt/Congruence_closure.ml +++ b/src/smt/Congruence_closure.ml @@ -587,26 +587,65 @@ let final_check cc : unit = (* model: map each uninterpreted equiv class to some ID *) let mk_model (cc:t) (m:Model.t) : Model.t = (* populate [repr -> value] table *) - let tbl = Equiv_class.Tbl.create 32 in + let t_tbl = Equiv_class.Tbl.create 32 in + (* type -> default value *) + let ty_tbl = Ty.Tbl.create 8 in Term.Tbl.values cc.tbl (fun r -> if is_root_ r then ( - let v = match Model.eval m r.n_term with + let t = r.n_term in + let v = match Model.eval m t with | Some v -> v | None -> Value.mk_elt - (ID.makef "v_%d" @@ Term.id r.n_term) + (ID.makef "v_%d" @@ Term.id t) (Term.ty r.n_term) in - Equiv_class.Tbl.add tbl r v + if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then ( + Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *) + ); + Equiv_class.Tbl.add t_tbl r v )); - (* now map every uninterpreted term to its representative's value *) - Term.Tbl.to_seq cc.tbl - |> Sequence.fold - (fun m (t,r) -> - if Model.mem t m then m - else ( - let v = Equiv_class.Tbl.find tbl r in - Model.add t v m - )) - m + (* now map every uninterpreted term to its representative's value, and + create function tables *) + let m, funs = + Term.Tbl.to_seq cc.tbl + |> Sequence.fold + (fun (m,funs) (t,r) -> + let r = find cc r in (* get representative *) + match Term.view t with + | _ when Model.mem t m -> m, funs + | App_cst (c, args) -> + if Model.mem t m then m, funs + else if Cst.is_undefined c && IArray.length args > 0 then ( + (* update signature of [c] *) + let ty = Term.ty t in + let v = Equiv_class.Tbl.find t_tbl r in + let args = + args + |> IArray.map (fun t -> Equiv_class.Tbl.find t_tbl @@ find_tn cc t) + |> IArray.to_list + in + let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in + m, Cst.Map.add c (ty, (args,v)::l) funs + ) else ( + let v = Equiv_class.Tbl.find t_tbl r in + Model.add t v m, funs + ) + | _ -> + let v = Equiv_class.Tbl.find t_tbl r in + Model.add t v m, funs) + (m,Cst.Map.empty) + in + (* get or make a default value for this type *) + let get_ty_default (ty:Ty.t) : Value.t = + Ty.Tbl.get_or_add ty_tbl ~k:ty + ~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty) + in + let funs = + Cst.Map.map + (fun (ty,l) -> + Model.Fun_interpretation.make ~default:(get_ty_default ty) l) + funs + in + Model.add_funs funs m diff --git a/src/smt/Cst.ml b/src/smt/Cst.ml index 44c410ce..ddd7fe71 100644 --- a/src/smt/Cst.ml +++ b/src/smt/Cst.ml @@ -12,6 +12,8 @@ let as_undefined (c:t) = match view c with | Cst_undef ty -> Some (c,ty) | Cst_def _ -> None +let[@inline] is_undefined c = match view c with Cst_undef _ -> true | _ -> false + let as_undefined_exn (c:t) = match as_undefined c with | Some tup -> tup | None -> assert false diff --git a/src/smt/Cst.mli b/src/smt/Cst.mli index 08d118eb..b6d98dc8 100644 --- a/src/smt/Cst.mli +++ b/src/smt/Cst.mli @@ -11,6 +11,7 @@ val compare : t -> t -> int val hash : t -> int val as_undefined : t -> (t * Ty.Fun.t) option val as_undefined_exn : t -> t * Ty.Fun.t +val is_undefined : t -> bool val mk_undef : ID.t -> Ty.Fun.t -> t val mk_undef_const : ID.t -> Ty.t -> t diff --git a/src/smt/Model.ml b/src/smt/Model.ml index 2407eaea..b3057484 100644 --- a/src/smt/Model.ml +++ b/src/smt/Model.ml @@ -5,11 +5,58 @@ open Solver_types +module Val_map = struct + module M = CCIntMap + module Key = struct + type t = Value.t list + let equal = CCList.equal Value.equal + let hash = Hash.list Value.hash + end + type key = Key.t + + type 'a t = (key * 'a) list M.t + + let empty = M.empty + let is_empty m = M.cardinal m = 0 + + let cardinal = M.cardinal + + let find k m = + try Some (CCList.assoc ~eq:Key.equal k @@ M.find_exn (Key.hash k) m) + with Not_found -> None + + let add k v m = + let h = Key.hash k in + let l = M.find h m |> CCOpt.get_or ~default:[] in + let l = CCList.Assoc.set ~eq:Key.equal k v l in + M.add h l m + + let to_seq m yield = M.iter (fun _ l -> List.iter yield l) m +end + +module Fun_interpretation = struct + type t = { + cases: Value.t Val_map.t; + default: Value.t; + } + + let default fi = fi.default + let cases_list fi = Val_map.to_seq fi.cases |> Sequence.to_rev_list + + let make ~default l : t = + let m = List.fold_left (fun m (k,v) -> Val_map.add k v m) Val_map.empty l in + { cases=m; default } +end + type t = { values: Value.t Term.Map.t; + funs: Fun_interpretation.t Cst.Map.t; } -let empty : t = {values=Term.Map.empty} +let empty : t = { + values=Term.Map.empty; + funs=Cst.Map.empty; +} let[@inline] mem t m = Term.Map.mem t m.values let[@inline] find t m = Term.Map.get t m.values @@ -23,7 +70,13 @@ let add t v m : t = ); m | exception Not_found -> - {values=Term.Map.add t v m.values} + {m with values=Term.Map.add t v m.values} + +let add_fun c v m : t = + match Cst.Map.find c m.funs with + | _ -> Error.errorf "@[Model: function %a already has an interpretation@]" Cst.pp c + | exception Not_found -> + {m with funs=Cst.Map.add c v m.funs} (* merge two models *) let merge m1 m2 : t = @@ -36,17 +89,36 @@ let merge m1 m2 : t = Error.errorf "@[Model: incompatible values for term %a@ :previous %a@ :new %a@]" Term.pp t Value.pp v1 Value.pp v2 )) + and funs = + Cst.Map.merge_safe m1.funs m2.funs + ~f:(fun c o -> match o with + | `Left v | `Right v -> Some v + | `Both _ -> + Error.errorf "cannot merge the two interpretations of function %a" Cst.pp c) in - {values} + {values; funs} -let pp out (m:t) = - let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ -> %a@])" Term.pp t Value.pp v in - Fmt.fprintf out "(@[model@ %a@])" - (Fmt.seq ~sep:Fmt.(return "@ ") pp_tv) (Term.Map.to_seq m.values) +let add_funs fs m : t = merge {values=Term.Map.empty; funs=fs} m + +let pp out {values; funs} = + let module FI = Fun_interpretation in + let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ %a@])" Term.pp t Value.pp v in + let pp_fun_entry out (vals,ret) = + Format.fprintf out "(@[%a@ %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret + in + let pp_fun out (c, fi: Cst.t * FI.t) = + Format.fprintf out "(@[%a :default %a@ %a@])" + Cst.pp c Value.pp fi.FI.default + (Fmt.list ~sep:(Fmt.return "@ ") pp_fun_entry) (FI.cases_list fi) + in + Fmt.fprintf out "(@[model@ @[:terms (@[%a@])@]@ @[:funs (@[%a@])@]@])" + (Fmt.seq ~sep:Fmt.(return "@ ") pp_tv) (Term.Map.to_seq values) + (Fmt.seq ~sep:Fmt.(return "@ ") pp_fun) (Cst.Map.to_seq funs) exception No_value let eval (m:t) (t:Term.t) : Value.t option = + let module FI = Fun_interpretation in let rec aux t = match Term.view t with | Bool b -> Value.bool b | If (a,b,c) -> @@ -64,8 +136,17 @@ let eval (m:t) (t:Term.t) : Value.t option = let args = IArray.map aux args in udef.eval args | Cst_undef _ -> - Log.debugf 5 (fun k->k "(@[model.eval.undef@ %a@])" Term.pp t); - raise No_value (* no particular interpretation *) + begin match Cst.Map.find c m.funs with + | fi -> + let args = IArray.map aux args |> IArray.to_list in + begin match Val_map.find args fi.FI.cases with + | None -> fi.FI.default + | Some v -> v + end + | exception Not_found -> + Log.debugf 5 (fun k->k "(@[model.eval.undef@ %a@])" Term.pp t); + raise No_value (* no particular interpretation *) + end end in try Some (aux t) diff --git a/src/smt/Model.mli b/src/smt/Model.mli index e0ac5ebc..c6ac4c04 100644 --- a/src/smt/Model.mli +++ b/src/smt/Model.mli @@ -3,14 +3,44 @@ (** {1 Model} *) +module Val_map : sig + type key = Value.t list + type 'a t + val empty : 'a t + val is_empty : _ t -> bool + val cardinal : _ t -> int + val find : key -> 'a t -> 'a option + val add : key -> 'a -> 'a t -> 'a t +end + +module Fun_interpretation : sig + type t = { + cases: Value.t Val_map.t; + default: Value.t; + } + + val default : t -> Value.t + val cases_list : t -> (Value.t list * Value.t) list + + val make : + default:Value.t -> + (Value.t list * Value.t) list -> + t +end + type t = { values: Value.t Term.Map.t; + funs: Fun_interpretation.t Cst.Map.t; } val empty : t val add : Term.t -> Value.t -> t -> t +val add_fun : Cst.t -> Fun_interpretation.t -> t -> t + +val add_funs : Fun_interpretation.t Cst.Map.t -> t -> t + val mem : Term.t -> t -> bool val find : Term.t -> t -> Value.t option diff --git a/src/smt/Solver.ml b/src/smt/Solver.ml index 3a14b70f..3f8ab8ab 100644 --- a/src/smt/Solver.ml +++ b/src/smt/Solver.ml @@ -202,7 +202,9 @@ let[@inline] assume_eq self t u expl : unit = let[@inline] assume_distinct self l ~neq lit : unit = Congruence_closure.assert_distinct (cc self) l lit ~neq -let check_model (s:t) = Sat_solver.check_model s.solver +let check_model (s:t) : unit = + Log.debug 1 "(smt.solver.check-model)"; + Sat_solver.check_model s.solver (* TODO: main loop with iterative deepening of the unrolling limit (not the value depth limit) *) diff --git a/src/smt/Solver_types.ml b/src/smt/Solver_types.ml index b9f4d4ed..6ea45fe3 100644 --- a/src/smt/Solver_types.ml +++ b/src/smt/Solver_types.ml @@ -130,6 +130,7 @@ and value = view: value_custom_view; pp: value_custom_view Fmt.printer; eq: value_custom_view -> value_custom_view -> bool; + hash: value_custom_view -> int; } (** Custom value *) and value_custom_view = .. @@ -192,6 +193,11 @@ let eq_value a b = match a, b with | V_bool _, _ | V_element _, _ | V_custom _, _ -> false +let hash_value a = match a with + | V_bool a -> Hash.bool a + | V_element e -> ID.hash e.id + | V_custom x -> x.hash x.view + let pp_value out = function | V_bool b -> Fmt.bool out b | V_element e -> ID.pp out e.id diff --git a/src/smt/Ty.ml b/src/smt/Ty.ml index 7b944667..f175044e 100644 --- a/src/smt/Ty.ml +++ b/src/smt/Ty.ml @@ -5,7 +5,8 @@ type t = ty type view = Solver_types.ty_view type def = Solver_types.ty_def -let view t = t.ty_view +let[@inline] id t = t.ty_id +let[@inline] view t = t.ty_view let equal = eq_ty let[@inline] compare a b = CCInt.compare a.ty_id b.ty_id diff --git a/src/smt/Ty.mli b/src/smt/Ty.mli index 6ee5413c..7976e1ed 100644 --- a/src/smt/Ty.mli +++ b/src/smt/Ty.mli @@ -7,6 +7,7 @@ type t = Solver_types.ty type view = Solver_types.ty_view type def = Solver_types.ty_def +val id : t -> int val view : t -> view val prop : t diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index c9749348..81616e74 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -217,6 +217,47 @@ end let conv_ty = Conv.conv_ty let conv_term = Conv.conv_term +(* check SMT model *) +let check_smt_model (solver:Solver.Sat_solver.t) (hyps:_ Vec.t) (m:Model.t) : unit = + Log.debug 1 "(smt.check-smt-model)"; + let open Solver_types in + let module S = Solver.Sat_solver in + let check_atom (lit:Lit.t) : bool option = + Log.debugf 5 (fun k->k "(@[smt.check-smt-model.atom@ %a@])" Lit.pp lit); + let a = S.Atom.make solver lit in + let is_true = S.Atom.is_true a in + let is_false = S.Atom.is_true (S.Atom.neg a) in + let sat_value = if is_true then Some true else if is_false then Some false else None in + begin match Lit.as_atom lit with + | None -> assert false + | Some (t, sign) -> + match Model.eval m t with + | Some (V_bool b) -> + let b = if sign then b else not b in + if (is_true || is_false) && ((b && is_false) || (not b && is_true)) then ( + Error.errorf "(@[check-model.error@ :atom %a@ :model-val %B@ :sat-val %B@])" + S.Atom.pp a b (if is_true then true else not is_false) + ) + | Some v -> + Error.errorf "(@[check-model.error@ :atom %a@ :non-bool-value %a@])" + S.Atom.pp a Value.pp v + | None -> + if is_true || is_false then ( + Error.errorf "(@[check-model.error@ :atom %a@ :no-smt-value@ :sat-val %B@])" + S.Atom.pp a is_true + ); + end; + sat_value + in + let check_c c = + let bs = List.map check_atom c in + if List.for_all (function Some true -> false | _ -> true) bs then ( + Error.errorf "(@[check-model.error.none-true@ :clause %a@ :vals %a@])" + (Fmt.Dump.list Lit.pp) c Fmt.(Dump.list @@ Dump.option bool) bs + ); + in + Vec.iter check_c hyps + (* call the solver to check-sat *) let solve ?gc:_ @@ -225,11 +266,13 @@ let solve ?(pp_model=false) ?(check=false) ?time:_ ?memory:_ ?progress:_ - ~assumptions s : unit = + ?hyps + ~assumptions + s : unit = let t1 = Sys.time() in let res = Solver.solve ~assumptions s - (* ?gc ?restarts ?time ?memory ?progress s *) + (* ?gc ?restarts ?time ?memory ?progress *) in let t2 = Sys.time () in begin match res with @@ -237,7 +280,10 @@ let solve if pp_model then ( Format.printf "(@[model@ %a@])@." Model.pp m ); - if check then Solver.check_model s; + if check then ( + Solver.check_model s; + CCOpt.iter (fun h -> check_smt_model (Solver.solver s) h m) hyps; + ); let t3 = Sys.time () -. t2 in Format.printf "Sat (%.3f/%.3f/%.3f)@." t1 (t2-.t1) t3; | Solver.Unsat p -> @@ -271,6 +317,7 @@ let mk_iatom = (* process a single statement *) let process_stmt + ?hyps ?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?check ?time ?memory ?progress (solver:Solver.t) @@ -301,8 +348,10 @@ let process_stmt Log.debug 1 "exit"; raise Exit | A.CheckSat -> - solve ?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress - solver ~assumptions:[]; + solve + ?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress + ~assumptions:[] ?hyps + solver; E.return() | A.TyDecl (id,n) -> decl_sort id n; @@ -318,13 +367,13 @@ let process_stmt if pp_cnf then ( Format.printf "(@[assert@ %a@])@." Term.pp t ); - (* TODO - hyps := clauses @ !hyps; - *) - Solver.assume solver (IArray.singleton (Lit.atom t)); + let atom = Lit.atom t in + CCOpt.iter (fun h -> Vec.push h [atom]) hyps; + Solver.assume solver (IArray.singleton atom); E.return() | A.Assert_bool l -> let c = List.rev_map (mk_iatom tst) l in + CCOpt.iter (fun h -> Vec.push h c) hyps; Solver.assume solver (IArray.of_list c); E.return () | A.Goal (_, _) -> diff --git a/src/smtlib/Process.mli b/src/smtlib/Process.mli index 978cc1b2..46c20333 100644 --- a/src/smtlib/Process.mli +++ b/src/smtlib/Process.mli @@ -12,6 +12,7 @@ val conv_ty : Ast.Ty.t -> Ty.t val conv_term : Term.state -> Ast.term -> Term.t val process_stmt : + ?hyps:Lit.t list Vec.t -> ?gc:bool -> ?restarts:bool -> ?pp_cnf:bool -> diff --git a/src/util/Vec.ml b/src/util/Vec.ml index 67106d0d..c7df6b56 100644 --- a/src/util/Vec.ml +++ b/src/util/Vec.ml @@ -146,6 +146,8 @@ let append a b = grow_to_at_least a (size a + size b); iter (push a) b +let append_l v l = List.iter (push v) l + let fold f acc t = let rec _fold f acc t i = if i=t.sz diff --git a/src/util/Vec.mli b/src/util/Vec.mli index 0d9b6cef..c7497020 100644 --- a/src/util/Vec.mli +++ b/src/util/Vec.mli @@ -66,6 +66,8 @@ val push : 'a t -> 'a -> unit val append : 'a t -> 'a t -> unit (** [append v1 v2] pushes all elements of [v2] into [v1] *) +val append_l : 'a t -> 'a list -> unit + val last : 'a t -> 'a (** Last element, or @raise Invalid_argument if the vector is empty *)