diff --git a/src/arith/lra/Sidekick_arith_lra.ml b/src/arith/lra/Sidekick_arith_lra.ml index 98d69326..85946fb4 100644 --- a/src/arith/lra/Sidekick_arith_lra.ml +++ b/src/arith/lra/Sidekick_arith_lra.ml @@ -82,7 +82,7 @@ module Make(A : ARG) : S with module A = A = struct neq_encoded: unit T.Tbl.t; (* if [a != b] asserted and not in this table, add clause [a = b \/ ab] *) mutable t_defs: (T.t * LE.t) list; (* term definitions *) - pred_defs: (pred * LE.t * LE.t) T.Tbl.t; (* predicate definitions *) + pred_defs: (pred * LE.t * LE.t * T.t * T.t) T.Tbl.t; (* predicate definitions *) } let create tst : state = @@ -178,7 +178,7 @@ module Make(A : ARG) : S with module A = A = struct let l1 = as_linexp ~f:recurse t1 in let l2 = as_linexp ~f:recurse t2 in let proxy = fresh_term self ~pre:"_pred_lra_" Ty.bool in - T.Tbl.add self.pred_defs proxy (pred, l1, l2); + T.Tbl.add self.pred_defs proxy (pred, l1, l2, t1, t2); Log.debugf 5 (fun k->k"@[lra.preprocess.step %a@ :into %a@ :def %a@]" T.pp t T.pp proxy pp_pred_def (pred,l1,l2)); Some proxy @@ -191,41 +191,59 @@ module Make(A : ARG) : S with module A = A = struct Some proxy | LRA_const _ | LRA_other _ -> None - (* partial check: just ensure [a != b] triggers the clause + (* ensure that [a != b] triggers the clause [a=b \/ ab] *) - let partial_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit = + let encode_neq self si acts trail : unit = let tst = self.tst in begin trail - |> Iter.filter (fun lit -> not (Lit.sign lit)) |> Iter.filter_map (fun lit -> let t = Lit.term lit in - match A.view_as_lra t with - | LRA_pred (Eq, a, b) when not (T.Tbl.mem self.neq_encoded t) -> - Some (lit, a,b) - | _ -> None) + Log.debugf 50 (fun k->k "@[lra: check lit %a@ :t %a@ :sign %B@]" + Lit.pp lit T.pp t (Lit.sign lit)); + begin match T.Tbl.find self.pred_defs t with + | (pred, _, _, ta, tb) -> + let pred = if Lit.sign lit then pred else FM.Pred.neg pred in + Log.debugf 50 (fun k->k "pred = `%s`" (FM.Pred.to_string pred)); + if pred = Neq && not (T.Tbl.mem self.neq_encoded t) then ( + Some (lit, ta, tb) + ) else None + | exception Not_found -> + begin match A.view_as_lra t with + | LRA_pred (Neq, a, b) when not (T.Tbl.mem self.neq_encoded t) -> + Some (lit, a, b) + | _ -> None + end + end) |> Iter.iter (fun (lit,a,b) -> + Log.debugf 50 (fun k->k "encode neq in %a" Lit.pp lit); let c = [ - Lit.abs lit; + Lit.neg lit; SI.mk_lit si acts (A.mk_lra tst (LRA_pred (Lt, a, b))); SI.mk_lit si acts (A.mk_lra tst (LRA_pred (Lt, b, a))); ] in SI.add_clause_permanent si acts c; - T.Tbl.add self.neq_encoded (Lit.term lit) (); + T.Tbl.add self.neq_encoded (Lit.term (Lit.abs lit)) (); ) end + let dedup_lits lits : _ list = + let module LTbl = CCHashtbl.Make(Lit) in + let tbl = LTbl.create 16 in + List.iter (fun l -> LTbl.replace tbl l ()) lits; + LTbl.keys_list tbl + let final_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit = Log.debug 5 "(th-lra.final-check)"; + encode_neq self si acts trail; let fm = FM_A.create() in (* first, add definitions *) begin List.iter (fun (t,le) -> let open LE.Infix in - let le = le - LE.var t in let c = FM_A.Constr.mk ?tag:None Eq (LE.var t) le in FM_A.assert_c fm c) self.t_defs @@ -239,7 +257,7 @@ module Make(A : ARG) : S with module A = A = struct let t = Lit.term lit in begin match T.Tbl.find self.pred_defs t with | exception Not_found -> () - | (pred, a, b) -> + | (pred, a, b, _, _) -> let pred = if sign then pred else FM.Pred.neg pred in if pred = Neq then ( Log.debugf 50 (fun k->k "skip neq in %a" T.pp t); @@ -259,7 +277,9 @@ module Make(A : ARG) : S with module A = A = struct unsat core translates directly into a conflict clause *) Log.debugf 5 (fun k->k"lra: solver returns UNSAT@ with cert %a" (Fmt.Dump.list Lit.pp) lits); - let confl = List.rev_map Lit.neg lits in + let confl = + List.rev_map Lit.neg lits |> dedup_lits + in (* TODO: produce and store a proper LRA resolution proof *) SI.raise_conflict si acts confl SI.P.default end; @@ -270,7 +290,6 @@ module Make(A : ARG) : S with module A = A = struct let st = create (SI.tst si) in (* TODO SI.add_simplifier si (simplify st); *) SI.add_preprocess si (preproc_lra st); - SI.on_partial_check si (partial_check_ st); SI.on_final_check si (final_check_ st); (* SI.add_preprocess si (cnf st); *) (* TODO: theory combination *)