From cfbd352ca050043e0d1de084482a863941818553 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 16 Feb 2021 13:25:21 -0500 Subject: [PATCH] feat(lra): restore theory combination; improve preprocessing --- src/arith/lra/linear_expr.ml | 2 + src/arith/lra/linear_expr_intf.ml | 3 +- src/arith/lra/sidekick_arith_lra.ml | 307 +++++++++++++--------------- src/arith/lra/simplex2.ml | 4 +- src/smtlib/Process.ml | 4 +- 5 files changed, 153 insertions(+), 167 deletions(-) diff --git a/src/arith/lra/linear_expr.ml b/src/arith/lra/linear_expr.ml index 8691f6cc..ddfda71d 100644 --- a/src/arith/lra/linear_expr.ml +++ b/src/arith/lra/linear_expr.ml @@ -71,6 +71,8 @@ module Make(C : COEFF)(Var : VAR) = struct include Infix + let iter = Var_map.iter + let of_list l = List.fold_left (fun e (c,x) -> add c x e) empty l let to_list e = Var_map.bindings e |> List.rev_map CCPair.swap diff --git a/src/arith/lra/linear_expr_intf.ml b/src/arith/lra/linear_expr_intf.ml index 672818c9..cf0323a9 100644 --- a/src/arith/lra/linear_expr_intf.ml +++ b/src/arith/lra/linear_expr_intf.ml @@ -118,7 +118,6 @@ module type S = sig val add : C.t -> var -> t -> t (** [add n v t] adds the monome [n * v] to the combination [t]. *) - (** Infix operations on combinations This module defines usual operations on linear combinations, @@ -136,6 +135,8 @@ module type S = sig include module type of Infix (** Include the previous module. *) + val iter : (var -> C.t -> unit) -> t -> unit + val of_list : (C.t * var) list -> t val to_list : t -> (C.t * var) list diff --git a/src/arith/lra/sidekick_arith_lra.ml b/src/arith/lra/sidekick_arith_lra.ml index d019c687..715fcc38 100644 --- a/src/arith/lra/sidekick_arith_lra.ml +++ b/src/arith/lra/sidekick_arith_lra.ml @@ -49,8 +49,8 @@ module type ARG = sig val ty_lra : S.T.Term.state -> ty - val mk_and : S.T.Term.state -> term -> term -> term - val mk_or : S.T.Term.state -> term -> term -> term + val mk_eq : S.T.Term.state -> term -> term -> term + (** syntactic equality *) val has_ty_real : term -> bool (** Does this term have the type [Real] *) @@ -89,17 +89,14 @@ module Make(A : ARG) : S with module A = A = struct module Tag = struct type t = - | By_def | Lit of Lit.t | CC_eq of N.t * N.t let pp out = function - | By_def -> Fmt.string out "" | Lit l -> Fmt.fprintf out "(@[lit %a@])" Lit.pp l | CC_eq (n1,n2) -> Fmt.fprintf out "(@[cc-eq@ %a@ %a@])" N.pp n1 N.pp n2 let to_lits si = function - | By_def -> [] | Lit l -> [l] | CC_eq (n1,n2) -> SI.CC.explain_eq (SI.cc si) n1 n2 @@ -121,29 +118,28 @@ module Make(A : ARG) : S with module A = A = struct module LE = LE_.Expr module SimpSolver = Simplex2.Make(SimpVar) module LConstr = SimpSolver.Constraint + module Subst = SimpSolver.Subst + + module Comb_map = CCMap.Make(LE_.Comb) type state = { tst: T.state; simps: T.t T.Tbl.t; (* cache *) gensym: A.Gensym.t; - neq_encoded: unit T.Tbl.t; - (* if [a != b] asserted and not in this table, add clause [a = b \/ ab] *) - needs_th_combination: LE_.Comb.t T.Tbl.t; (* terms that require theory combination *) - t_defs: LE.t T.Tbl.t; (* term definitions *) - pred_defs: (pred * LE.t * LE.t * T.t * T.t) T.Tbl.t; (* predicate definitions *) + encoded_eqs: unit T.Tbl.t; (* [a=b] gets clause [a = b <=> (a >= b /\ a <= b)] *) + needs_th_combination: unit T.Tbl.t; (* terms that require theory combination *) + mutable encoded_le: T.t Comb_map.t; (* [le] -> var encoding [le] *) local_eqs: (N.t * N.t) Backtrack_stack.t; (* inferred by the congruence closure *) simplex: SimpSolver.t; } - (* TODO *) let create tst : state = { tst; simps=T.Tbl.create 128; gensym=A.Gensym.create tst; - neq_encoded=T.Tbl.create 16; + encoded_eqs=T.Tbl.create 8; needs_th_combination=T.Tbl.create 8; - t_defs=T.Tbl.create 8; - pred_defs=T.Tbl.create 16; + encoded_le=Comb_map.empty; local_eqs = Backtrack_stack.create(); simplex=SimpSolver.create (); } @@ -158,50 +154,6 @@ module Make(A : ARG) : S with module A = A = struct Backtrack_stack.pop_levels self.local_eqs n ~f:(fun _ -> ()); () - (* FIXME - let simplify (self:state) (simp:SI.Simplify.t) (t:T.t) : T.t option = - let tst = self.tst in - match A.view_as_bool t with - | B_bool _ -> None - | B_not u when is_true u -> Some (T.bool tst false) - | B_not u when is_false u -> Some (T.bool tst true) - | B_not _ -> None - | B_opaque_bool _ -> None - | B_and a -> - if IArray.exists is_false a then Some (T.bool tst false) - else if IArray.for_all is_true a then Some (T.bool tst true) - else None - | B_or a -> - if IArray.exists is_true a then Some (T.bool tst true) - else if IArray.for_all is_false a then Some (T.bool tst false) - else None - | B_imply (args, u) -> - (* turn into a disjunction *) - let u = - or_a tst @@ - IArray.append (IArray.map (not_ tst) args) (IArray.singleton u) - in - Some u - | B_ite (a,b,c) -> - (* directly simplify [a] so that maybe we never will simplify one - of the branches *) - let a = SI.Simplify.normalize simp a in - begin match A.view_as_bool a with - | B_bool true -> Some b - | B_bool false -> Some c - | _ -> - None - end - | B_equiv (a,b) when is_true a -> Some b - | B_equiv (a,b) when is_false a -> Some (not_ tst b) - | B_equiv (a,b) when is_true b -> Some a - | B_equiv (a,b) when is_false b -> Some (not_ tst a) - | B_equiv _ -> None - | B_eq (a,b) when T.equal a b -> Some (T.bool tst true) - | B_eq _ -> None - | B_atom _ -> None - *) - let fresh_term self ~pre ty = A.Gensym.fresh_term self.gensym ~pre ty let fresh_lit (self:state) ~mk_lit ~pre : Lit.t = let t = fresh_term ~pre self Ty.bool in @@ -232,26 +184,49 @@ module Make(A : ARG) : S with module A = A = struct let as_linexp_id = as_linexp ~f:CCFun.id - (* TODO: keep the linexps until they're asserted; - TODO: but use simplification in preprocess - *) + (* return a variable that is equal to [le_comb] in the simplex. *) + let var_encoding_comb ~pre self (le_comb:LE_.Comb.t) : T.t = + match LE_.Comb.as_singleton le_comb with + | Some (c, x) when Q.(c = one) -> x (* trivial linexp *) + | _ -> + match Comb_map.find le_comb self.encoded_le with + | x -> x (* already encoded that *) + | exception Not_found -> + (* new variable to represent [le_comb] *) + let proxy = fresh_term self ~pre (A.ty_lra self.tst) in + self.encoded_le <- Comb_map.add le_comb proxy self.encoded_le; + Log.debugf 50 + (fun k->k "(@[lra.encode-le@ %a@ :into-var %a@])" LE_.Comb.pp le_comb T.pp proxy); + + LE_.Comb.iter (fun v _ -> SimpSolver.add_var self.simplex v) le_comb; + SimpSolver.define self.simplex proxy (LE_.Comb.to_list le_comb); + proxy (* preprocess linear expressions away *) let preproc_lra (self:state) si ~recurse ~mk_lit ~add_clause (t:T.t) : T.t option = Log.debugf 50 (fun k->k "lra.preprocess %a" T.pp t); let tst = SI.tst si in - let mk_eq x q = - let t1 = A.mk_lra tst (LRA_simplex_pred (x, Leq, q)) in - let t2 = A.mk_lra tst (LRA_simplex_pred (x, Geq, q)) in - A.mk_and tst t1 t2 - and mk_neq x q = - let t1 = A.mk_lra tst (LRA_simplex_pred (x, Lt, q)) in - let t2 = A.mk_lra tst (LRA_simplex_pred (x, Gt, q)) in - A.mk_or tst t1 t2 - in - match A.view_as_lra t with + | LRA_pred ((Eq | Neq), t1, t2) -> + (* the equality side. *) + let t, _ = T.abs tst t in + if not (T.Tbl.mem self.encoded_eqs t) then ( + let u1 = A.mk_lra tst (LRA_pred (Leq, t1, t2)) in + let u2 = A.mk_lra tst (LRA_pred (Geq, t1, t2)) in + + T.Tbl.add self.encoded_eqs t (); + + (* encode [t <=> (u1 /\ u2)] *) + let lit_t = mk_lit t in + let lit_u1 = mk_lit u1 in + let lit_u2 = mk_lit u2 in + add_clause [SI.Lit.neg lit_t; lit_u1]; + add_clause [SI.Lit.neg lit_t; lit_u2]; + add_clause [SI.Lit.neg lit_u1; SI.Lit.neg lit_u2; lit_t]; + ); + None + | LRA_pred (pred, t1, t2) -> let l1 = as_linexp ~f:recurse t1 in let l2 = as_linexp ~f:recurse t2 in @@ -263,34 +238,23 @@ module Make(A : ARG) : S with module A = A = struct begin match LE_.Comb.as_singleton le_comb, pred with | None, _ -> (* non trivial linexp, give it a fresh name in the simplex *) - let proxy = fresh_term self ~pre:"_le" (T.ty t1) in - T.Tbl.replace self.needs_th_combination proxy le_comb; - - let le_comb = LE_.Comb.to_list le_comb in - List.iter (fun (_,v) -> SimpSolver.add_var self.simplex v) le_comb; - SimpSolver.define self.simplex proxy le_comb; + let proxy = var_encoding_comb self ~pre:"_le" le_comb in + T.Tbl.replace self.needs_th_combination proxy (); let new_t = match pred with - | Eq -> mk_eq proxy le_const - | Neq -> mk_neq proxy le_const + | Eq | Neq -> assert false (* unreachable *) | Leq -> A.mk_lra tst (LRA_simplex_pred (proxy, S_op.Leq, le_const)) | Lt -> A.mk_lra tst (LRA_simplex_pred (proxy, S_op.Lt, le_const)) | Geq -> A.mk_lra tst (LRA_simplex_pred (proxy, S_op.Geq, le_const)) | Gt -> A.mk_lra tst (LRA_simplex_pred (proxy, S_op.Gt, le_const)) in - Log.debugf 10 (fun k->k "lra.preprocess@ :%a@ :into %a" T.pp t T.pp new_t); + Log.debugf 10 (fun k->k "lra.preprocess:@ %a@ :into %a" T.pp t T.pp new_t); + + T.Tbl.add self.needs_th_combination new_t (); Some new_t - | Some (coeff, v), Eq -> - let q = Q.(le_const / coeff) in - Some (mk_eq v q) (* turn into [c.v <= const /\ … >= ..] *) - - | Some (coeff, v), Neq -> - let q = Q.(le_const / coeff) in - Some (mk_neq v q) (* turn into [c.v < const \/ … > ..] *) - | Some (coeff, v), pred -> (* [c . v <= const] becomes a direct simplex constraint [v <= const/c] *) let q = Q.div le_const coeff in @@ -307,21 +271,19 @@ module Make(A : ARG) : S with module A = A = struct let new_t = A.mk_lra tst (LRA_simplex_pred (v, op, q)) in Log.debugf 10 (fun k->k "lra.preprocess@ :%a@ :into %a" T.pp t T.pp new_t); + + T.Tbl.add self.needs_th_combination new_t (); Some new_t end | LRA_op _ | LRA_mult _ -> let le = as_linexp ~f:recurse t in let le_comb, le_const = LE.comb le, LE.const le in - let le_comb = LE_.Comb.to_list le_comb in - List.iter (fun (_,v) -> SimpSolver.add_var self.simplex v) le_comb; - - let proxy = fresh_term self ~pre:"_le" (T.ty t) in if Q.(le_const = zero) then ( (* if there is no constant, define [proxy] as [proxy := le_comb] and return [proxy] *) - SimpSolver.define self.simplex proxy le_comb; + let proxy = var_encoding_comb self ~pre:"_le" le_comb in Some proxy ) else ( (* a bit more complicated: we cannot just define [proxy := le_comb] @@ -329,9 +291,18 @@ module Make(A : ARG) : S with module A = A = struct Instead we assert [proxy - le_comb = le_const] using a secondary variable [proxy2 := le_comb - proxy] and asserting [proxy2 = -le_const] *) + let proxy = fresh_term self ~pre:"_le" (T.ty t) in let proxy2 = fresh_term self ~pre:"_le_diff" (T.ty t) in + + SimpSolver.add_var self.simplex proxy; + LE_.Comb.iter (fun v _ -> SimpSolver.add_var self.simplex v) le_comb; + SimpSolver.define self.simplex proxy2 - ((Q.minus_one, proxy) :: le_comb); + ((Q.minus_one, proxy) :: LE_.Comb.to_list le_comb); + + Log.debugf 50 + (fun k->k "(@[lra.encode-le.with-offset@ %a@ :var %a@ :diff-var %a@])" + LE_.Comb.pp le_comb T.pp proxy T.pp proxy2); add_clause [ mk_lit (A.mk_lra tst (LRA_simplex_pred (proxy2, Leq, Q.neg le_const))) @@ -344,8 +315,7 @@ module Make(A : ARG) : S with module A = A = struct ) | LRA_other t when A.has_ty_real t -> - let le = LE_.Comb.monomial1 t in - T.Tbl.replace self.needs_th_combination t le; + T.Tbl.replace self.needs_th_combination t (); None | LRA_const _ | LRA_simplex_pred _ | LRA_simplex_var _ | LRA_other _ -> None @@ -375,6 +345,81 @@ module Make(A : ARG) : S with module A = A = struct (* TODO: trivial propagations *) + let add_local_eq (self:state) si acts n1 n2 : unit = + Log.debugf 20 (fun k->k "(@[lra.add-local-eq@ %a@ %a@])" N.pp n1 N.pp n2); + let t1 = N.term n1 in + let t2 = N.term n2 in + let t1, t2 = if T.compare t1 t2 > 0 then t2, t1 else t1, t2 in + + let le = LE.(as_linexp_id t1 - as_linexp_id t2) in + let le_comb, le_const = LE.comb le, LE.const le in + let le_const = Q.neg le_const in + + let v = var_encoding_comb ~pre:"le_local_eq" self le_comb in + let lit = Tag.CC_eq (n1,n2) in + begin + try + let c1 = SimpSolver.Constraint.geq v le_const in + SimpSolver.add_constraint self.simplex c1 lit; + let c2 = SimpSolver.Constraint.leq v le_const in + SimpSolver.add_constraint self.simplex c2 lit; + with SimpSolver.E_unsat cert -> + fail_with_cert si acts cert + end; + () + + (* theory combination: add decisions [t=u] whenever [t] and [u] + have the same value in [subst] and both occur under function symbols *) + let do_th_combination (self:state) si acts (subst:Subst.t) : unit = + let n_th_comb = T.Tbl.keys self.needs_th_combination |> Iter.length in + if n_th_comb > 0 then ( + Log.debugf 5 + (fun k->k "(@[LRA.needs-th-combination@ :n-lits %d@])" n_th_comb); + ); + Log.debugf 50 + (fun k->k "(@[LRA.needs-th-combination@ :lits %a@])" + (Util.pp_iter @@ Fmt.within "`" "`" T.pp) (T.Tbl.keys self.needs_th_combination)); + + (* theory combination: for [t1,t2] terms in [self.needs_th_combination] + that have same value, but are not provably equal, push + decision [t1=t2] into the SAT solver. *) + begin + let by_val: T.t list Q_map.t = + T.Tbl.keys self.needs_th_combination + |> Iter.map (fun t -> Subst.eval subst t, t) + |> Iter.fold + (fun m (q,t) -> + let l = Q_map.get_or ~default:[] q m in + Q_map.add q (t::l) m) + Q_map.empty + in + Q_map.iter + (fun _q ts -> + begin match ts with + | [] | [_] -> () + | ts -> + (* several terms! see if they are already equal *) + CCList.diagonal ts + |> List.iter + (fun (t1,t2) -> + Log.debugf 50 + (fun k->k "(@[LRA.th-comb.check-pair[val=%a]@ %a@ %a@])" + Q.pp_print _q T.pp t1 T.pp t2); + (* FIXME: we need these equalities to be considered + by the congruence closure *) + if not (SI.cc_are_equal si t1 t2) then ( + Log.debug 50 "LRA.th-comb.must-decide-equal"; + let t = A.mk_eq (SI.tst si) t1 t2 in + let lit = SI.mk_lit si acts t in + SI.push_decision si acts lit + ) + ) + end) + by_val; + () + end; + () + (* partial checks is where we add literals from the trail to the simplex. *) let partial_check_ self si acts trail : unit = @@ -418,76 +463,16 @@ module Make(A : ARG) : S with module A = A = struct let final_check_ (self:state) si (acts:SI.actions) (_trail:_ Iter.t) : unit = Log.debug 5 "(th-lra.final-check)"; Profile.with_ "lra.final-check" @@ fun () -> - (* FIXME + (* add congruence closure equalities *) Backtrack_stack.iter self.local_eqs - ~f:(fun (n1,n2) -> - let t1 = N.term n1 |> as_linexp_id in - let t2 = N.term n2 |> as_linexp_id in - let c = LConstr.eq0 LE.(t1 - t2) in - let lit = Tag.CC_eq (n1,n2) in - SimpSolver.add_constr simplex c lit); - *) + ~f:(fun (n1,n2) -> add_local_eq self si acts n1 n2); - Log.debug 5 "lra: call arith solver"; + Log.debug 5 "(th-lra: call arith solver)"; let model = check_simplex_ self si acts in Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model); Log.debug 5 "lra: solver returns SAT"; - let n_th_comb = - T.Tbl.keys self.needs_th_combination |> Iter.length - in - if n_th_comb > 0 then ( - Log.debugf 5 - (fun k->k "(@[LRA.needs-th-combination@ :n-lits %d@])" n_th_comb); - ); - Log.debugf 50 - (fun k->k "(@[LRA.needs-th-combination@ :lits %a@])" - (Util.pp_iter @@ Fmt.within "`" "`" T.pp) (T.Tbl.keys self.needs_th_combination)); - - (* FIXME: theory combination - let lazy model = model in - Log.debugf 30 (fun k->k "(@[LRA.model@ %a@])" FM_A.pp_model model); - - (* theory combination: for [t1,t2] terms in [self.needs_th_combination] - that have same value, but are not provably equal, push - decision [t1=t2] into the SAT solver. *) - begin - let by_val: T.t list Q_map.t = - T.Tbl.to_iter self.needs_th_combination - |> Iter.map (fun (t,le) -> FM_A.eval_model model le, t) - |> Iter.fold - (fun m (q,t) -> - let l = Q_map.get_or ~default:[] q m in - Q_map.add q (t::l) m) - Q_map.empty - in - Q_map.iter - (fun _q ts -> - begin match ts with - | [] | [_] -> () - | ts -> - (* several terms! see if they are already equal *) - CCList.diagonal ts - |> List.iter - (fun (t1,t2) -> - Log.debugf 50 - (fun k->k "(@[LRA.th-comb.check-pair[val=%a]@ %a@ %a@])" - Q.pp_print _q T.pp t1 T.pp t2); - (* FIXME: we need these equalities to be considered - by the congruence closure *) - if not (SI.cc_are_equal si t1 t2) then ( - Log.debug 50 "LRA.th-comb.must-decide-equal"; - let t = A.mk_lra (SI.tst si) (LRA_pred (Eq, t1, t2)) in - let lit = SI.mk_lit si acts t in - SI.push_decision si acts lit - ) - ) - end) - by_val; - () - end; - *) - + do_th_combination self si acts model; () let create_and_setup si = @@ -502,8 +487,6 @@ module Make(A : ARG) : S with module A = A = struct if A.has_ty_real (N.term n1) then ( Backtrack_stack.push st.local_eqs (n1, n2) )); - (* SI.add_preprocess si (cnf st); *) - (* TODO: theory combination *) st let theory = diff --git a/src/arith/lra/simplex2.ml b/src/arith/lra/simplex2.ml index 224cb778..aa4ec765 100644 --- a/src/arith/lra/simplex2.ml +++ b/src/arith/lra/simplex2.ml @@ -64,6 +64,7 @@ module type S = sig module Subst : sig type t = num V_map.t + val eval : t -> V.t -> Q.t val pp : t Fmt.printer val to_string : t -> string end @@ -155,6 +156,7 @@ module Make(Var: VAR) module Subst = struct type t = num V_map.t + let eval self t = try V_map.find t self with Not_found -> Q.zero let pp out (self:t) : unit = let pp_pair out (v,n) = Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in @@ -533,7 +535,7 @@ module Make(Var: VAR) assert (Var_state.is_basic x_j); (* value of [x_j] by [a_ji * diff] *) let new_val = Erat.(x_j.value + a_ji * diff) in - Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); + (* Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); *) x_j.value <- new_val; done; x.value <- v; diff --git a/src/smtlib/Process.ml b/src/smtlib/Process.ml index fbecaa78..fa08d992 100644 --- a/src/smtlib/Process.ml +++ b/src/smtlib/Process.ml @@ -312,9 +312,7 @@ module Th_lra = Sidekick_arith_lra.Make(struct type term = S.T.Term.t type ty = S.T.Ty.t - let mk_and = Form.and_ - let mk_or = Form.or_ - + let mk_eq = Form.eq let mk_lra = T.lra let view_as_lra t = match T.view t with | T.LRA l -> l