From 93b56618f1b13ef11d3835363d838d5fcba351c1 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sat, 10 Oct 2020 01:22:22 -0400 Subject: [PATCH] wip: first implem of Fourier Motzkin --- src/th-lra/Sidekick_lra.ml | 1 - src/th-lra/fourier_motzkin.ml | 216 +++++++++++++++++++++++++++++++--- 2 files changed, 197 insertions(+), 20 deletions(-) diff --git a/src/th-lra/Sidekick_lra.ml b/src/th-lra/Sidekick_lra.ml index 7fbb52e3..4699d9ee 100644 --- a/src/th-lra/Sidekick_lra.ml +++ b/src/th-lra/Sidekick_lra.ml @@ -236,7 +236,6 @@ module Make(A : ARG) : S with module A = A = struct begin match T.Tbl.find self.pred_defs t with | exception Not_found -> () | (pred, a, b) -> - let open LE.Infix in let pred = if sign then pred else FM.Pred.neg pred in let c = FM_A.Constr.mk ~tag:lit pred a b in FM_A.assert_c fm c; diff --git a/src/th-lra/fourier_motzkin.ml b/src/th-lra/fourier_motzkin.ml index 8fda33f3..1b4ecc8e 100644 --- a/src/th-lra/fourier_motzkin.ml +++ b/src/th-lra/fourier_motzkin.ml @@ -46,13 +46,18 @@ module type S = sig type t - type expr = A.T.t + type term = A.T.t module LE : sig type t val const : Q.t -> t - val var : expr -> t + val zero : t + val var : term -> t + val neg : t -> t + + val find_exn : term -> t -> Q.t + val find : term -> t -> Q.t option module Infix : sig val (+) : t -> t -> t @@ -66,15 +71,11 @@ module type S = sig (** {3 Arithmetic constraint} *) module Constr : sig - type t = { - pred: Pred.t; - le: LE.t; - tag: A.tag option; - } - - val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t + type t val pp : t Fmt.printer + val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t + val is_absurd : t -> bool end val create : unit -> t @@ -97,7 +98,7 @@ module Make(A : ARG) module T_set = CCSet.Make(A.T) module T_map = CCMap.Make(A.T) - type expr = A.T.t + type term = A.T.t module LE = struct module M = T_map @@ -108,8 +109,18 @@ module Make(A : ARG) } let const x : t = {const=x; le=M.empty} + let zero = const Q.zero let var x : t = {const=Q.zero; le=M.singleton x Q.one} + let find_exn v le = M.find v le.le + let find v le = M.get v le.le + + let remove v le : t = {le with le=M.remove v le.le} + + let neg a : t = + {const=Q.neg a.const; + le=M.map Q.neg a.le; } + let (+) a b : t = {const = Q.(a.const + b.const); le=M.merge_safe a.le b.le @@ -139,6 +150,14 @@ module Make(A : ARG) } ) + let max_var self : T.t option = + M.keys self.le |> Iter.max ~lt:(fun a b -> T.compare a b < 0) + + (* ensure coeff of [v] is 1 *) + let normalize_wrt (v:T.t) le : t = + let q = find_exn v le in + Q.inv q * le + module Infix = struct let (+) = (+) let (-) = (-) @@ -160,41 +179,200 @@ module Make(A : ARG) type t = { pred: Pred.t; le: LE.t; - tag: A.tag option; + tag: A.tag list; } let pp out (c:t) : unit = Fmt.fprintf out "(@[constr@ :le %a@ :pred %s 0@])" LE.pp c.le (Pred.to_string c.pred) - + let mk_ ~tag pred le : t = {pred; tag; le; } let mk ?tag pred l1 l2 : t = - {pred; tag; le=LE.(l1 - l2); } + mk_ ~tag:(CCOpt.to_list tag) pred LE.(l1 - l2) + + let is_absurd (self:t) : bool = + T_map.is_empty self.le.le && + let c = self.le.const in + begin match self.pred with + | Leq -> Q.compare c Q.zero > 0 + | Lt -> Q.compare c Q.zero >= 0 + | Geq -> Q.compare c Q.zero < 0 + | Gt -> Q.compare c Q.zero <= 0 + | Eq -> Q.compare c Q.zero <> 0 + | Neq -> Q.compare c Q.zero = 0 + end + + let is_trivial (self:t) : bool = + T_map.is_empty self.le.le && not (is_absurd self) + + (* nornalize and return maximum variable *) + let normalize (self:t) : t = + match self.pred with + | Geq -> mk_ ~tag:self.tag Lt (LE.neg self.le) + | Gt -> mk_ ~tag:self.tag Leq (LE.neg self.le) + | _ -> self + + let find_max (self:t) : T.t option * bool = + match LE.max_var self.le with + | None -> None, true + | Some t -> Some t, Q.sign (T_map.find t self.le.le) > 0 end + (** constraints for a variable (where the variable is maximal) *) + type c_for_var = { + occ_pos: Constr.t list; + occ_eq: Constr.t list; + occ_neg: Constr.t list; + } + + type system = { + empties: Constr.t list; (* no variables, check first *) + idx: c_for_var T_map.t; + (* map [t] to [c1,c2] where [c1] are normalized constraints whose + maximum term is [t], with positive sign for [c1] + and negative for [c2] respectively. *) + } + type t = { mutable cs: Constr.t list; - mutable all_vars: T_set.t; + mutable sys: system; } + let empty_sys : system = {empties=[]; idx=T_map.empty} + let empty_c_for_v : c_for_var = + { occ_pos=[]; occ_neg=[]; occ_eq=[] } + let create () : t = { cs=[]; - all_vars=T_set.empty; + sys=empty_sys; } - let assert_c (self:t) c : unit = + let add_sys (sys:system) (c:Constr.t) : system = + assert (match c.pred with Eq|Neq|Lt -> true | _ -> false); + if Constr.is_trivial c then ( + Log.debugf 10 (fun k->k"(@[FM.drop-trivial@ %a@])" Constr.pp c); + sys + ) else ( + match Constr.find_max c with + | None, _ -> {sys with empties=c :: sys.empties} + | Some v, occ_pos -> + Log.debugf 30 (fun k->k "(@[FM.add-sys %a@ :max_var %a@ :occurs-pos %B@])" + Constr.pp c T.pp v occ_pos); + let cs = T_map.get_or ~default:empty_c_for_v v sys.idx in + let cs = + if c.pred = Eq then {cs with occ_eq = c :: cs.occ_eq} + else if occ_pos then {cs with occ_pos = c :: cs.occ_pos} + else {cs with occ_neg = c :: cs.occ_neg } + in + let idx = T_map.add v cs sys.idx in + {sys with idx} + ) + + let assert_c (self:t) c0 : unit = + Log.debugf 10 (fun k->k "(@[FM.add-constr@ %a@])" Constr.pp c0); + let c = Constr.normalize c0 in + if c.pred <> c0.pred then ( + Log.debugf 30 (fun k->k "(@[FM.normalized %a@])" Constr.pp c); + ); + assert (match c.pred with Eq | Leq | Lt -> true | _ -> false); self.cs <- c :: self.cs; - self.all_vars <- c.Constr.le |> LE.vars |> T_set.add_iter self.all_vars; + self.sys <- add_sys self.sys c; () + let pp_system out (self:system) : unit = + let pp_idxkv out (t,{occ_eq; occ_pos; occ_neg}) = + Fmt.fprintf out "(@[for-var %a@ :occ-eq %a@ :occ-pos %a@ :occ-neg %a@])" + T.pp t + (Fmt.Dump.list Constr.pp) occ_eq + (Fmt.Dump.list Constr.pp) occ_pos + (Fmt.Dump.list Constr.pp) occ_neg + in + Fmt.fprintf out "(@[:empties %a@ :idx (@[%a@])@])" + (Fmt.Dump.list Constr.pp) self.empties + (Util.pp_iter pp_idxkv) (T_map.to_iter self.idx) + (* TODO: be able to provide a model for SAT *) type res = | Sat | Unsat of A.tag list + (* replace [x] with [by] inside [le] *) + let subst_le (x:T.t) (le:LE.t) ~by:(le1:LE.t) : LE.t = + let q = LE.find_exn x le in + let le = LE.remove x le in + LE.( le + q * le1 ) + + let subst_constr x c ~by : Constr.t = + let c = {c with Constr.le=subst_le x ~by c.Constr.le} in + Constr.normalize c + + let rec solve_ (self:system) : res = + Log.debugf 50 + (fun k->k "(@[FM.solve-rec@ :sys %a@])" pp_system self); + begin match List.find Constr.is_absurd self.empties with + | c -> + Log.debugf 10 (fun k->k"(@[FM.unsat@ :by-absurd %a@])" Constr.pp c); + Unsat c.tag + | exception Not_found -> + (* need to process biggest variable first *) + match T_map.max_binding_opt self.idx with + | None -> Sat + | Some (v, {occ_eq=c0 :: ceq'; occ_pos; occ_neg}) -> + (* remove [v] from [idx] *) + let sys = {self with idx=T_map.remove v self.idx} in + + (* substitute using [c0] in the other constraints containing [v] *) + assert (c0.pred = Eq); + let c0 = LE.normalize_wrt v c0.le in + (* build [v = rhs] *) + let rhs = LE.neg @@ LE.remove v c0 in + Log.debugf 50 + (fun k->k "(@[FM.subst-from-eq@ :v %a@ :rhs %a@])" + T.pp v LE.pp rhs); + (* perform substitution *) + + let new_sys = + [Iter.of_list ceq'; Iter.of_list occ_pos; Iter.of_list occ_neg] + |> Iter.of_list + |> Iter.flatten + |> Iter.map (subst_constr v ~by:rhs) + |> Iter.fold add_sys sys + in + solve_ new_sys + + | Some (v, {occ_eq=[]; occ_pos=l1; occ_neg=l2}) -> + Log.debugf 10 + (fun k->k "(@[@{FM.pivot@}@ :v %a@ :lpos %a@ :lneg %a@])" + T.pp v (Fmt.Dump.list Constr.pp) l1 + (Fmt.Dump.list Constr.pp) l2); + + (* remove [v] *) + let sys = {self with idx=T_map.remove v self.idx} in + + let new_sys = + Iter.product (Iter.of_list l1) (Iter.of_list l2) + |> Iter.map + (fun (c1,c2) -> + let q1 = LE.find_exn v c1.Constr.le in + let le = LE.( c1.Constr.le + (Q.inv q1 * c2.Constr.le) ) in + let pred = match c1.Constr.pred, c2.Constr.pred with + | Eq, Eq -> Pred.Eq + | Lt, _ | _, Lt -> Pred.Lt + | Leq, _ | _, Leq -> Pred.Leq + | _ -> assert false + in + let c = Constr.mk_ ~tag:(c1.tag @ c2.tag) pred le in + Log.debugf 50 (fun k->k "(@[FM.resolve@ %a@ %a@ :yields@ %a@])" + Constr.pp c1 Constr.pp c2 Constr.pp c); + c) + |> Iter.fold add_sys sys + in + solve_ new_sys + end + let solve (self:t) : res = Log.debugf 5 - (fun k->k"(@[FM.solve@ %a@])" (Util.pp_list Constr.pp) self.cs); - assert false + (fun k->k"(@[@{FM.solve@}@ %a@])" (Util.pp_list Constr.pp) self.cs); + solve_ self.sys end