From b3a7acf95bf378519699eedb0c19711e21456872 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 17 Nov 2020 18:24:09 -0500 Subject: [PATCH] feat(LRA): handle congruence closure and theory combination in LRA - merges in the CC are handled by adding corresponding equalities locally - theory combination pushes the decision `a=b` into the SAT solver if a,b have the same model values and are not provably equal in the CC already. - also, fix model construction --- src/arith/base-term/Base_types.ml | 2 + src/arith/lra/Sidekick_arith_lra.ml | 166 ++++++++++++++++++++++------ src/arith/lra/fourier_motzkin.ml | 31 ++++-- 3 files changed, 159 insertions(+), 40 deletions(-) diff --git a/src/arith/base-term/Base_types.ml b/src/arith/base-term/Base_types.ml index ed7df0eb..660f19c1 100644 --- a/src/arith/base-term/Base_types.ml +++ b/src/arith/base-term/Base_types.ml @@ -914,6 +914,8 @@ end = struct | Eq (a,b) -> C.Eq (a, b) | Not u -> C.Not u | Ite (a,b,c) -> C.If (a,b,c) + | LRA (LRA_pred (Eq, a, b)) -> + C.Eq (a,b) (* need congruence closure on this one, for theory combination *) | LRA _ -> C.Opaque t (* no congruence here *) module As_key = struct diff --git a/src/arith/lra/Sidekick_arith_lra.ml b/src/arith/lra/Sidekick_arith_lra.ml index a7971d28..77e72195 100644 --- a/src/arith/lra/Sidekick_arith_lra.ml +++ b/src/arith/lra/Sidekick_arith_lra.ml @@ -67,17 +67,34 @@ module Make(A : ARG) : S with module A = A = struct module T = A.S.T.Term module Lit = A.S.Solver_internal.Lit module SI = A.S.Solver_internal + module N = A.S.Solver_internal.CC.N + + module Tag = struct + type t = + | Lit of Lit.t + | CC_eq of N.t * N.t + + let pp out = function + | Lit l -> Fmt.fprintf out "(@[lit %a@])" Lit.pp l + | CC_eq (n1,n2) -> Fmt.fprintf out "(@[cc-eq@ %a@ %a@])" N.pp n1 N.pp n2 + + let to_lits si = function + | Lit l -> [l] + | CC_eq (n1,n2) -> + SI.CC.explain_eq (SI.cc si) n1 n2 + end (* the fourier motzkin module *) module FM_A = FM.Make(struct module T = T - type tag = Lit.t - let pp_tag = Lit.pp + type tag = Tag.t + let pp_tag = Tag.pp end) (* linear expressions *) module LE = FM_A.LE + type proxy = T.t type state = { tst: T.state; simps: T.t T.Tbl.t; (* cache *) @@ -85,8 +102,9 @@ module Make(A : ARG) : S with module A = A = struct neq_encoded: unit T.Tbl.t; (* if [a != b] asserted and not in this table, add clause [a = b \/ ab] *) needs_th_combination: LE.t T.Tbl.t; (* terms that require theory combination *) - mutable t_defs: (T.t * LE.t) list; (* term definitions *) + t_defs: LE.t T.Tbl.t; (* term definitions *) pred_defs: (pred * LE.t * LE.t * T.t * T.t) T.Tbl.t; (* predicate definitions *) + local_eqs: (N.t * N.t) Backtrack_stack.t; (* inferred by the congruence closure *) } let create tst : state = @@ -95,10 +113,19 @@ module Make(A : ARG) : S with module A = A = struct gensym=A.Gensym.create tst; neq_encoded=T.Tbl.create 16; needs_th_combination=T.Tbl.create 8; - t_defs=[]; + t_defs=T.Tbl.create 8; pred_defs=T.Tbl.create 16; + local_eqs = Backtrack_stack.create(); } + let push_level self = + Backtrack_stack.push_level self.local_eqs; + () + + let pop_levels self n = + Backtrack_stack.pop_levels self.local_eqs n ~f:(fun _ -> ()); + () + (* FIXME let simplify (self:state) (simp:SI.Simplify.t) (t:T.t) : T.t option = let tst = self.tst in @@ -170,6 +197,8 @@ module Make(A : ARG) : S with module A = A = struct LE.( n * t ) | LRA_const q -> LE.const q + let as_linexp_id = as_linexp ~f:CCFun.id + (* TODO: keep the linexps until they're asserted; TODO: but use simplification in preprocess *) @@ -177,8 +206,14 @@ module Make(A : ARG) : S with module A = A = struct (* preprocess linear expressions away *) let preproc_lra (self:state) si ~recurse ~mk_lit:_ ~add_clause:_ (t:T.t) : T.t option = Log.debugf 50 (fun k->k "lra.preprocess %a" T.pp t); - let _tst = SI.tst si in + let tst = SI.tst si in match A.view_as_lra t with + | LRA_pred ((Eq|Neq) as pred, t1, t2) -> + (* keep equality as is, needed for congruence closure *) + let t1 = recurse t1 in + let t2 = recurse t2 in + let u = A.mk_lra tst (LRA_pred (pred, t1, t2)) in + if T.equal t u then None else Some u | LRA_pred (pred, t1, t2) -> let l1 = as_linexp ~f:recurse t1 in let l2 = as_linexp ~f:recurse t2 in @@ -189,10 +224,9 @@ module Make(A : ARG) : S with module A = A = struct Some proxy | LRA_op _ | LRA_mult _ -> let le = as_linexp ~f:recurse t in - (* TODO: reuse proxy if present? *) let proxy = fresh_term self ~pre:"_e_lra_" (T.ty t) in - self.t_defs <- (proxy, le) :: self.t_defs; - T.Tbl.add self.needs_th_combination t le; + T.Tbl.add self.t_defs proxy le; + T.Tbl.add self.needs_th_combination proxy le; Log.debugf 5 (fun k->k"@[lra.preprocess.step %a@ :into %a@ :def %a@]" T.pp t T.pp proxy LE.pp le); Some proxy @@ -213,17 +247,20 @@ module Make(A : ARG) : S with module A = A = struct let t = Lit.term lit in Log.debugf 50 (fun k->k "@[lra: check lit %a@ :t %a@ :sign %B@]" Lit.pp lit T.pp t (Lit.sign lit)); + + let check_pred pred a b = + let pred = if Lit.sign lit then pred else FM.Pred.neg pred in + Log.debugf 50 (fun k->k "pred = `%s`" (FM.Pred.to_string pred)); + if pred = Neq && not (T.Tbl.mem self.neq_encoded t) then ( + Some (lit, a, b) + ) else None + in + begin match T.Tbl.find self.pred_defs t with - | (pred, _, _, ta, tb) -> - let pred = if Lit.sign lit then pred else FM.Pred.neg pred in - Log.debugf 50 (fun k->k "pred = `%s`" (FM.Pred.to_string pred)); - if pred = Neq && not (T.Tbl.mem self.neq_encoded t) then ( - Some (lit, ta, tb) - ) else None + | (pred, _, _, ta, tb) -> check_pred pred ta tb | exception Not_found -> begin match A.view_as_lra t with - | LRA_pred (Neq, a, b) when not (T.Tbl.mem self.neq_encoded t) -> - Some (lit, a, b) + | LRA_pred (pred, a, b) -> check_pred pred a b | _ -> None end end) @@ -246,19 +283,28 @@ module Make(A : ARG) : S with module A = A = struct List.iter (fun l -> LTbl.replace tbl l ()) lits; LTbl.keys_list tbl + module Q_map = CCMap.Make(Q) + let final_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit = Log.debug 5 "(th-lra.final-check)"; encode_neq self si acts trail; let fm = FM_A.create() in (* first, add definitions *) begin - List.iter - (fun (t,le) -> + T.Tbl.iter + (fun t le -> let open LE.Infix in let c = FM_A.Constr.mk ?tag:None Eq (LE.var t) le in FM_A.assert_c fm c) self.t_defs end; + (* add congruence closure equalities *) + Backtrack_stack.iter self.local_eqs + ~f:(fun (n1,n2) -> + let t1 = N.term n1 |> as_linexp_id in + let t2 = N.term n2 |> as_linexp_id in + let c = FM_A.Constr.mk ~tag:(Tag.CC_eq (n1,n2)) Eq t1 t2 in + FM_A.assert_c fm c); (* add trail *) begin trail @@ -266,16 +312,25 @@ module Make(A : ARG) : S with module A = A = struct (fun lit -> let sign = Lit.sign lit in let t = Lit.term lit in + let assert_pred pred a b = + let pred = if sign then pred else FM.Pred.neg pred in + if pred = Neq then ( + Log.debugf 50 (fun k->k "skip neq in %a" T.pp t); + ) else ( + let c = FM_A.Constr.mk ~tag:(Tag.Lit lit) pred a b in + FM_A.assert_c fm c; + ) + in begin match T.Tbl.find self.pred_defs t with - | exception Not_found -> () - | (pred, a, b, _, _) -> - let pred = if sign then pred else FM.Pred.neg pred in - if pred = Neq then ( - Log.debugf 50 (fun k->k "skip neq in %a" T.pp t); - ) else ( - let c = FM_A.Constr.mk ~tag:lit pred a b in - FM_A.assert_c fm c; - ) + | (pred, a, b, _, _) -> assert_pred pred a b + | exception Not_found -> + begin match A.view_as_lra t with + | LRA_pred (pred, a, b) -> + let a = try T.Tbl.find self.t_defs a with _ -> as_linexp_id a in + let b = try T.Tbl.find self.t_defs b with _ -> as_linexp_id b in + assert_pred pred a b + | _ -> () + end end) end; Log.debug 5 "lra: call arith solver"; @@ -286,14 +341,56 @@ module Make(A : ARG) : S with module A = A = struct (fun k->k "(@[LRA.needs-th-combination:@ %a@])" (Util.pp_iter @@ Fmt.within "`" "`" T.pp) (T.Tbl.keys self.needs_th_combination)); Log.debugf 30 (fun k->k "(@[LRA.model@ %a@])" FM_A.pp_model model); - () (* TODO: get a model + model combination *) - | FM_A.Unsat lits -> + + (* theory combination: for [t1,t2] terms in [self.needs_th_combination] + that have same value, but are not provably equal, push + decision [t1=t2] into the SAT solver. *) + begin + let by_val: T.t list Q_map.t = + T.Tbl.to_iter self.needs_th_combination + |> Iter.map (fun (t,le) -> FM_A.eval_model model le, t) + |> Iter.fold + (fun m (q,t) -> + let l = Q_map.get_or ~default:[] q m in + Q_map.add q (t::l) m) + Q_map.empty + in + Q_map.iter + (fun _q ts -> + begin match ts with + | [] | [_] -> () + | ts -> + (* several terms! see if they are already equal *) + CCList.diagonal ts + |> List.iter + (fun (t1,t2) -> + Log.debugf 50 + (fun k->k "(@[LRA.th-comb.check-pair[val=%a]@ %a@ %a@])" + Q.pp_print _q T.pp t1 T.pp t2); + (* FIXME: we need these equalities to be considered + by the congruence closure *) + if not (SI.cc_are_equal si t1 t2) then ( + Log.debug 50 "LRA.th-comb.must-decide-equal"; + let t = A.mk_lra (SI.tst si) (LRA_pred (Eq, t1, t2)) in + let lit = SI.mk_lit si acts t in + SI.push_decision si acts lit + ) + ) + end) + by_val; + () + end; + () + | FM_A.Unsat tags -> (* we tagged assertions with their lit, so the certificate being an unsat core translates directly into a conflict clause *) Log.debugf 5 (fun k->k"lra: solver returns UNSAT@ with cert %a" - (Fmt.Dump.list Lit.pp) lits); + (Fmt.Dump.list Tag.pp) tags); let confl = - List.rev_map Lit.neg lits |> dedup_lits + tags + |> CCList.flat_map (fun t -> Tag.to_lits si t) + |> List.rev_map Lit.neg + |> dedup_lits in (* TODO: produce and store a proper LRA resolution proof *) SI.raise_conflict si acts confl SI.P.default @@ -306,6 +403,11 @@ module Make(A : ARG) : S with module A = A = struct (* TODO SI.add_simplifier si (simplify st); *) SI.add_preprocess si (preproc_lra st); SI.on_final_check si (final_check_ st); + SI.on_cc_post_merge si + (fun _ _ n1 n2 -> + if A.has_ty_real (N.term n1) then ( + Backtrack_stack.push st.local_eqs (n1, n2) + )); (* SI.add_preprocess si (cnf st); *) (* TODO: theory combination *) st @@ -313,6 +415,6 @@ module Make(A : ARG) : S with module A = A = struct let theory = A.S.mk_theory ~name:"th-lra" - ~create_and_setup + ~create_and_setup ~push_level ~pop_levels () end diff --git a/src/arith/lra/fourier_motzkin.ml b/src/arith/lra/fourier_motzkin.ml index 56d3cf4f..43707e48 100644 --- a/src/arith/lra/fourier_motzkin.ml +++ b/src/arith/lra/fourier_motzkin.ml @@ -87,6 +87,7 @@ module type S = sig type model val get_model : model -> term -> Q.t + val eval_model : model -> LE.t -> Q.t val pp_model : model Fmt.printer type res = @@ -320,19 +321,23 @@ module Make(A : ARG) (Fmt.Dump.list Constr.pp) self.empties (Util.pp_iter pp_idxkv) (T_map.to_iter self.idx) - (* TODO: be able to provide a model for SAT *) let build_model_ (self:pre_model) : _ T_map.t = - let l = T_map.to_iter self |> Iter.to_rev_list in + (* order matters: we need to compute values for lowest variables first *) + let l = T_map.to_iter self |> Iter.to_list in + (* INVARIANT: assert (CCList.is_sorted ~cmp:(fun (a,_) (b,_) -> T.compare a b) l); *) let m = ref T_map.empty in (* how to evaluate a linexpr in the model *) - let eval_le (le:LE.t) : Q.t = + let eval_le ~for_v (le:LE.t) : Q.t = let find x = + assert (T.compare for_v x > 0); try T_map.find x !m with Not_found -> + Log.debugf 50 (fun k->k "LRA.model: add default value for %a" T.pp x); m := T_map.add x Q.zero !m; (* remember this choice *) - Q.zero in + Q.zero + in T_map.to_iter le.LE.le |> Iter.fold (fun sum (t,coeff) -> Q.(sum + coeff * find t)) @@ -355,12 +360,13 @@ module Make(A : ARG) begin fun (v,cs_v) -> (* update [v] using its constraints [cs_v]. [m] is the model to update *) + Log.debugf 40 (fun k->k "LRA.model: compute value for %a" T.pp v); let val_v = match cs_v with - | lazy (PM_eq le) -> eval_le le + | lazy (PM_eq le) -> eval_le ~for_v:v le | lazy (PM_bounds {lower; upper}) -> - let lower = List.map (fun (s,le) -> s, eval_le le) lower in - let upper = List.map (fun (s,le) -> s, eval_le le) upper in + let lower = List.map (fun (s,le) -> s, eval_le ~for_v:v le) lower in + let upper = List.map (fun (s,le) -> s, eval_le ~for_v:v le) upper in let strict_low, lower = match lower with | [] -> NonStrict, Q.minus_inf | x :: l -> List.fold_left max_pair x l @@ -383,7 +389,10 @@ module Make(A : ARG) Q.zero (* no bounds *) ) in - assert (not (T_map.mem v !m)); (* by ordering *) + if T_map.mem v !m then ( + (* error: by ordering [v] should not have been touched yet *) + Error.errorf "LRA.build-model: variable %a already has a value" T.pp v + ); m := T_map.add v val_v !m; end l; @@ -394,6 +403,12 @@ module Make(A : ARG) try T_map.find v m with Not_found -> Q.zero + let eval_model m (le:LE.t) : Q.t = + T_map.fold + (fun v coeff sum -> + Q.(sum + coeff * get_model m v)) + le.LE.le le.LE.const + let pp_model out (m:model) : unit = let lazy m = m in let pp_pair out (v,q) = Fmt.fprintf out "(@[%a@ %a@])" T.pp v Q.pp_print q in