From 9783c3ae1bc4cd9e287197552862532a3f4c8b37 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 10 Oct 2020 00:00:20 -0400 Subject: [PATCH] wip: reimplement a fourier motzkin module, from scratch --- src/base-term/Base_types.ml | 4 +- src/th-lra/Sidekick_lra.ml | 148 +++++++++++-------------- src/th-lra/dune | 2 +- src/th-lra/fourier_motzkin.ml | 200 ++++++++++++++++++++++++++++++++++ 4 files changed, 269 insertions(+), 85 deletions(-) create mode 100644 src/th-lra/fourier_motzkin.ml diff --git a/src/base-term/Base_types.ml b/src/base-term/Base_types.ml index 5f118c74..f86f3be4 100644 --- a/src/base-term/Base_types.ml +++ b/src/base-term/Base_types.ml @@ -4,7 +4,7 @@ module Fmt = CCFormat module CC_view = Sidekick_core.CC_view -type lra_pred = Sidekick_lra.pred = Lt | Leq | Eq | Neq | Geq | Gt +type lra_pred = Sidekick_lra.FM.Pred.t = Lt | Leq | Geq | Gt | Neq | Eq type lra_op = Sidekick_lra.op = Plus | Minus type 'a lra_view = 'a Sidekick_lra.lra_view = @@ -894,6 +894,8 @@ end = struct | Not u -> u, false | App_fun ({fun_view=Fun_def def; _}, args) -> def.abs ~self:t args (* TODO: pass state *) + | LRA (LRA_pred (Neq, a, b)) -> + lra tst (LRA_pred (Eq,a,b)), false (* != is just not eq *) | _ -> t, true let[@inline] is_true t = match view t with Bool true -> true | _ -> false diff --git a/src/th-lra/Sidekick_lra.ml b/src/th-lra/Sidekick_lra.ml index 9b09f1e7..7fbb52e3 100644 --- a/src/th-lra/Sidekick_lra.ml +++ b/src/th-lra/Sidekick_lra.ml @@ -6,7 +6,9 @@ open Sidekick_core -type pred = Lt | Leq | Eq | Neq | Geq | Gt +module FM = Fourier_motzkin + +type pred = FM.Pred.t = Lt | Leq | Geq | Gt | Neq | Eq type op = Plus | Minus type 'a lra_view = @@ -25,24 +27,6 @@ let map_view f (l:_ lra_view) : _ lra_view = | LRA_other x -> LRA_other (f x) end -(* TODO: upstream *) -let neg_pred = function - | Leq -> Gt - | Lt -> Geq - | Eq -> Neq - | Neq -> Eq - | Geq -> Lt - | Gt -> Leq - -let pred_to_funarith = function - | Leq -> `Leq - | Lt -> `Lt - | Geq -> `Geq - | Gt -> `Gt - | Eq -> `Eq - | Neq -> `Neq - - module type ARG = sig module S : Sidekick_core.SOLVER @@ -81,47 +65,30 @@ module Make(A : ARG) : S with module A = A = struct module Lit = A.S.Solver_internal.Lit module SI = A.S.Solver_internal - type simp_var = - | V_fresh of int - | V_t of T.t + (* the fourier motzkin module *) + module FM_A = FM.Make(struct + module T = T + type tag = Lit.t + end) - (** Simplex variables *) - module Simp_vars = struct - type t = simp_var - let compare a b = - match a, b with - | V_fresh i, V_fresh j -> CCInt.compare i j - | V_fresh _, V_t _ -> -1 - | V_t _, V_fresh _ -> 1 - | V_t t1, V_t t2 -> T.compare t1 t2 - let pp out = function - | V_fresh i -> Fmt.fprintf out "$fresh_%d" i - | V_t t -> T.pp out t - module Fresh = struct - type t = int ref - let create() : t = ref 0 - let fresh n = V_fresh (CCRef.get_then_incr n) - end - end - - module Simplex = Funarith_zarith.Simplex.Make_full(Simp_vars) - module LE = Simplex.L.Expr - module LComb = Simplex.L.Comb - module Constr = Simplex.L.Constr + (* linear expressions *) + module LE = FM_A.LE type state = { tst: T.state; simps: T.t T.Tbl.t; (* cache *) - simplex: Simplex.t; gensym: A.Gensym.t; + neq_encoded: unit T.Tbl.t; + (* if [a != b] asserted and not in this table, add clause [a = b \/ ab] *) mutable t_defs: (T.t * LE.t) list; (* term definitions *) pred_defs: (pred * LE.t * LE.t) T.Tbl.t; (* predicate definitions *) } let create tst : state = - { tst; simps=T.Tbl.create 128; + { tst; + simps=T.Tbl.create 128; gensym=A.Gensym.create tst; - simplex=Simplex.create(); + neq_encoded=T.Tbl.create 16; t_defs=[]; pred_defs=T.Tbl.create 16; } @@ -179,7 +146,7 @@ module Make(A : ARG) : S with module A = A = struct let rec as_linexp (t:T.t) : LE.t = let open LE.Infix in match A.view_as_lra t with - | LRA_other _ -> LE.of_list Q.zero [Q.one, V_t t] + | LRA_other _ -> LE.var t | LRA_pred _ -> Error.errorf "type error: in linexp, LRA predicate %a" T.pp t | LRA_op (op, t1, t2) -> @@ -192,7 +159,7 @@ module Make(A : ARG) : S with module A = A = struct | LRA_mult (n, x) -> let t = as_linexp x in LE.( n * t ) - | LRA_const q -> LE.of_const q + | LRA_const q -> LE.const q (* TODO: keep the linexps until they're asserted; TODO: but use simplification in preprocess @@ -220,18 +187,43 @@ module Make(A : ARG) : S with module A = A = struct Some proxy | LRA_const _ | LRA_other _ -> None - let final_check_ (self:state) _si (_acts:SI.actions) (trail:_ Iter.t) : unit = + (* partial check: just ensure [a != b] triggers the clause + [a=b \/ ab] *) + let partial_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit = + let tst = self.tst in + begin + trail + |> Iter.filter (fun lit -> not (Lit.sign lit)) + |> Iter.filter_map + (fun lit -> + let t = Lit.term lit in + match A.view_as_lra t with + | LRA_pred (Eq, a, b) when not (T.Tbl.mem self.neq_encoded t) -> + Some (lit, a,b) + | _ -> None) + |> Iter.iter + (fun (lit,a,b) -> + let c = [ + Lit.abs lit; + SI.mk_lit si acts (A.mk_lra tst (LRA_pred (Lt, a, b))); + SI.mk_lit si acts (A.mk_lra tst (LRA_pred (Lt, b, a))); + ] in + SI.add_clause_permanent si acts c; + T.Tbl.add self.neq_encoded (Lit.term lit) (); + ) + end + + let final_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit = Log.debug 5 "(th-lra.final-check)"; - let simplex = Simplex.create() in + let fm = FM_A.create() in (* first, add definitions *) begin List.iter (fun (t,le) -> let open LE.Infix in - let c = - Constr.of_expr (le - LE.of_comb (LComb.monomial1 (V_t t))) `Eq - in - Simplex.add_constr simplex c) + let le = le - LE.var t in + let c = FM_A.Constr.mk ?tag:None Eq (LE.var t) le in + FM_A.assert_c fm c) self.t_defs end; (* add trail *) @@ -245,35 +237,24 @@ module Make(A : ARG) : S with module A = A = struct | exception Not_found -> () | (pred, a, b) -> let open LE.Infix in - let e = a - b in - let pred = if sign then pred else neg_pred pred in - let pred = match pred_to_funarith pred with - | `Neq -> Sidekick_util.Error.errorf "cannot handle negative LEQ equality" - | (`Eq | `Geq | `Gt | `Leq | `Lt) as p -> p - in - let c = Constr.of_expr e pred in - Simplex.add_constr simplex c; + let pred = if sign then pred else FM.Pred.neg pred in + let c = FM_A.Constr.mk ~tag:lit pred a b in + FM_A.assert_c fm c; end) end; - Log.debug 5 "lra: call simplex"; - begin match Simplex.solve simplex with - | Simplex.Solution _ -> - Log.debug 5 "lra: simplex returns SAT"; - () (* TODO: model combination *) - | Simplex.Unsatisfiable cert -> - Log.debugf 5 (fun k->k"lra: simplex returns UNSAT@ with cert %a" Simplex.pp_cert cert); - (* find what terms are involved *) - let asserts = - cert.Simplex.cert_expr - |> Iter.of_list - |> Iter.filter_map - (function - | V_t -> Some t - | V_fresh _ -> None) - |> Iter.to_rev_list - in - Simplex.cert - () (* TODO: produce conflict *) + Log.debug 5 "lra: call arith solver"; + begin match FM_A.solve fm with + | FM_A.Sat -> + Log.debug 5 "lra: solver returns SAT"; + () (* TODO: get a model + model combination *) + | FM_A.Unsat lits -> + (* we tagged assertions with their lit, so the certificate being an + unsat core translates directly into a conflict clause *) + Log.debugf 5 (fun k->k"lra: solver returns UNSAT@ with cert %a" + (Fmt.Dump.list Lit.pp) lits); + let confl = List.rev_map Lit.neg lits in + (* TODO: produce and store a proper LRA resolution proof *) + SI.raise_conflict si acts confl SI.P.default end; () @@ -282,6 +263,7 @@ module Make(A : ARG) : S with module A = A = struct let st = create (SI.tst si) in (* TODO SI.add_simplifier si (simplify st); *) SI.add_preprocess si (preproc_lra st); + SI.on_partial_check si (partial_check_ st); SI.on_final_check si (final_check_ st); (* SI.add_preprocess si (cnf st); *) (* TODO: theory combination *) diff --git a/src/th-lra/dune b/src/th-lra/dune index a80c8575..cdab035a 100644 --- a/src/th-lra/dune +++ b/src/th-lra/dune @@ -4,4 +4,4 @@ (public_name sidekick.th-lra) (optional) ; only if deps present (flags :standard -warn-error -a+8 -open Sidekick_util) - (libraries containers sidekick.core zarith funarith.zarith funarith)) + (libraries containers sidekick.core zarith)) diff --git a/src/th-lra/fourier_motzkin.ml b/src/th-lra/fourier_motzkin.ml new file mode 100644 index 00000000..8fda33f3 --- /dev/null +++ b/src/th-lra/fourier_motzkin.ml @@ -0,0 +1,200 @@ + + +module type ARG = sig + (** terms *) + module T : sig + type t + + val equal : t -> t -> bool + val hash : t -> int + val compare : t -> t -> int + val pp : t Fmt.printer + end + + type tag +end + +module Pred : sig + type t = Lt | Leq | Geq | Gt | Neq | Eq + + val neg : t -> t + val pp : t Fmt.printer + val to_string : t -> string +end = struct + type t = Lt | Leq | Geq | Gt | Neq | Eq + let to_string = function + | Lt -> "<" + | Leq -> "<=" + | Eq -> "=" + | Neq -> "!=" + | Gt -> ">" + | Geq -> ">=" + + let neg = function + | Leq -> Gt + | Lt -> Geq + | Eq -> Neq + | Neq -> Eq + | Geq -> Lt + | Gt -> Leq + + let pp out p = Fmt.string out (to_string p) +end + +module type S = sig + module A : ARG + + type t + + type expr = A.T.t + + module LE : sig + type t + + val const : Q.t -> t + val var : expr -> t + + module Infix : sig + val (+) : t -> t -> t + val (-) : t -> t -> t + val ( * ) : Q.t -> t -> t + end + include module type of Infix + + val pp : t Fmt.printer + end + + (** {3 Arithmetic constraint} *) + module Constr : sig + type t = { + pred: Pred.t; + le: LE.t; + tag: A.tag option; + } + + val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t + + val pp : t Fmt.printer + end + + val create : unit -> t + + val assert_c : t -> Constr.t -> unit + + type res = + | Sat + | Unsat of A.tag list + + val solve : t -> res +end + +module Make(A : ARG) + : S with module A = A += struct + module A = A + module T = A.T + + module T_set = CCSet.Make(A.T) + module T_map = CCMap.Make(A.T) + + type expr = A.T.t + + module LE = struct + module M = T_map + + type t = { + le: Q.t M.t; + const: Q.t; + } + + let const x : t = {const=x; le=M.empty} + let var x : t = {const=Q.zero; le=M.singleton x Q.one} + + let (+) a b : t = + {const = Q.(a.const + b.const); + le=M.merge_safe a.le b.le + ~f:(fun _ -> function + | `Left x | `Right x -> Some x + | `Both (x,y) -> + let z = Q.(x + y) in + if Q.sign z = 0 then None else Some z) + } + + let (-) a b : t = + {const = Q.(a.const - b.const); + le=M.merge_safe a.le b.le + ~f:(fun _ -> function + | `Left x -> Some x + | `Right x -> Some (Q.neg x) + | `Both (x,y) -> + let z = Q.(x - y) in + if Q.sign z = 0 then None else Some z) + } + + let ( * ) x a : t = + if Q.sign x = 0 then const Q.zero + else ( + {const=Q.( a.const * x ); + le=M.map (fun y -> Q.(x * y)) a.le + } + ) + + module Infix = struct + let (+) = (+) + let (-) = (-) + let ( * ) = ( * ) + end + + let vars self = T_map.keys self.le + + let pp out (self:t) : unit = + let pp_pair out (e,q) = + if Q.equal Q.one q then T.pp out e + else Fmt.fprintf out "%a * %a" Q.pp_print q T.pp e + in + Fmt.fprintf out "(@[%a@ + %a@])" + Q.pp_print self.const (Util.pp_iter ~sep:" + " pp_pair) (M.to_iter self.le) + end + + module Constr = struct + type t = { + pred: Pred.t; + le: LE.t; + tag: A.tag option; + } + + let pp out (c:t) : unit = + Fmt.fprintf out "(@[constr@ :le %a@ :pred %s 0@])" + LE.pp c.le (Pred.to_string c.pred) + + + let mk ?tag pred l1 l2 : t = + {pred; tag; le=LE.(l1 - l2); } + end + + type t = { + mutable cs: Constr.t list; + mutable all_vars: T_set.t; + } + + let create () : t = { + cs=[]; + all_vars=T_set.empty; + } + + let assert_c (self:t) c : unit = + self.cs <- c :: self.cs; + self.all_vars <- c.Constr.le |> LE.vars |> T_set.add_iter self.all_vars; + () + + (* TODO: be able to provide a model for SAT *) + type res = + | Sat + | Unsat of A.tag list + + let solve (self:t) : res = + Log.debugf 5 + (fun k->k"(@[FM.solve@ %a@])" (Util.pp_list Constr.pp) self.cs); + assert false +end +