mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 11:15:43 -05:00
feat(LRA): expose model after fourier-motzkin returns "SAT"
This commit is contained in:
parent
1b7d084a9c
commit
db1c50f7ed
3 changed files with 155 additions and 13 deletions
|
|
@ -38,6 +38,9 @@ module type ARG = sig
|
|||
val mk_lra : S.T.Term.state -> term lra_view -> term
|
||||
(** Make a term from the given theory view *)
|
||||
|
||||
val has_ty_real : term -> bool
|
||||
(** Does this term have the type [Real] *)
|
||||
|
||||
module Gensym : sig
|
||||
type t
|
||||
|
||||
|
|
@ -81,6 +84,7 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
gensym: A.Gensym.t;
|
||||
neq_encoded: unit T.Tbl.t;
|
||||
(* if [a != b] asserted and not in this table, add clause [a = b \/ a<b \/ a>b] *)
|
||||
needs_th_combination: LE.t T.Tbl.t; (* terms that require theory combination *)
|
||||
mutable t_defs: (T.t * LE.t) list; (* term definitions *)
|
||||
pred_defs: (pred * LE.t * LE.t * T.t * T.t) T.Tbl.t; (* predicate definitions *)
|
||||
}
|
||||
|
|
@ -90,6 +94,7 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
simps=T.Tbl.create 128;
|
||||
gensym=A.Gensym.create tst;
|
||||
neq_encoded=T.Tbl.create 16;
|
||||
needs_th_combination=T.Tbl.create 8;
|
||||
t_defs=[];
|
||||
pred_defs=T.Tbl.create 16;
|
||||
}
|
||||
|
|
@ -170,7 +175,7 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
*)
|
||||
|
||||
(* preprocess linear expressions away *)
|
||||
let preproc_lra self si ~recurse ~mk_lit:_ ~add_clause:_ (t:T.t) : T.t option =
|
||||
let preproc_lra (self:state) si ~recurse ~mk_lit:_ ~add_clause:_ (t:T.t) : T.t option =
|
||||
Log.debugf 50 (fun k->k "lra.preprocess %a" T.pp t);
|
||||
let _tst = SI.tst si in
|
||||
match A.view_as_lra t with
|
||||
|
|
@ -184,11 +189,17 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
Some proxy
|
||||
| LRA_op _ | LRA_mult _ ->
|
||||
let le = as_linexp ~f:recurse t in
|
||||
(* TODO: reuse proxy if present? *)
|
||||
let proxy = fresh_term self ~pre:"_e_lra_" (T.ty t) in
|
||||
self.t_defs <- (proxy, le) :: self.t_defs;
|
||||
T.Tbl.add self.needs_th_combination t le;
|
||||
Log.debugf 5 (fun k->k"@[<hv2>lra.preprocess.step %a@ :into %a@ :def %a@]"
|
||||
T.pp t T.pp proxy LE.pp le);
|
||||
Some proxy
|
||||
| LRA_other t when A.has_ty_real t ->
|
||||
let le = LE.var t in
|
||||
T.Tbl.replace self.needs_th_combination t le;
|
||||
None
|
||||
| LRA_const _ | LRA_other _ -> None
|
||||
|
||||
(* ensure that [a != b] triggers the clause
|
||||
|
|
@ -269,8 +280,12 @@ module Make(A : ARG) : S with module A = A = struct
|
|||
end;
|
||||
Log.debug 5 "lra: call arith solver";
|
||||
begin match FM_A.solve fm with
|
||||
| FM_A.Sat ->
|
||||
| FM_A.Sat model ->
|
||||
Log.debug 5 "lra: solver returns SAT";
|
||||
Log.debugf 50
|
||||
(fun k->k "(@[LRA.needs-th-combination:@ %a@])"
|
||||
(Util.pp_iter @@ Fmt.within "`" "`" T.pp) (T.Tbl.keys self.needs_th_combination));
|
||||
Log.debugf 30 (fun k->k "(@[LRA.model@ %a@])" FM_A.pp_model model);
|
||||
() (* TODO: get a model + model combination *)
|
||||
| FM_A.Unsat lits ->
|
||||
(* we tagged assertions with their lit, so the certificate being an
|
||||
|
|
|
|||
|
|
@ -61,8 +61,6 @@ module type S = sig
|
|||
val find : term -> t -> Q.t option
|
||||
val mem : term -> t -> bool
|
||||
|
||||
(* val map : (term -> term) -> t -> t *)
|
||||
|
||||
module Infix : sig
|
||||
val (+) : t -> t -> t
|
||||
val (-) : t -> t -> t
|
||||
|
|
@ -86,8 +84,13 @@ module type S = sig
|
|||
|
||||
val assert_c : t -> Constr.t -> unit
|
||||
|
||||
type model
|
||||
|
||||
val get_model : model -> term -> Q.t
|
||||
val pp_model : model Fmt.printer
|
||||
|
||||
type res =
|
||||
| Sat
|
||||
| Sat of model
|
||||
| Unsat of A.tag list
|
||||
|
||||
val solve : t -> res
|
||||
|
|
@ -237,8 +240,20 @@ module Make(A : ARG)
|
|||
occ_neg: Constr.t list;
|
||||
}
|
||||
|
||||
type pre_model_strict = Strict | NonStrict
|
||||
type pre_model_constr =
|
||||
| PM_eq of LE.t
|
||||
| PM_bounds of {
|
||||
lower: (pre_model_strict * LE.t) list;
|
||||
upper: (pre_model_strict * LE.t) list;
|
||||
}
|
||||
|
||||
type pre_model = pre_model_constr lazy_t T_map.t
|
||||
type model = Q.t T_map.t lazy_t
|
||||
|
||||
type system = {
|
||||
empties: Constr.t list; (* no variables, check first *)
|
||||
pre_model: pre_model; (* for model construction *)
|
||||
idx: c_for_var T_map.t;
|
||||
(* map [t] to [cft] where [cft] are normalized constraints whose
|
||||
maximum term is [t], with positive sign for [cft.occ_pos]
|
||||
|
|
@ -250,7 +265,7 @@ module Make(A : ARG)
|
|||
mutable sys: system;
|
||||
}
|
||||
|
||||
let empty_sys : system = {empties=[]; idx=T_map.empty}
|
||||
let empty_sys : system = {empties=[]; pre_model=T_map.empty; idx=T_map.empty}
|
||||
let empty_c_for_v : c_for_var =
|
||||
{ occ_pos=[]; occ_neg=[]; occ_eq=[] }
|
||||
|
||||
|
|
@ -294,7 +309,8 @@ module Make(A : ARG)
|
|||
|
||||
let pp_system out (self:system) : unit =
|
||||
let pp_idxkv out (t,{occ_eq; occ_pos; occ_neg}) =
|
||||
Fmt.fprintf out "(@[for-var %a@ :occ-eq %a@ :occ-pos %a@ :occ-neg %a@])"
|
||||
Fmt.fprintf out
|
||||
"(@[for-var %a@ @[:occ-eq %a@]@ @[:occ-pos %a@]@ @[:occ-neg %a@]@])"
|
||||
T.pp t
|
||||
(Fmt.Dump.list Constr.pp) occ_eq
|
||||
(Fmt.Dump.list Constr.pp) occ_pos
|
||||
|
|
@ -305,8 +321,82 @@ module Make(A : ARG)
|
|||
(Util.pp_iter pp_idxkv) (T_map.to_iter self.idx)
|
||||
|
||||
(* TODO: be able to provide a model for SAT *)
|
||||
let build_model_ (self:pre_model) : _ T_map.t =
|
||||
let l = T_map.to_iter self |> Iter.to_rev_list in
|
||||
|
||||
(* how to evaluate a linexpr in the model *)
|
||||
let eval_le (mv:Q.t T_map.t) (le:LE.t) : Q.t =
|
||||
let find x = try T_map.find x mv with Not_found -> Q.zero in
|
||||
T_map.to_iter le.LE.le
|
||||
|> Iter.fold
|
||||
(fun sum (t,coeff) -> Q.(sum + coeff * find t))
|
||||
le.LE.const
|
||||
in
|
||||
let or_strict s1 s2 = match s1, s2 with
|
||||
| Strict, _ | _, Strict -> Strict
|
||||
| NonStrict, NonStrict -> NonStrict
|
||||
in
|
||||
let max_pair (s1,q1)(s2,q2) =
|
||||
if Q.equal q1 q2 then or_strict s1 s2, q1
|
||||
else if Q.gt q1 q2 then s1,q1
|
||||
else s2,q2
|
||||
and min_pair (s1,q1)(s2,q2) =
|
||||
if Q.equal q1 q2 then or_strict s1 s2, q1
|
||||
else if Q.lt q1 q2 then s1,q1
|
||||
else s2,q2
|
||||
in
|
||||
let m =
|
||||
List.fold_left
|
||||
begin fun m (v,cs_v) ->
|
||||
(* update [v] using its constraints [cs_v].
|
||||
[m] is the model to update *)
|
||||
let val_v =
|
||||
match cs_v with
|
||||
| lazy (PM_eq le) -> eval_le m le
|
||||
| lazy (PM_bounds {lower; upper}) ->
|
||||
let lower = List.map (fun (s,le) -> s, eval_le m le) lower in
|
||||
let upper = List.map (fun (s,le) -> s, eval_le m le) upper in
|
||||
let strict_low, lower = match lower with
|
||||
| [] -> NonStrict, Q.minus_inf
|
||||
| x :: l -> List.fold_left max_pair x l
|
||||
and strict_up, upper = match upper with
|
||||
| [] -> NonStrict, Q.inf
|
||||
| x :: l -> List.fold_left min_pair x l
|
||||
in
|
||||
if Q.is_real lower && Q.is_real upper then (
|
||||
if Q.equal lower upper then (
|
||||
assert (strict_low=NonStrict && strict_up=NonStrict); (* unsat otherwise *)
|
||||
lower
|
||||
) else (
|
||||
Q.((lower + upper) / of_int 2) (* middle *)
|
||||
)
|
||||
) else if Q.is_real lower then (
|
||||
if strict_low=Strict then Q.(lower + one) else lower
|
||||
) else if Q.is_real upper then (
|
||||
if strict_up=Strict then Q.(upper - one) else upper
|
||||
) else (
|
||||
Q.zero (* no bounds *)
|
||||
)
|
||||
in
|
||||
T_map.add v val_v m
|
||||
end
|
||||
T_map.empty l
|
||||
in
|
||||
m
|
||||
|
||||
let get_model (m:model) (v:T.t) : Q.t =
|
||||
let lazy m = m in
|
||||
try T_map.find v m
|
||||
with Not_found -> Q.zero
|
||||
|
||||
let pp_model out (m:model) : unit =
|
||||
let lazy m = m in
|
||||
let pp_pair out (v,q) = Fmt.fprintf out "(@[%a@ %a@])" T.pp v Q.pp_print q in
|
||||
Fmt.fprintf out "(@[<hv1>model@ %a@])"
|
||||
(Util.pp_iter pp_pair) (T_map.to_iter m)
|
||||
|
||||
type res =
|
||||
| Sat
|
||||
| Sat of model
|
||||
| Unsat of A.tag list
|
||||
|
||||
(* replace [x] with [by] inside [le] *)
|
||||
|
|
@ -324,6 +414,23 @@ module Make(A : ARG)
|
|||
} in
|
||||
Constr.normalize c
|
||||
|
||||
(* given an ineq constraint on [v], canonize it wrt [v]
|
||||
(set the coeff of [v] to 1)
|
||||
and return whether it's strict or not *)
|
||||
let premod_of_constr (v:T.t) (c:Constr.t) : pre_model_strict * LE.t =
|
||||
let strict =
|
||||
match c.Constr.pred with
|
||||
| Pred.Leq -> NonStrict | Pred.Lt -> Strict
|
||||
| _ -> assert false
|
||||
in
|
||||
let coeff =
|
||||
try LE.find_exn v c.Constr.le
|
||||
with Not_found -> assert false
|
||||
in
|
||||
let le = LE.remove v c.Constr.le in
|
||||
let le = LE.( Q.(one / coeff) * le) in
|
||||
strict, le
|
||||
|
||||
let rec solve_ (self:system) : res =
|
||||
Log.debugf 50
|
||||
(fun k->k "(@[FM.solve-rec@ :sys %a@])" pp_system self);
|
||||
|
|
@ -334,7 +441,9 @@ module Make(A : ARG)
|
|||
| exception Not_found ->
|
||||
(* need to process biggest variable first *)
|
||||
match T_map.max_binding_opt self.idx with
|
||||
| None -> Sat
|
||||
| None ->
|
||||
let m = lazy (build_model_ self.pre_model) in
|
||||
Sat m
|
||||
| Some (v, {occ_eq=c0 :: ceq'; occ_pos; occ_neg}) ->
|
||||
(* at least one equality constraint, use it as a substitution *)
|
||||
|
||||
|
|
@ -362,6 +471,13 @@ module Make(A : ARG)
|
|||
|> Iter.map (subst_constr v ~tag:c0.Constr.tag ~by:rhs)
|
||||
|> Iter.fold add_sys sys
|
||||
in
|
||||
|
||||
let new_sys =
|
||||
(* update pre-model, keeping only [v := rhs] *)
|
||||
let pre_model = T_map.add v (Lazy.from_val (PM_eq rhs)) self.pre_model in
|
||||
{new_sys with pre_model}
|
||||
in
|
||||
|
||||
solve_ new_sys
|
||||
|
||||
| Some (v, {occ_eq=[]; occ_pos=l_pos; occ_neg=l_neg}) ->
|
||||
|
|
@ -373,10 +489,6 @@ module Make(A : ARG)
|
|||
(* remove [v] *)
|
||||
let sys = {self with idx=T_map.remove v self.idx} in
|
||||
|
||||
(* TODO: store all lower bound constraints for [v], so we can use
|
||||
their max to build the model once we have values for lower
|
||||
variables *)
|
||||
|
||||
let new_sys =
|
||||
Iter.product (Iter.of_list l_pos) (Iter.of_list l_neg)
|
||||
|> Iter.map
|
||||
|
|
@ -402,6 +514,19 @@ module Make(A : ARG)
|
|||
c)
|
||||
|> Iter.fold add_sys sys
|
||||
in
|
||||
|
||||
let new_sys =
|
||||
let pre_model =
|
||||
let pm_c = lazy (
|
||||
let lower = List.rev_map (premod_of_constr v) l_neg in
|
||||
let upper = List.rev_map (premod_of_constr v) l_pos in
|
||||
PM_bounds {lower; upper}
|
||||
) in
|
||||
T_map.add v pm_c self.pre_model
|
||||
in
|
||||
{new_sys with pre_model}
|
||||
in
|
||||
|
||||
solve_ new_sys
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -313,6 +313,8 @@ module Th_lra = Sidekick_arith_lra.Make(struct
|
|||
| T.Eq (a,b) when Ty.equal (T.ty a) Ty.real -> LRA_pred (Eq, a, b)
|
||||
| _ -> LRA_other t
|
||||
|
||||
let has_ty_real t = Ty.equal (T.ty t) Ty.real
|
||||
|
||||
module Gensym = struct
|
||||
type t = {
|
||||
tst: T.state;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue