mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
feat(model): proper model construction for CC + fun interpretation
This commit is contained in:
parent
f3c02ebd58
commit
080cde778e
15 changed files with 255 additions and 35 deletions
|
|
@ -123,10 +123,11 @@ let main () =
|
|||
(* process statements *)
|
||||
let res =
|
||||
try
|
||||
let hyps = Vec.make_empty [] in
|
||||
E.fold_l
|
||||
(fun () ->
|
||||
Process.process_stmt
|
||||
~gc:!gc ~restarts:!restarts ~pp_cnf:!p_cnf
|
||||
~hyps ~gc:!gc ~restarts:!restarts ~pp_cnf:!p_cnf
|
||||
~time:!time_limit ~memory:!size_limit
|
||||
?dot_proof ~pp_model:!p_model ~check:!check ~progress:!p_progress
|
||||
solver)
|
||||
|
|
|
|||
|
|
@ -152,7 +152,9 @@ module type S = sig
|
|||
val compare : t -> t -> int
|
||||
val equal : t -> t -> bool
|
||||
val get_formula : t -> formula
|
||||
val is_true : t -> bool
|
||||
|
||||
val dummy : t
|
||||
val make : solver -> formula -> t
|
||||
|
||||
val pp : t printer
|
||||
|
|
|
|||
|
|
@ -587,26 +587,65 @@ let final_check cc : unit =
|
|||
(* model: map each uninterpreted equiv class to some ID *)
|
||||
let mk_model (cc:t) (m:Model.t) : Model.t =
|
||||
(* populate [repr -> value] table *)
|
||||
let tbl = Equiv_class.Tbl.create 32 in
|
||||
let t_tbl = Equiv_class.Tbl.create 32 in
|
||||
(* type -> default value *)
|
||||
let ty_tbl = Ty.Tbl.create 8 in
|
||||
Term.Tbl.values cc.tbl
|
||||
(fun r ->
|
||||
if is_root_ r then (
|
||||
let v = match Model.eval m r.n_term with
|
||||
let t = r.n_term in
|
||||
let v = match Model.eval m t with
|
||||
| Some v -> v
|
||||
| None ->
|
||||
Value.mk_elt
|
||||
(ID.makef "v_%d" @@ Term.id r.n_term)
|
||||
(ID.makef "v_%d" @@ Term.id t)
|
||||
(Term.ty r.n_term)
|
||||
in
|
||||
Equiv_class.Tbl.add tbl r v
|
||||
if not @@ Ty.Tbl.mem ty_tbl (Term.ty t) then (
|
||||
Ty.Tbl.add ty_tbl (Term.ty t) v; (* also give a value to this type *)
|
||||
);
|
||||
Equiv_class.Tbl.add t_tbl r v
|
||||
));
|
||||
(* now map every uninterpreted term to its representative's value *)
|
||||
Term.Tbl.to_seq cc.tbl
|
||||
|> Sequence.fold
|
||||
(fun m (t,r) ->
|
||||
if Model.mem t m then m
|
||||
else (
|
||||
let v = Equiv_class.Tbl.find tbl r in
|
||||
Model.add t v m
|
||||
))
|
||||
m
|
||||
(* now map every uninterpreted term to its representative's value, and
|
||||
create function tables *)
|
||||
let m, funs =
|
||||
Term.Tbl.to_seq cc.tbl
|
||||
|> Sequence.fold
|
||||
(fun (m,funs) (t,r) ->
|
||||
let r = find cc r in (* get representative *)
|
||||
match Term.view t with
|
||||
| _ when Model.mem t m -> m, funs
|
||||
| App_cst (c, args) ->
|
||||
if Model.mem t m then m, funs
|
||||
else if Cst.is_undefined c && IArray.length args > 0 then (
|
||||
(* update signature of [c] *)
|
||||
let ty = Term.ty t in
|
||||
let v = Equiv_class.Tbl.find t_tbl r in
|
||||
let args =
|
||||
args
|
||||
|> IArray.map (fun t -> Equiv_class.Tbl.find t_tbl @@ find_tn cc t)
|
||||
|> IArray.to_list
|
||||
in
|
||||
let ty, l = Cst.Map.get_or c funs ~default:(ty,[]) in
|
||||
m, Cst.Map.add c (ty, (args,v)::l) funs
|
||||
) else (
|
||||
let v = Equiv_class.Tbl.find t_tbl r in
|
||||
Model.add t v m, funs
|
||||
)
|
||||
| _ ->
|
||||
let v = Equiv_class.Tbl.find t_tbl r in
|
||||
Model.add t v m, funs)
|
||||
(m,Cst.Map.empty)
|
||||
in
|
||||
(* get or make a default value for this type *)
|
||||
let get_ty_default (ty:Ty.t) : Value.t =
|
||||
Ty.Tbl.get_or_add ty_tbl ~k:ty
|
||||
~f:(fun ty -> Value.mk_elt (ID.makef "ty_%d" @@ Ty.id ty) ty)
|
||||
in
|
||||
let funs =
|
||||
Cst.Map.map
|
||||
(fun (ty,l) ->
|
||||
Model.Fun_interpretation.make ~default:(get_ty_default ty) l)
|
||||
funs
|
||||
in
|
||||
Model.add_funs funs m
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ let as_undefined (c:t) = match view c with
|
|||
| Cst_undef ty -> Some (c,ty)
|
||||
| Cst_def _ -> None
|
||||
|
||||
let[@inline] is_undefined c = match view c with Cst_undef _ -> true | _ -> false
|
||||
|
||||
let as_undefined_exn (c:t) = match as_undefined c with
|
||||
| Some tup -> tup
|
||||
| None -> assert false
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ val compare : t -> t -> int
|
|||
val hash : t -> int
|
||||
val as_undefined : t -> (t * Ty.Fun.t) option
|
||||
val as_undefined_exn : t -> t * Ty.Fun.t
|
||||
val is_undefined : t -> bool
|
||||
|
||||
val mk_undef : ID.t -> Ty.Fun.t -> t
|
||||
val mk_undef_const : ID.t -> Ty.t -> t
|
||||
|
|
|
|||
|
|
@ -5,11 +5,58 @@
|
|||
|
||||
open Solver_types
|
||||
|
||||
module Val_map = struct
|
||||
module M = CCIntMap
|
||||
module Key = struct
|
||||
type t = Value.t list
|
||||
let equal = CCList.equal Value.equal
|
||||
let hash = Hash.list Value.hash
|
||||
end
|
||||
type key = Key.t
|
||||
|
||||
type 'a t = (key * 'a) list M.t
|
||||
|
||||
let empty = M.empty
|
||||
let is_empty m = M.cardinal m = 0
|
||||
|
||||
let cardinal = M.cardinal
|
||||
|
||||
let find k m =
|
||||
try Some (CCList.assoc ~eq:Key.equal k @@ M.find_exn (Key.hash k) m)
|
||||
with Not_found -> None
|
||||
|
||||
let add k v m =
|
||||
let h = Key.hash k in
|
||||
let l = M.find h m |> CCOpt.get_or ~default:[] in
|
||||
let l = CCList.Assoc.set ~eq:Key.equal k v l in
|
||||
M.add h l m
|
||||
|
||||
let to_seq m yield = M.iter (fun _ l -> List.iter yield l) m
|
||||
end
|
||||
|
||||
module Fun_interpretation = struct
|
||||
type t = {
|
||||
cases: Value.t Val_map.t;
|
||||
default: Value.t;
|
||||
}
|
||||
|
||||
let default fi = fi.default
|
||||
let cases_list fi = Val_map.to_seq fi.cases |> Sequence.to_rev_list
|
||||
|
||||
let make ~default l : t =
|
||||
let m = List.fold_left (fun m (k,v) -> Val_map.add k v m) Val_map.empty l in
|
||||
{ cases=m; default }
|
||||
end
|
||||
|
||||
type t = {
|
||||
values: Value.t Term.Map.t;
|
||||
funs: Fun_interpretation.t Cst.Map.t;
|
||||
}
|
||||
|
||||
let empty : t = {values=Term.Map.empty}
|
||||
let empty : t = {
|
||||
values=Term.Map.empty;
|
||||
funs=Cst.Map.empty;
|
||||
}
|
||||
|
||||
let[@inline] mem t m = Term.Map.mem t m.values
|
||||
let[@inline] find t m = Term.Map.get t m.values
|
||||
|
|
@ -23,7 +70,13 @@ let add t v m : t =
|
|||
);
|
||||
m
|
||||
| exception Not_found ->
|
||||
{values=Term.Map.add t v m.values}
|
||||
{m with values=Term.Map.add t v m.values}
|
||||
|
||||
let add_fun c v m : t =
|
||||
match Cst.Map.find c m.funs with
|
||||
| _ -> Error.errorf "@[Model: function %a already has an interpretation@]" Cst.pp c
|
||||
| exception Not_found ->
|
||||
{m with funs=Cst.Map.add c v m.funs}
|
||||
|
||||
(* merge two models *)
|
||||
let merge m1 m2 : t =
|
||||
|
|
@ -36,17 +89,36 @@ let merge m1 m2 : t =
|
|||
Error.errorf "@[Model: incompatible values for term %a@ :previous %a@ :new %a@]"
|
||||
Term.pp t Value.pp v1 Value.pp v2
|
||||
))
|
||||
and funs =
|
||||
Cst.Map.merge_safe m1.funs m2.funs
|
||||
~f:(fun c o -> match o with
|
||||
| `Left v | `Right v -> Some v
|
||||
| `Both _ ->
|
||||
Error.errorf "cannot merge the two interpretations of function %a" Cst.pp c)
|
||||
in
|
||||
{values}
|
||||
{values; funs}
|
||||
|
||||
let pp out (m:t) =
|
||||
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ -> %a@])" Term.pp t Value.pp v in
|
||||
Fmt.fprintf out "(@[model@ %a@])"
|
||||
(Fmt.seq ~sep:Fmt.(return "@ ") pp_tv) (Term.Map.to_seq m.values)
|
||||
let add_funs fs m : t = merge {values=Term.Map.empty; funs=fs} m
|
||||
|
||||
let pp out {values; funs} =
|
||||
let module FI = Fun_interpretation in
|
||||
let pp_tv out (t,v) = Fmt.fprintf out "(@[%a@ %a@])" Term.pp t Value.pp v in
|
||||
let pp_fun_entry out (vals,ret) =
|
||||
Format.fprintf out "(@[%a@ %a@])" (Fmt.Dump.list Value.pp) vals Value.pp ret
|
||||
in
|
||||
let pp_fun out (c, fi: Cst.t * FI.t) =
|
||||
Format.fprintf out "(@[<hov>%a :default %a@ %a@])"
|
||||
Cst.pp c Value.pp fi.FI.default
|
||||
(Fmt.list ~sep:(Fmt.return "@ ") pp_fun_entry) (FI.cases_list fi)
|
||||
in
|
||||
Fmt.fprintf out "(@[model@ @[:terms (@[<hv>%a@])@]@ @[:funs (@[<hv>%a@])@]@])"
|
||||
(Fmt.seq ~sep:Fmt.(return "@ ") pp_tv) (Term.Map.to_seq values)
|
||||
(Fmt.seq ~sep:Fmt.(return "@ ") pp_fun) (Cst.Map.to_seq funs)
|
||||
|
||||
exception No_value
|
||||
|
||||
let eval (m:t) (t:Term.t) : Value.t option =
|
||||
let module FI = Fun_interpretation in
|
||||
let rec aux t = match Term.view t with
|
||||
| Bool b -> Value.bool b
|
||||
| If (a,b,c) ->
|
||||
|
|
@ -64,8 +136,17 @@ let eval (m:t) (t:Term.t) : Value.t option =
|
|||
let args = IArray.map aux args in
|
||||
udef.eval args
|
||||
| Cst_undef _ ->
|
||||
Log.debugf 5 (fun k->k "(@[model.eval.undef@ %a@])" Term.pp t);
|
||||
raise No_value (* no particular interpretation *)
|
||||
begin match Cst.Map.find c m.funs with
|
||||
| fi ->
|
||||
let args = IArray.map aux args |> IArray.to_list in
|
||||
begin match Val_map.find args fi.FI.cases with
|
||||
| None -> fi.FI.default
|
||||
| Some v -> v
|
||||
end
|
||||
| exception Not_found ->
|
||||
Log.debugf 5 (fun k->k "(@[model.eval.undef@ %a@])" Term.pp t);
|
||||
raise No_value (* no particular interpretation *)
|
||||
end
|
||||
end
|
||||
in
|
||||
try Some (aux t)
|
||||
|
|
|
|||
|
|
@ -3,14 +3,44 @@
|
|||
|
||||
(** {1 Model} *)
|
||||
|
||||
module Val_map : sig
|
||||
type key = Value.t list
|
||||
type 'a t
|
||||
val empty : 'a t
|
||||
val is_empty : _ t -> bool
|
||||
val cardinal : _ t -> int
|
||||
val find : key -> 'a t -> 'a option
|
||||
val add : key -> 'a -> 'a t -> 'a t
|
||||
end
|
||||
|
||||
module Fun_interpretation : sig
|
||||
type t = {
|
||||
cases: Value.t Val_map.t;
|
||||
default: Value.t;
|
||||
}
|
||||
|
||||
val default : t -> Value.t
|
||||
val cases_list : t -> (Value.t list * Value.t) list
|
||||
|
||||
val make :
|
||||
default:Value.t ->
|
||||
(Value.t list * Value.t) list ->
|
||||
t
|
||||
end
|
||||
|
||||
type t = {
|
||||
values: Value.t Term.Map.t;
|
||||
funs: Fun_interpretation.t Cst.Map.t;
|
||||
}
|
||||
|
||||
val empty : t
|
||||
|
||||
val add : Term.t -> Value.t -> t -> t
|
||||
|
||||
val add_fun : Cst.t -> Fun_interpretation.t -> t -> t
|
||||
|
||||
val add_funs : Fun_interpretation.t Cst.Map.t -> t -> t
|
||||
|
||||
val mem : Term.t -> t -> bool
|
||||
|
||||
val find : Term.t -> t -> Value.t option
|
||||
|
|
|
|||
|
|
@ -202,7 +202,9 @@ let[@inline] assume_eq self t u expl : unit =
|
|||
let[@inline] assume_distinct self l ~neq lit : unit =
|
||||
Congruence_closure.assert_distinct (cc self) l lit ~neq
|
||||
|
||||
let check_model (s:t) = Sat_solver.check_model s.solver
|
||||
let check_model (s:t) : unit =
|
||||
Log.debug 1 "(smt.solver.check-model)";
|
||||
Sat_solver.check_model s.solver
|
||||
|
||||
(* TODO: main loop with iterative deepening of the unrolling limit
|
||||
(not the value depth limit) *)
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ and value =
|
|||
view: value_custom_view;
|
||||
pp: value_custom_view Fmt.printer;
|
||||
eq: value_custom_view -> value_custom_view -> bool;
|
||||
hash: value_custom_view -> int;
|
||||
} (** Custom value *)
|
||||
|
||||
and value_custom_view = ..
|
||||
|
|
@ -192,6 +193,11 @@ let eq_value a b = match a, b with
|
|||
| V_bool _, _ | V_element _, _ | V_custom _, _
|
||||
-> false
|
||||
|
||||
let hash_value a = match a with
|
||||
| V_bool a -> Hash.bool a
|
||||
| V_element e -> ID.hash e.id
|
||||
| V_custom x -> x.hash x.view
|
||||
|
||||
let pp_value out = function
|
||||
| V_bool b -> Fmt.bool out b
|
||||
| V_element e -> ID.pp out e.id
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ type t = ty
|
|||
type view = Solver_types.ty_view
|
||||
type def = Solver_types.ty_def
|
||||
|
||||
let view t = t.ty_view
|
||||
let[@inline] id t = t.ty_id
|
||||
let[@inline] view t = t.ty_view
|
||||
|
||||
let equal = eq_ty
|
||||
let[@inline] compare a b = CCInt.compare a.ty_id b.ty_id
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ type t = Solver_types.ty
|
|||
type view = Solver_types.ty_view
|
||||
type def = Solver_types.ty_def
|
||||
|
||||
val id : t -> int
|
||||
val view : t -> view
|
||||
|
||||
val prop : t
|
||||
|
|
|
|||
|
|
@ -217,6 +217,47 @@ end
|
|||
let conv_ty = Conv.conv_ty
|
||||
let conv_term = Conv.conv_term
|
||||
|
||||
(* check SMT model *)
|
||||
let check_smt_model (solver:Solver.Sat_solver.t) (hyps:_ Vec.t) (m:Model.t) : unit =
|
||||
Log.debug 1 "(smt.check-smt-model)";
|
||||
let open Solver_types in
|
||||
let module S = Solver.Sat_solver in
|
||||
let check_atom (lit:Lit.t) : bool option =
|
||||
Log.debugf 5 (fun k->k "(@[smt.check-smt-model.atom@ %a@])" Lit.pp lit);
|
||||
let a = S.Atom.make solver lit in
|
||||
let is_true = S.Atom.is_true a in
|
||||
let is_false = S.Atom.is_true (S.Atom.neg a) in
|
||||
let sat_value = if is_true then Some true else if is_false then Some false else None in
|
||||
begin match Lit.as_atom lit with
|
||||
| None -> assert false
|
||||
| Some (t, sign) ->
|
||||
match Model.eval m t with
|
||||
| Some (V_bool b) ->
|
||||
let b = if sign then b else not b in
|
||||
if (is_true || is_false) && ((b && is_false) || (not b && is_true)) then (
|
||||
Error.errorf "(@[check-model.error@ :atom %a@ :model-val %B@ :sat-val %B@])"
|
||||
S.Atom.pp a b (if is_true then true else not is_false)
|
||||
)
|
||||
| Some v ->
|
||||
Error.errorf "(@[check-model.error@ :atom %a@ :non-bool-value %a@])"
|
||||
S.Atom.pp a Value.pp v
|
||||
| None ->
|
||||
if is_true || is_false then (
|
||||
Error.errorf "(@[check-model.error@ :atom %a@ :no-smt-value@ :sat-val %B@])"
|
||||
S.Atom.pp a is_true
|
||||
);
|
||||
end;
|
||||
sat_value
|
||||
in
|
||||
let check_c c =
|
||||
let bs = List.map check_atom c in
|
||||
if List.for_all (function Some true -> false | _ -> true) bs then (
|
||||
Error.errorf "(@[check-model.error.none-true@ :clause %a@ :vals %a@])"
|
||||
(Fmt.Dump.list Lit.pp) c Fmt.(Dump.list @@ Dump.option bool) bs
|
||||
);
|
||||
in
|
||||
Vec.iter check_c hyps
|
||||
|
||||
(* call the solver to check-sat *)
|
||||
let solve
|
||||
?gc:_
|
||||
|
|
@ -225,11 +266,13 @@ let solve
|
|||
?(pp_model=false)
|
||||
?(check=false)
|
||||
?time:_ ?memory:_ ?progress:_
|
||||
~assumptions s : unit =
|
||||
?hyps
|
||||
~assumptions
|
||||
s : unit =
|
||||
let t1 = Sys.time() in
|
||||
let res =
|
||||
Solver.solve ~assumptions s
|
||||
(* ?gc ?restarts ?time ?memory ?progress s *)
|
||||
(* ?gc ?restarts ?time ?memory ?progress *)
|
||||
in
|
||||
let t2 = Sys.time () in
|
||||
begin match res with
|
||||
|
|
@ -237,7 +280,10 @@ let solve
|
|||
if pp_model then (
|
||||
Format.printf "(@[<hv1>model@ %a@])@." Model.pp m
|
||||
);
|
||||
if check then Solver.check_model s;
|
||||
if check then (
|
||||
Solver.check_model s;
|
||||
CCOpt.iter (fun h -> check_smt_model (Solver.solver s) h m) hyps;
|
||||
);
|
||||
let t3 = Sys.time () -. t2 in
|
||||
Format.printf "Sat (%.3f/%.3f/%.3f)@." t1 (t2-.t1) t3;
|
||||
| Solver.Unsat p ->
|
||||
|
|
@ -271,6 +317,7 @@ let mk_iatom =
|
|||
|
||||
(* process a single statement *)
|
||||
let process_stmt
|
||||
?hyps
|
||||
?gc ?restarts ?(pp_cnf=false) ?dot_proof ?pp_model ?check
|
||||
?time ?memory ?progress
|
||||
(solver:Solver.t)
|
||||
|
|
@ -301,8 +348,10 @@ let process_stmt
|
|||
Log.debug 1 "exit";
|
||||
raise Exit
|
||||
| A.CheckSat ->
|
||||
solve ?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress
|
||||
solver ~assumptions:[];
|
||||
solve
|
||||
?gc ?restarts ?dot_proof ?check ?pp_model ?time ?memory ?progress
|
||||
~assumptions:[] ?hyps
|
||||
solver;
|
||||
E.return()
|
||||
| A.TyDecl (id,n) ->
|
||||
decl_sort id n;
|
||||
|
|
@ -318,13 +367,13 @@ let process_stmt
|
|||
if pp_cnf then (
|
||||
Format.printf "(@[<hv1>assert@ %a@])@." Term.pp t
|
||||
);
|
||||
(* TODO
|
||||
hyps := clauses @ !hyps;
|
||||
*)
|
||||
Solver.assume solver (IArray.singleton (Lit.atom t));
|
||||
let atom = Lit.atom t in
|
||||
CCOpt.iter (fun h -> Vec.push h [atom]) hyps;
|
||||
Solver.assume solver (IArray.singleton atom);
|
||||
E.return()
|
||||
| A.Assert_bool l ->
|
||||
let c = List.rev_map (mk_iatom tst) l in
|
||||
CCOpt.iter (fun h -> Vec.push h c) hyps;
|
||||
Solver.assume solver (IArray.of_list c);
|
||||
E.return ()
|
||||
| A.Goal (_, _) ->
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ val conv_ty : Ast.Ty.t -> Ty.t
|
|||
val conv_term : Term.state -> Ast.term -> Term.t
|
||||
|
||||
val process_stmt :
|
||||
?hyps:Lit.t list Vec.t ->
|
||||
?gc:bool ->
|
||||
?restarts:bool ->
|
||||
?pp_cnf:bool ->
|
||||
|
|
|
|||
|
|
@ -146,6 +146,8 @@ let append a b =
|
|||
grow_to_at_least a (size a + size b);
|
||||
iter (push a) b
|
||||
|
||||
let append_l v l = List.iter (push v) l
|
||||
|
||||
let fold f acc t =
|
||||
let rec _fold f acc t i =
|
||||
if i=t.sz
|
||||
|
|
|
|||
|
|
@ -66,6 +66,8 @@ val push : 'a t -> 'a -> unit
|
|||
val append : 'a t -> 'a t -> unit
|
||||
(** [append v1 v2] pushes all elements of [v2] into [v1] *)
|
||||
|
||||
val append_l : 'a t -> 'a list -> unit
|
||||
|
||||
val last : 'a t -> 'a
|
||||
(** Last element, or
|
||||
@raise Invalid_argument if the vector is empty *)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue