feat(model): proper model construction for CC + fun interpretation

This commit is contained in:
Simon Cruanes 2018-06-11 20:57:19 -05:00
parent f3c02ebd58
commit 080cde778e
15 changed files with 255 additions and 35 deletions

View file

@ -123,10 +123,11 @@ let main () =
(* process statements *)
let res =
try
let hyps = Vec.make_empty [] in
E.fold_l
(fun () ->
Process.process_stmt
~gc:!gc ~restarts:!restarts ~pp_cnf:!p_cnf
~hyps ~gc:!gc ~restarts:!restarts ~pp_cnf:!p_cnf
~time:!time_limit ~memory:!size_limit
?dot_proof ~pp_model:!p_model ~check:!check ~progress:!p_progress
solver)

View file

@ -152,7 +152,9 @@ module type S = sig
val compare : t -> t -> int
val equal : t -> t -> bool
val get_formula : t -> formula
val is_true : t -> bool
val dummy : t
val make : solver -> formula -> t
val pp : t printer

View file

@ -587,26 +587,65 @@ let final_check cc : unit =
(* model: map each uninterpreted equiv class to some ID *)
let mk_model (cc:t) (m:Model.t) : Model.t =
(* populate [repr -> value] table *)
let tbl = Equiv_class.Tbl.create 32 in
let t_tbl = Equiv_class.Tbl.create 32 in
(* type -> default value *)
let ty_tbl = Ty.Tbl.create 8 in
Term.Tbl.values cc.tbl
(fun r ->
if is_root_ r then (
let v = match Model.eval m r.n_term with
let t = r.n_term in
let v = match Model.eval m t with
| Some v -> v
| None ->
Value.mk_elt
(ID.makef "v_%d" @@ Term.id r.n_term)
(ID.makef "v_%d" @@ Term.id t)
(Term.ty r.n_term)
in
Equiv_class.Tbl.add tbl r v
if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then (
Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *)
);
Equiv_class.Tbl.add t_tbl r v
));
(* now map every uninterpreted term to its representative's value *)
Term.Tbl.to_seq cc.tbl
|> Sequence.fold
(fun m (t,r) ->
if Model.mem t m then m
else (
let v = Equiv_class.Tbl.find tbl r in
Model.add t v m
))
m
(* now map every uninterpreted term to its representative's value, and
create function tables *)
let m, funs =
Term.Tbl.to_seq cc.tbl
|> Sequence.fold
(fun (m,funs) (t,r) ->
let r = find cc r in (* get representative *)
match Term.view t with
| _ when Model.mem t m -> m, funs
| App_cst (c, args) ->
if Model.mem t m then m, funs
else if Cst.is_undefined c && IArray.length args > 0 then (
(* update signature of [c] *)
let ty = Term.ty t in
let v = Equiv_class.Tbl.find t_tbl r in
let args =
args
|> IArray.map (fun t -> Equiv_class.Tbl.find t_tbl @@ find_tn cc t)
|> IArray.to_list
in
let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in
m, Cst.Map.add c (ty, (args,v)::l) funs
) else (
let v = Equiv_class.Tbl.find t_tbl r in
Model.add t v m, funs
)
| _ ->
let v = Equiv_class.Tbl.find t_tbl r in
Model.add t v m, funs)
(m,Cst.Map.empty)
in
(* get or make a default value for this type *)
let get_ty_default (ty:Ty.t) : Value.t =
Ty.Tbl.get_or_add ty_tbl ~k:ty
~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty)
in
let funs =
Cst.Map.map
(fun (ty,l) ->
Model.Fun_interpretation.make ~default:(get_ty_default ty) l)
funs
in
Model.add_funs funs m

View file

@ -12,6 +12,8 @@ let as_undefined (c:t) = match view c with
| Cst_undef ty -> Some (c,ty)
| Cst_def _ -> None
let[@inline] is_undefined c = match view c with Cst_undef _ -> true | _ -> false
let as_undefined_exn (c:t) = match as_undefined c with
| Some tup -> tup
| None -> assert false

View file

@ -11,6 +11,7 @@ val compare : t -> t -> int
val hash : t -> int
val as_undefined : t -> (t * Ty.Fun.t) option
val as_undefined_exn : t -> t * Ty.Fun.t
val is_undefined : t -> bool
val mk_undef : ID.t -> Ty.Fun.t -> t
val mk_undef_const : ID.t -> Ty.t -> t

