diff --git a/src/arith/lra/Sidekick_arith_lra.ml b/src/arith/lra/Sidekick_arith_lra.ml index 85946fb4..a7971d28 100644 --- a/src/arith/lra/Sidekick_arith_lra.ml +++ b/src/arith/lra/Sidekick_arith_lra.ml @@ -38,6 +38,9 @@ module type ARG = sig val mk_lra : S.T.Term.state -> term lra_view -> term (** Make a term from the given theory view *) + val has_ty_real : term -> bool + (** Does this term have the type [Real] *) + module Gensym : sig type t @@ -81,6 +84,7 @@ module Make(A : ARG) : S with module A = A = struct gensym: A.Gensym.t; neq_encoded: unit T.Tbl.t; (* if [a != b] asserted and not in this table, add clause [a = b \/ ab] *) + needs_th_combination: LE.t T.Tbl.t; (* terms that require theory combination *) mutable t_defs: (T.t * LE.t) list; (* term definitions *) pred_defs: (pred * LE.t * LE.t * T.t * T.t) T.Tbl.t; (* predicate definitions *) } @@ -90,6 +94,7 @@ module Make(A : ARG) : S with module A = A = struct simps=T.Tbl.create 128; gensym=A.Gensym.create tst; neq_encoded=T.Tbl.create 16; + needs_th_combination=T.Tbl.create 8; t_defs=[]; pred_defs=T.Tbl.create 16; } @@ -170,7 +175,7 @@ module Make(A : ARG) : S with module A = A = struct *) (* preprocess linear expressions away *) - let preproc_lra self si ~recurse ~mk_lit:_ ~add_clause:_ (t:T.t) : T.t option = + let preproc_lra (self:state) si ~recurse ~mk_lit:_ ~add_clause:_ (t:T.t) : T.t option = Log.debugf 50 (fun k->k "lra.preprocess %a" T.pp t); let _tst = SI.tst si in match A.view_as_lra t with @@ -184,11 +189,17 @@ module Make(A : ARG) : S with module A = A = struct Some proxy | LRA_op _ | LRA_mult _ -> let le = as_linexp ~f:recurse t in + (* TODO: reuse proxy if present? *) let proxy = fresh_term self ~pre:"_e_lra_" (T.ty t) in self.t_defs <- (proxy, le) :: self.t_defs; + T.Tbl.add self.needs_th_combination t le; Log.debugf 5 (fun k->k"@[lra.preprocess.step %a@ :into %a@ :def %a@]" T.pp t T.pp proxy LE.pp le); Some proxy + | LRA_other t when A.has_ty_real t -> + let le = LE.var t in + T.Tbl.replace self.needs_th_combination t le; + None | LRA_const _ | LRA_other _ -> None (* ensure that [a != b] triggers the clause @@ -269,8 +280,12 @@ module Make(A : ARG) : S with module A = A = struct end; Log.debug 5 "lra: call arith solver"; begin match FM_A.solve fm with - | FM_A.Sat -> + | FM_A.Sat model -> Log.debug 5 "lra: solver returns SAT"; + Log.debugf 50 + (fun k->k "(@[LRA.needs-th-combination:@ %a@])" + (Util.pp_iter @@ Fmt.within "`" "`" T.pp) (T.Tbl.keys self.needs_th_combination)); + Log.debugf 30 (fun k->k "(@[LRA.model@ %a@])" FM_A.pp_model model); () (* TODO: get a model + model combination *) | FM_A.Unsat lits -> (* we tagged assertions with their lit, so the certificate being an diff --git a/src/arith/lra/fourier_motzkin.ml b/src/arith/lra/fourier_motzkin.ml index 24511862..8a01995d 100644 --- a/src/arith/lra/fourier_motzkin.ml +++ b/src/arith/lra/fourier_motzkin.ml @@ -61,8 +61,6 @@ module type S = sig val find : term -> t -> Q.t option val mem : term -> t -> bool -(* val map : (term -> term) -> t -> t *) - module Infix : sig val (+) : t -> t -> t val (-) : t -> t -> t @@ -86,8 +84,13 @@ module type S = sig val assert_c : t -> Constr.t -> unit + type model + + val get_model : model -> term -> Q.t + val pp_model : model Fmt.printer + type res = - | Sat + | Sat of model | Unsat of A.tag list val solve : t -> res @@ -237,8 +240,20 @@ module Make(A : ARG) occ_neg: Constr.t list; } + type pre_model_strict = Strict | NonStrict + type pre_model_constr = + | PM_eq of LE.t + | PM_bounds of { + lower: (pre_model_strict * LE.t) list; + upper: (pre_model_strict * LE.t) list; + } + + type pre_model = pre_model_constr lazy_t T_map.t + type model = Q.t T_map.t lazy_t + type system = { empties: Constr.t list; (* no variables, check first *) + pre_model: pre_model; (* for model construction *) idx: c_for_var T_map.t; (* map [t] to [cft] where [cft] are normalized constraints whose maximum term is [t], with positive sign for [cft.occ_pos] @@ -250,7 +265,7 @@ module Make(A : ARG) mutable sys: system; } - let empty_sys : system = {empties=[]; idx=T_map.empty} + let empty_sys : system = {empties=[]; pre_model=T_map.empty; idx=T_map.empty} let empty_c_for_v : c_for_var = { occ_pos=[]; occ_neg=[]; occ_eq=[] } @@ -294,7 +309,8 @@ module Make(A : ARG) let pp_system out (self:system) : unit = let pp_idxkv out (t,{occ_eq; occ_pos; occ_neg}) = - Fmt.fprintf out "(@[for-var %a@ :occ-eq %a@ :occ-pos %a@ :occ-neg %a@])" + Fmt.fprintf out + "(@[for-var %a@ @[:occ-eq %a@]@ @[:occ-pos %a@]@ @[:occ-neg %a@]@])" T.pp t (Fmt.Dump.list Constr.pp) occ_eq (Fmt.Dump.list Constr.pp) occ_pos @@ -305,8 +321,82 @@ module Make(A : ARG) (Util.pp_iter pp_idxkv) (T_map.to_iter self.idx) (* TODO: be able to provide a model for SAT *) + let build_model_ (self:pre_model) : _ T_map.t = + let l = T_map.to_iter self |> Iter.to_rev_list in + + (* how to evaluate a linexpr in the model *) + let eval_le (mv:Q.t T_map.t) (le:LE.t) : Q.t = + let find x = try T_map.find x mv with Not_found -> Q.zero in + T_map.to_iter le.LE.le + |> Iter.fold + (fun sum (t,coeff) -> Q.(sum + coeff * find t)) + le.LE.const + in + let or_strict s1 s2 = match s1, s2 with + | Strict, _ | _, Strict -> Strict + | NonStrict, NonStrict -> NonStrict + in + let max_pair (s1,q1)(s2,q2) = + if Q.equal q1 q2 then or_strict s1 s2, q1 + else if Q.gt q1 q2 then s1,q1 + else s2,q2 + and min_pair (s1,q1)(s2,q2) = + if Q.equal q1 q2 then or_strict s1 s2, q1 + else if Q.lt q1 q2 then s1,q1 + else s2,q2 + in + let m = + List.fold_left + begin fun m (v,cs_v) -> + (* update [v] using its constraints [cs_v]. + [m] is the model to update *) + let val_v = + match cs_v with + | lazy (PM_eq le) -> eval_le m le + | lazy (PM_bounds {lower; upper}) -> + let lower = List.map (fun (s,le) -> s, eval_le m le) lower in + let upper = List.map (fun (s,le) -> s, eval_le m le) upper in + let strict_low, lower = match lower with + | [] -> NonStrict, Q.minus_inf + | x :: l -> List.fold_left max_pair x l + and strict_up, upper = match upper with + | [] -> NonStrict, Q.inf + | x :: l -> List.fold_left min_pair x l + in + if Q.is_real lower && Q.is_real upper then ( + if Q.equal lower upper then ( + assert (strict_low=NonStrict && strict_up=NonStrict); (* unsat otherwise *) + lower + ) else ( + Q.((lower + upper) / of_int 2) (* middle *) + ) + ) else if Q.is_real lower then ( + if strict_low=Strict then Q.(lower + one) else lower + ) else if Q.is_real upper then ( + if strict_up=Strict then Q.(upper - one) else upper + ) else ( + Q.zero (* no bounds *) + ) + in + T_map.add v val_v m + end + T_map.empty l + in + m + + let get_model (m:model) (v:T.t) : Q.t = + let lazy m = m in + try T_map.find v m + with Not_found -> Q.zero + + let pp_model out (m:model) : unit = + let lazy m = m in + let pp_pair out (v,q) = Fmt.fprintf out "(@[%a@ %a@])" T.pp v Q.pp_print q in + Fmt.fprintf out "(@[model@ %a@])" + (Util.pp_iter pp_pair) (T_map.to_iter m) + type res = - | Sat + | Sat of model | Unsat of A.tag list (* replace [x] with [by] inside [le] *) @@ -324,6 +414,23 @@ module Make(A : ARG) } in Constr.normalize c + (* given an ineq constraint on [v], canonize it wrt [v] + (set the coeff of [v] to 1) + and return whether it's strict or not *) + let premod_of_constr (v:T.t) (c:Constr.t) : pre_model_strict * LE.t = + let strict = + match c.Constr.pred with + | Pred.Leq -> NonStrict | Pred.Lt -> Strict + | _ -> assert false + in + let coeff = + try LE.find_exn v c.Constr.le + with Not_found -> assert false + in + let le = LE.remove v c.Constr.le in + let le = LE.( Q.(one / coeff) * le) in + strict, le + let rec solve_ (self:system) : res = Log.debugf 50 (fun k->k "(@[FM.solve-rec@ :sys %a@])" pp_system self); @@ -334,7 +441,9 @@ module Make(A : ARG) | exception Not_found -> (* need to process biggest variable first *) match T_map.max_binding_opt self.idx with - | None -> Sat + | None -> + let m = lazy (build_model_ self.pre_model) in + Sat m | Some (v, {occ_eq=c0 :: ceq'; occ_pos; occ_neg}) -> (* at least one equality constraint, use it as a substitution *) @@ -362,6 +471,13 @@ module Make(A : ARG) |> Iter.map (subst_constr v ~tag:c0.Constr.tag ~by:rhs) |> Iter.fold add_sys sys in + + let new_sys = + (* update pre-model, keeping only [v := rhs] *) + let pre_model = T_map.add v (Lazy.from_val (PM_eq rhs)) self.pre_model in + {new_sys with pre_model} + in + solve_ new_sys | Some (v, {occ_eq=[]; occ_pos=l_pos; occ_neg=l_neg}) -> @@ -373,10 +489,6 @@ module Make(A : ARG) (* remove [v] *) let sys = {self with idx=T_map.remove v self.idx} in - (* TODO: store all lower bound constraints for [v], so we can use - their max to build the model once we have values for lower - variables *) - let new_sys = Iter.product (Iter.of_list l_pos) (Iter.of_list l_neg) |> Iter.map @@ -402,6 +514,19 @@ module Make(A : ARG) c) |> Iter.fold add_sys sys in + + let new_sys = + let pre_model = + let pm_c = lazy ( + let lower = List.rev_map (premod_of_constr v) l_neg in + let upper = List.rev_map (premod_of_constr v) l_pos in + PM_bounds {lower; upper} + ) in + T_map.add v pm_c self.pre_model + in + {new_sys with pre_model} + in + solve_ new_sys end diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index 3ae64420..305bacde 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -313,6 +313,8 @@ module Th_lra = Sidekick_arith_lra.Make(struct | T.Eq (a,b) when Ty.equal (T.ty a) Ty.real -> LRA_pred (Eq, a, b) | _ -> LRA_other t + let has_ty_real t = Ty.equal (T.ty t) Ty.real + module Gensym = struct type t = { tst: T.state;