wip: first implem of Fourier Motzkin

This commit is contained in:
Simon Cruanes 2020-10-10 01:22:22 -04:00
parent 9783c3ae1b
commit 93b56618f1
2 changed files with 197 additions and 20 deletions

View file

@ -236,7 +236,6 @@ module Make(A : ARG) : S with module A = A = struct
begin match T.Tbl.find self.pred_defs t with begin match T.Tbl.find self.pred_defs t with
| exception Not_found -> () | exception Not_found -> ()
| (pred, a, b) -> | (pred, a, b) ->
let open LE.Infix in
let pred = if sign then pred else FM.Pred.neg pred in let pred = if sign then pred else FM.Pred.neg pred in
let c = FM_A.Constr.mk ~tag:lit pred a b in let c = FM_A.Constr.mk ~tag:lit pred a b in
FM_A.assert_c fm c; FM_A.assert_c fm c;

View file

@ -46,13 +46,18 @@ module type S = sig
type t type t
type expr = A.T.t type term = A.T.t
module LE : sig module LE : sig
type t type t
val const : Q.t -> t val const : Q.t -> t
val var : expr -> t val zero : t
val var : term -> t
val neg : t -> t
val find_exn : term -> t -> Q.t
val find : term -> t -> Q.t option
module Infix : sig module Infix : sig
val (+) : t -> t -> t val (+) : t -> t -> t
@ -66,15 +71,11 @@ module type S = sig
(** {3 Arithmetic constraint} *) (** {3 Arithmetic constraint} *)
module Constr : sig module Constr : sig
type t = { type t
pred: Pred.t;
le: LE.t;
tag: A.tag option;
}
val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t
val pp : t Fmt.printer val pp : t Fmt.printer
val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t
val is_absurd : t -> bool
end end
val create : unit -> t val create : unit -> t
@ -97,7 +98,7 @@ module Make(A : ARG)
module T_set = CCSet.Make(A.T) module T_set = CCSet.Make(A.T)
module T_map = CCMap.Make(A.T) module T_map = CCMap.Make(A.T)
type expr = A.T.t type term = A.T.t
module LE = struct module LE = struct
module M = T_map module M = T_map
@ -108,8 +109,18 @@ module Make(A : ARG)
} }
let const x : t = {const=x; le=M.empty} let const x : t = {const=x; le=M.empty}
let zero = const Q.zero
let var x : t = {const=Q.zero; le=M.singleton x Q.one} let var x : t = {const=Q.zero; le=M.singleton x Q.one}
let find_exn v le = M.find v le.le
let find v le = M.get v le.le
let remove v le : t = {le with le=M.remove v le.le}
let neg a : t =
{const=Q.neg a.const;
le=M.map Q.neg a.le; }
let (+) a b : t = let (+) a b : t =
{const = Q.(a.const + b.const); {const = Q.(a.const + b.const);
le=M.merge_safe a.le b.le le=M.merge_safe a.le b.le
@ -139,6 +150,14 @@ module Make(A : ARG)
} }
) )
let max_var self : T.t option =
M.keys self.le |> Iter.max ~lt:(fun a b -> T.compare a b < 0)
(* ensure coeff of [v] is 1 *)
let normalize_wrt (v:T.t) le : t =
let q = find_exn v le in
Q.inv q * le
module Infix = struct module Infix = struct
let (+) = (+) let (+) = (+)
let (-) = (-) let (-) = (-)
@ -160,41 +179,200 @@ module Make(A : ARG)
type t = { type t = {
pred: Pred.t; pred: Pred.t;
le: LE.t; le: LE.t;
tag: A.tag option; tag: A.tag list;
} }
let pp out (c:t) : unit = let pp out (c:t) : unit =
Fmt.fprintf out "(@[constr@ :le %a@ :pred %s 0@])" Fmt.fprintf out "(@[constr@ :le %a@ :pred %s 0@])"
LE.pp c.le (Pred.to_string c.pred) LE.pp c.le (Pred.to_string c.pred)
let mk_ ~tag pred le : t = {pred; tag; le; }
let mk ?tag pred l1 l2 : t = let mk ?tag pred l1 l2 : t =
{pred; tag; le=LE.(l1 - l2); } mk_ ~tag:(CCOpt.to_list tag) pred LE.(l1 - l2)
let is_absurd (self:t) : bool =
T_map.is_empty self.le.le &&
let c = self.le.const in
begin match self.pred with
| Leq -> Q.compare c Q.zero > 0
| Lt -> Q.compare c Q.zero >= 0
| Geq -> Q.compare c Q.zero < 0
| Gt -> Q.compare c Q.zero <= 0
| Eq -> Q.compare c Q.zero <> 0
| Neq -> Q.compare c Q.zero = 0
end end
let is_trivial (self:t) : bool =
T_map.is_empty self.le.le && not (is_absurd self)
(* nornalize and return maximum variable *)
let normalize (self:t) : t =
match self.pred with
| Geq -> mk_ ~tag:self.tag Lt (LE.neg self.le)
| Gt -> mk_ ~tag:self.tag Leq (LE.neg self.le)
| _ -> self
let find_max (self:t) : T.t option * bool =
match LE.max_var self.le with
| None -> None, true
| Some t -> Some t, Q.sign (T_map.find t self.le.le) > 0
end
(** constraints for a variable (where the variable is maximal) *)
type c_for_var = {
occ_pos: Constr.t list;
occ_eq: Constr.t list;
occ_neg: Constr.t list;
}
type system = {
empties: Constr.t list; (* no variables, check first *)
idx: c_for_var T_map.t;
(* map [t] to [c1,c2] where [c1] are normalized constraints whose
maximum term is [t], with positive sign for [c1]
and negative for [c2] respectively. *)
}
type t = { type t = {
mutable cs: Constr.t list; mutable cs: Constr.t list;
mutable all_vars: T_set.t; mutable sys: system;
} }
let empty_sys : system = {empties=[]; idx=T_map.empty}
let empty_c_for_v : c_for_var =
{ occ_pos=[]; occ_neg=[]; occ_eq=[] }
let create () : t = { let create () : t = {
cs=[]; cs=[];
all_vars=T_set.empty; sys=empty_sys;
} }
let assert_c (self:t) c : unit = let add_sys (sys:system) (c:Constr.t) : system =
assert (match c.pred with Eq|Neq|Lt -> true | _ -> false);
if Constr.is_trivial c then (
Log.debugf 10 (fun k->k"(@[FM.drop-trivial@ %a@])" Constr.pp c);
sys
) else (
match Constr.find_max c with
| None, _ -> {sys with empties=c :: sys.empties}
| Some v, occ_pos ->
Log.debugf 30 (fun k->k "(@[FM.add-sys %a@ :max_var %a@ :occurs-pos %B@])"
Constr.pp c T.pp v occ_pos);
let cs = T_map.get_or ~default:empty_c_for_v v sys.idx in
let cs =
if c.pred = Eq then {cs with occ_eq = c :: cs.occ_eq}
else if occ_pos then {cs with occ_pos = c :: cs.occ_pos}
else {cs with occ_neg = c :: cs.occ_neg }
in
let idx = T_map.add v cs sys.idx in
{sys with idx}
)
let assert_c (self:t) c0 : unit =
Log.debugf 10 (fun k->k "(@[FM.add-constr@ %a@])" Constr.pp c0);
let c = Constr.normalize c0 in
if c.pred <> c0.pred then (
Log.debugf 30 (fun k->k "(@[FM.normalized %a@])" Constr.pp c);
);
assert (match c.pred with Eq | Leq | Lt -> true | _ -> false);
self.cs <- c :: self.cs; self.cs <- c :: self.cs;
self.all_vars <- c.Constr.le |> LE.vars |> T_set.add_iter self.all_vars; self.sys <- add_sys self.sys c;
() ()
let pp_system out (self:system) : unit =
let pp_idxkv out (t,{occ_eq; occ_pos; occ_neg}) =
Fmt.fprintf out "(@[for-var %a@ :occ-eq %a@ :occ-pos %a@ :occ-neg %a@])"
T.pp t
(Fmt.Dump.list Constr.pp) occ_eq
(Fmt.Dump.list Constr.pp) occ_pos
(Fmt.Dump.list Constr.pp) occ_neg
in
Fmt.fprintf out "(@[:empties %a@ :idx (@[%a@])@])"
(Fmt.Dump.list Constr.pp) self.empties
(Util.pp_iter pp_idxkv) (T_map.to_iter self.idx)
(* TODO: be able to provide a model for SAT *) (* TODO: be able to provide a model for SAT *)
type res = type res =
| Sat | Sat
| Unsat of A.tag list | Unsat of A.tag list
(* replace [x] with [by] inside [le] *)
let subst_le (x:T.t) (le:LE.t) ~by:(le1:LE.t) : LE.t =
let q = LE.find_exn x le in
let le = LE.remove x le in
LE.( le + q * le1 )
let subst_constr x c ~by : Constr.t =
let c = {c with Constr.le=subst_le x ~by c.Constr.le} in
Constr.normalize c
let rec solve_ (self:system) : res =
Log.debugf 50
(fun k->k "(@[FM.solve-rec@ :sys %a@])" pp_system self);
begin match List.find Constr.is_absurd self.empties with
| c ->
Log.debugf 10 (fun k->k"(@[FM.unsat@ :by-absurd %a@])" Constr.pp c);
Unsat c.tag
| exception Not_found ->
(* need to process biggest variable first *)
match T_map.max_binding_opt self.idx with
| None -> Sat
| Some (v, {occ_eq=c0 :: ceq'; occ_pos; occ_neg}) ->
(* remove [v] from [idx] *)
let sys = {self with idx=T_map.remove v self.idx} in
(* substitute using [c0] in the other constraints containing [v] *)
assert (c0.pred = Eq);
let c0 = LE.normalize_wrt v c0.le in
(* build [v = rhs] *)
let rhs = LE.neg @@ LE.remove v c0 in
Log.debugf 50
(fun k->k "(@[FM.subst-from-eq@ :v %a@ :rhs %a@])"
T.pp v LE.pp rhs);
(* perform substitution *)
let new_sys =
[Iter.of_list ceq'; Iter.of_list occ_pos; Iter.of_list occ_neg]
|> Iter.of_list
|> Iter.flatten
|> Iter.map (subst_constr v ~by:rhs)
|> Iter.fold add_sys sys
in
solve_ new_sys
| Some (v, {occ_eq=[]; occ_pos=l1; occ_neg=l2}) ->
Log.debugf 10
(fun k->k "(@[@{<yellow>FM.pivot@}@ :v %a@ :lpos %a@ :lneg %a@])"
T.pp v (Fmt.Dump.list Constr.pp) l1
(Fmt.Dump.list Constr.pp) l2);
(* remove [v] *)
let sys = {self with idx=T_map.remove v self.idx} in
let new_sys =
Iter.product (Iter.of_list l1) (Iter.of_list l2)
|> Iter.map
(fun (c1,c2) ->
let q1 = LE.find_exn v c1.Constr.le in
let le = LE.( c1.Constr.le + (Q.inv q1 * c2.Constr.le) ) in
let pred = match c1.Constr.pred, c2.Constr.pred with
| Eq, Eq -> Pred.Eq
| Lt, _ | _, Lt -> Pred.Lt
| Leq, _ | _, Leq -> Pred.Leq
| _ -> assert false
in
let c = Constr.mk_ ~tag:(c1.tag @ c2.tag) pred le in
Log.debugf 50 (fun k->k "(@[FM.resolve@ %a@ %a@ :yields@ %a@])"
Constr.pp c1 Constr.pp c2 Constr.pp c);
c)
|> Iter.fold add_sys sys
in
solve_ new_sys
end
let solve (self:t) : res = let solve (self:t) : res =
Log.debugf 5 Log.debugf 5
(fun k->k"(@[FM.solve@ %a@])" (Util.pp_list Constr.pp) self.cs); (fun k->k"(@[<hv>@{<Green>FM.solve@}@ %a@])" (Util.pp_list Constr.pp) self.cs);
assert false solve_ self.sys
end end