View file

@ -5,11 +5,58 @@
open Solver_types
module Val_map = struct
module M = CCIntMap
module Key = struct
type t = Value.t list
let equal = CCList.equal Value.equal
let hash = Hash.list Value.hash
end
type key = Key.t
type 'a t = (key * 'a) list M.t
let empty = M.empty
let is_empty m = M.cardinal m = 0
let cardinal = M.cardinal
let find k m =
try Some (CCList.assoc ~eq:Key.equal k @@ M.find_exn (Key.hash k) m)
with Not_found -> None
let add k v m =
let h = Key.hash k in
let l = M.find h m |> CCOpt.get_or ~default:[] in
let l = CCList.Assoc.set ~eq:Key.equal k v l in
M.add h l m
let to_seq m yield = M.iter (fun _ l -> List.iter yield l) m
end
module Fun_interpretation = struct
type t = {
cases: Value.t Val_map.t;
default: Value.t;
}
let default fi = fi.default
let cases_list fi = Val_map.to_seq fi.cases |> Sequence.to_rev_list
let make ~default l : t =
let m = List.fold_left (fun m (k,v) -> Val_map.add k v m) Val_map.empty l in
{ cases=m; default }
end
type t = {
values: Value.t Term.Map.t;
funs: Fun_interpretation.t Cst.Map.t;
}
let empty : t = {values=Term.Map.empty}
let empty : t = {
values=Term.Map.empty;
funs=Cst.Map.empty;
}
let[@inline] mem t m = Term.Map.mem t m.values
let[@inline] find t m = Term.Map.get t m.values
@ -23,7 +70,13 @@ let add t v m : t =
);
m
| exception Not_found ->
{values=Term.Map.add t v m.values}
{m with values=Term.Map.add t v m.values}
let add_fun c v m : t =
match Cst.Map.find c m.funs with
| _ -> Error.errorf "@[Model: function %a already has an interpretation@]" Cst.pp c
| exception Not_found ->
{m with funs=Cst.Map.add c v m.funs}
(* merge two models *)
let merge m1 m2 : t =
@ -36,17 +89,36 @@ let merge m1 m2 : t =
Error.errorf "@[Model: incompatible values for term %a@ :previous %a@ :new %a@]"
Term.pp t Value.pp v1 Value.pp v2
))
and funs =
Cst.Map.merge_safe m1.funs m2.funs
~f:(fun c o -> match o with
| `Left v | `Right v -> Some v
| `Both _ ->
Error.errorf "cannot merge the two interpretations of function %a" Cst.pp c)
in
{values}
{values; funs}
let pp out (m:t) =
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ -> %a@])" Term.pp t Value.pp v in
Fmt.fprintf out "(@[model@ %a@])"
(Fmt.seq ~sep:Fmt.(return "@ ") pp_tv) (Term.Map.to_seq m.values)
let add_funs fs m : t = merge {values=Term.Map.empty; funs=fs} m
let pp out {values; funs} =
let module FI = Fun_interpretation in
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ %a@])" Term.pp t Value.pp v in
let pp_fun_entry out (vals,ret) =
Format.fprintf out "(@[%a@ %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret
in
let pp_fun out (c, fi: Cst.t * FI.t) =
Format.fprintf out "(@[<hov>%a :default %a@ %a@])"
Cst.pp c Value.pp fi.FI.default
(Fmt.list ~sep:(Fmt.return "@ ") pp_fun_entry) (FI.cases_list fi)
in
Fmt.fprintf out "(@[model@ @[:terms (@[<hv>%a@])@]@ @[:funs (@[<hv>%a@])@]@])"
(Fmt.seq ~sep:Fmt.(return "@ ") pp_tv) (Term.Map.to_seq values)
(Fmt.seq ~sep:Fmt.(return "@ ") pp_fun) (Cst.Map.to_seq funs)
exception No_value
let eval (m:t) (t:Term.t) : Value.t option =
let module FI = Fun_interpretation in
let rec aux t = match Term.view t with
| Bool b -> Value.bool b
| If (a,b,c) ->
@ -64,8 +136,17 @@ let eval (m:t) (t:Term.t) : Value.t option =
let args = IArray.map aux args in
udef.eval args
| Cst_undef _ ->
Log.debugf 5 (fun k->k "(@[model.eval.undef@ %a@])" Term.pp t);
raise No_value (* no particular interpretation *)
begin match Cst.Map.find c m.funs with
| fi ->
let args = IArray.map aux args |> IArray.to_list in
begin match Val_map.find args fi.FI.cases with
| None -> fi.FI.default
| Some v -> v
end
| exception Not_found ->
Log.debugf 5 (fun k->k "(@[model.eval.undef@ %a@])" Term.pp t);
raise No_value (* no particular interpretation *)
end
end
in
try Some (aux t)

View file

@ -3,14 +3,44 @@
(** {1 Model} *)
module Val_map : sig
type key = Value.t list
type 'a t
val empty : 'a t
val is_empty : _ t -> bool
val cardinal : _ t -> int
val find : key -> 'a t -> 'a option
val add : key -> 'a -> 'a t -> 'a t
end
module Fun_interpretation : sig
type t = {
cases: Value.t Val_map.t;
default: Value.t;
}
val default : t -> Value.t
val cases_list : t -> (Value.t list * Value.t) list
val make :
default:Value.t ->
(Value.t list * Value.t) list ->
t
end
type t = {
values: Value.t Term.Map.t;
funs: Fun_interpretation.t Cst.Map.t;
}
val empty : t
val add : Term.t -> Value.t -> t -> t
val add_fun : Cst.t -> Fun_interpretation.t -> t -> t
val add_funs : Fun_interpretation.t Cst.Map.t -> t -> t
val mem : Term.t -> t -> bool
val find : Term.t -> t -> Value.t option

View file

@ -202,7 +202,9 @@ let[@inline] assume_eq self t u expl : unit =
let[@inline] assume_distinct self l ~neq lit : unit =
Congruence_closure.assert_distinct (cc self) l lit ~neq
let check_model (s:t) = Sat_solver.check_model s.solver
let check_model (s:t) : unit =
Log.debug 1 "(smt.solver.check-model)";
Sat_solver.check_model s.solver
(* TODO: main loop with iterative deepening of the unrolling limit
(not the value depth limit) *)

View file

@ -130,6 +130,7 @@ and value =
view: value_custom_view;
pp: value_custom_view Fmt.printer;
eq: value_custom_view -> value_custom_view -> bool;
hash: value_custom_view -> int;
} (** Custom value *)
and value_custom_view = ..
@ -192,6 +193,11 @@ let eq_value a b = match a, b with
| V_bool _, _ | V_element _, _ | V_custom _, _
-> false
let hash_value a = match a with
| V_bool a -> Hash.bool a
| V_element e -> ID.hash e.id
| V_custom x -> x.hash x.view
let pp_value out = function
| V_bool b -> Fmt.bool out b
| V_element e -> ID.pp out e.id

View file

@ -5,7 +5,8 @@ type t = ty
type view = Solver_types.ty_view
type def = Solver_types.ty_def
let view t = t.ty_view
let[@inline] id t = t.ty_id
let[@inline] view t = t.ty_view
let equal = eq_ty
let[@inline] compare a b = CCInt.compare a.ty_id b.ty_id

View file

@ -7,6 +7,7 @@ type t = Solver_types.ty
type view = Solver_types.ty_view
type def = Solver_types.ty_def
val id : t -> int
val view : t -> view
val prop : t

View file

@ -217,6 +217,47 @@ end
let conv_ty = Conv.conv_ty
let conv_term = Conv.conv_term
(* check SMT model *)
let check_smt_model (solver:Solver.Sat_solver.t) (hyps:_ Vec.t) (m:Model.t) : unit =
Log.debug 1 "(smt.check-smt-model)";
let open Solver_types in
let module S = Solver.Sat_solver in
let check_atom (lit:Lit.t) : bool option =
Log.debugf 5 (fun k->k "(@[smt.check-smt-model.atom@ %a@])" Lit.pp lit);
let a = S.Atom.make solver lit in
let is_true = S.Atom.is_true a in
let is_false = S.Atom.is_true (S.Atom.neg a) in
let sat_value = if is_true then Some true else if is_false then Some false else None in
begin match Lit.as_atom lit with
| None -> assert false
| Some (t, sign) ->
match Model.eval m t with
| Some (V_bool b) ->
let b = if sign then b else not b in
if (is_true || is_false) && ((b && is_false) || (not b && is_true)) then (
Error.errorf "(@[check-model.error@ :atom %a@ :model-val %B@ :sat-val %B@])"
S.Atom.pp a b (if is_true then true else not is_false)
)
| Some v ->
Error.errorf "(@[check-model.error@ :atom %a@ :non-bool-value %a@])"
S.Atom.pp a Value.pp v
| None ->
if is_true || is_false then (
Error.errorf "(@[check-model.error@ :atom %a@ :no-smt-value@ :sat-val %B@])"
S.Atom.pp a is_true
);
end;
sat_value
in
let check_c c =
let bs = List.map check_atom c in
if List.for_all (function Some true -> false | _ -> true) bs then (
Error.errorf "(@[check-model.error.none-true@ :clause %a@ :vals %a@])"
(Fmt.Dump.list Lit.pp) c Fmt.(Dump.list @@ Dump.option bool) bs
);
in
Vec.iter check_c hyps
(* call the solver to check-sat *)
let solve
?gc:_
@ -225,11 +266,13 @@ let solve
?(pp_model=false)
?(check=false)
?time:_ ?memory:_ ?progress:_
~assumptions s : unit =
?hyps
~assumptions
s : unit =
let t1 = Sys.time() in
let res =
Solver.solve ~assumptions s
(* ?gc ?restarts ?time ?memory ?progress s *)
(* ?gc ?restarts ?time ?memory ?progress *)
in
let t2 = Sys.time () in
begin match res with
@ -237,7 +280,10 @@ let solve
if pp_model then (
Format.printf "(@[<hv1>model@ %a@])@." Model.pp m
);
if check then Solver.check_model s;
if check then (
Solver.check_model s;
CCOpt.iter (fun h -> check_smt_model (Solver.solver s) h m) hyps;
);
let t3 = Sys.time () -. t2 in
Format.printf "Sat (%.3f/%.3f/%.3f)@." t1 (t2-.t1) t3;
| Solver.Unsat p ->
@ -271,6 +317,7 @@ let mk_iatom =
(* process a single statement *)
let process_stmt
?hyps
?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?check
?time ?memory ?progress
(solver:Solver.t)
@ -301,8 +348,10 @@ let process_stmt
Log.debug 1 "exit";
raise Exit
| A.CheckSat ->
solve ?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress
solver ~assumptions:[];
solve
?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress
~assumptions:[] ?hyps
solver;
E.return()
| A.TyDecl (id,n) ->
decl_sort id n;
@ -318,13 +367,13 @@ let process_stmt
if pp_cnf then (
Format.printf "(@[<hv1>assert@ %a@])@." Term.pp t
);
(* TODO
hyps := clauses @ !hyps;
*)
Solver.assume solver (IArray.singleton (Lit.atom t));
let atom = Lit.atom t in
CCOpt.iter (fun h -> Vec.push h [atom]) hyps;
Solver.assume solver (IArray.singleton atom);
E.return()
| A.Assert_bool l ->
let c = List.rev_map (mk_iatom tst) l in
CCOpt.iter (fun h -> Vec.push h c) hyps;
Solver.assume solver (IArray.of_list c);
E.return ()
| A.Goal (_, _) ->

View file

@ -12,6 +12,7 @@ val conv_ty : Ast.Ty.t -> Ty.t
val conv_term : Term.state -> Ast.term -> Term.t
val process_stmt :
?hyps:Lit.t list Vec.t ->
?gc:bool ->
?restarts:bool ->
?pp_cnf:bool ->

View file

@ -146,6 +146,8 @@ let append a b =
grow_to_at_least a (size a + size b);
iter (push a) b
let append_l v l = List.iter (push v) l
let fold f acc t =
let rec _fold f acc t i =
if i=t.sz

View file

@ -66,6 +66,8 @@ val push : 'a t -> 'a -> unit
val append : 'a t -> 'a t -> unit
(** [append v1 v2] pushes all elements of [v2] into [v1] *)
val append_l : 'a t -> 'a list -> unit
val last : 'a t -> 'a
(** Last element, or
@raise Invalid_argument if the vector is empty *)