wip: refactor(lra): import Simplex from funarith, replace FM with it

This commit is contained in:
Simon Cruanes 2020-11-13 22:35:59 -05:00
parent df25e84a01
commit 5ff0fff85b
12 changed files with 1473 additions and 431 deletions

View file

@ -17,6 +17,7 @@ depends: [
"ocaml" { >= "4.03" }
"zarith"
"alcotest" {with-test}
"qcheck" {with-test & >= "0.16" }
]
tags: [ "sat" "smt" ]
homepage: "https://github.com/c-cube/sidekick"

View file

@ -4,7 +4,7 @@ module Fmt = CCFormat
module CC_view = Sidekick_core.CC_view
type lra_pred = Sidekick_arith_lra.FM.Pred.t = Lt | Leq | Geq | Gt | Neq | Eq
type lra_pred = Sidekick_arith_lra.Predicate.t = Leq | Geq | Lt | Gt | Eq | Neq
type lra_op = Sidekick_arith_lra.op = Plus | Minus
type 'a lra_view = 'a Sidekick_arith_lra.lra_view =

View file

@ -1,408 +0,0 @@
module type ARG = sig
(** terms *)
module T : sig
type t
val equal : t -> t -> bool
val hash : t -> int
val compare : t -> t -> int
val pp : t Fmt.printer
end
type tag
val pp_tag : tag Fmt.printer
end
module Pred : sig
type t = Lt | Leq | Geq | Gt | Neq | Eq
val neg : t -> t
val pp : t Fmt.printer
val to_string : t -> string
end = struct
type t = Lt | Leq | Geq | Gt | Neq | Eq
let to_string = function
| Lt -> "<"
| Leq -> "<="
| Eq -> "="
| Neq -> "!="
| Gt -> ">"
| Geq -> ">="
let neg = function
| Leq -> Gt
| Lt -> Geq
| Eq -> Neq
| Neq -> Eq
| Geq -> Lt
| Gt -> Leq
let pp out p = Fmt.string out (to_string p)
end
module type S = sig
module A : ARG
type t
type term = A.T.t
module LE : sig
type t
val const : Q.t -> t
val zero : t
val var : term -> t
val neg : t -> t
val find_exn : term -> t -> Q.t
val find : term -> t -> Q.t option
val mem : term -> t -> bool
(* val map : (term -> term) -> t -> t *)
module Infix : sig
val (+) : t -> t -> t
val (-) : t -> t -> t
val ( * ) : Q.t -> t -> t
end
include module type of Infix
val pp : t Fmt.printer
end
(** {3 Arithmetic constraint} *)
module Constr : sig
type t
val pp : t Fmt.printer
val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t
val is_absurd : t -> bool
end
val create : unit -> t
val assert_c : t -> Constr.t -> unit
type res =
| Sat
| Unsat of A.tag list
val solve : t -> res
end
module Make(A : ARG)
: S with module A = A
= struct
module A = A
module T = A.T
module T_set = CCSet.Make(A.T)
module T_map = CCMap.Make(A.T)
type term = A.T.t
module LE = struct
module M = T_map
type t = {
le: Q.t M.t;
const: Q.t;
}
let const x : t = {const=x; le=M.empty}
let zero = const Q.zero
let var x : t = {const=Q.zero; le=M.singleton x Q.one}
let[@inline] find_exn v le = M.find v le.le
let[@inline] find v le = M.get v le.le
let[@inline] mem v le = M.mem v le.le
let[@inline] remove v le : t = {le with le=M.remove v le.le}
let neg a : t =
{const=Q.neg a.const;
le=M.map Q.neg a.le; }
let (+) a b : t =
{const = Q.(a.const + b.const);
le=M.merge_safe a.le b.le
~f:(fun _ -> function
| `Left x | `Right x -> Some x
| `Both (x,y) ->
let z = Q.(x + y) in
if Q.sign z = 0 then None else Some z)
}
let (-) a b : t =
{const = Q.(a.const - b.const);
le=M.merge_safe a.le b.le
~f:(fun _ -> function
| `Left x -> Some x
| `Right x -> Some (Q.neg x)
| `Both (x,y) ->
let z = Q.(x - y) in
if Q.sign z = 0 then None else Some z)
}
let ( * ) x a : t =
if Q.sign x = 0 then const Q.zero
else (
{const=Q.( a.const * x );
le=M.map (fun y -> Q.(x * y)) a.le
}
)
let max_var self : T.t option =
M.keys self.le |> Iter.max ~lt:(fun a b -> T.compare a b < 0)
(* ensure coeff of [v] is 1 *)
let normalize_wrt (v:T.t) le : t =
let q = find_exn v le in
Q.inv q * le
module Infix = struct
let (+) = (+)
let (-) = (-)
let ( * ) = ( * )
end
let vars self = T_map.keys self.le
let pp out (self:t) : unit =
let pp_pair out (e,q) =
if Q.equal Q.one q then T.pp out e
else Fmt.fprintf out "%a * %a" Q.pp_print q T.pp e
in
let pp_sum out le =
(Util.pp_iter ~sep:" + " pp_pair) out (M.to_iter le)
in
if Q.sign self.const = 0 then (
Fmt.fprintf out "(@[%a@])" pp_sum self.le
) else (
Fmt.fprintf out "(@[%a@ + %a@])" Q.pp_print self.const pp_sum self.le
)
end
(** {2 Constraints} *)
module Constr = struct
type t = {
pred: Pred.t;
le: LE.t;
tag: A.tag list;
}
let pp out (c:t) : unit =
Fmt.fprintf out "(@[constr@ :le %a@ :pred %s 0@])"
LE.pp c.le (Pred.to_string c.pred)
let mk_ ~tag pred le : t = {pred; tag; le; }
let mk ?tag pred l1 l2 : t =
mk_ ~tag:(CCOpt.to_list tag) pred LE.(l1 - l2)
let is_absurd (self:t) : bool =
T_map.is_empty self.le.le &&
let c = self.le.const in
begin match self.pred with
| Leq -> Q.compare c Q.zero > 0
| Lt -> Q.compare c Q.zero >= 0
| Geq -> Q.compare c Q.zero < 0
| Gt -> Q.compare c Q.zero <= 0
| Eq -> Q.compare c Q.zero <> 0
| Neq -> Q.compare c Q.zero = 0
end
let is_trivial (self:t) : bool =
T_map.is_empty self.le.le && not (is_absurd self)
(* nornalize and return maximum variable *)
let normalize (self:t) : t =
match self.pred with
| Geq -> mk_ ~tag:self.tag Leq (LE.neg self.le)
| Gt -> mk_ ~tag:self.tag Lt (LE.neg self.le)
| _ -> self
let find_max (self:t) : T.t option * bool =
match LE.max_var self.le with
| None -> None, true
| Some t -> Some t, Q.sign (T_map.find t self.le.le) > 0
end
(** constraints for a variable (where the variable is maximal) *)
type c_for_var = {
occ_pos: Constr.t list;
occ_eq: Constr.t list;
occ_neg: Constr.t list;
}
type system = {
empties: Constr.t list; (* no variables, check first *)
idx: c_for_var T_map.t;
(* map [t] to [cft] where [cft] are normalized constraints whose
maximum term is [t], with positive sign for [cft.occ_pos]
and negative for [cft.neg_pos] respectively. *)
}
type t = {
mutable cs: Constr.t list;
mutable sys: system;
}
let empty_sys : system = {empties=[]; idx=T_map.empty}
let empty_c_for_v : c_for_var =
{ occ_pos=[]; occ_neg=[]; occ_eq=[] }
let create () : t = {
cs=[];
sys=empty_sys;
}
let add_sys (sys:system) (c:Constr.t) : system =
assert (match c.pred with Eq|Leq|Lt -> true | _ -> false);
if Constr.is_trivial c then (
Log.debugf 10 (fun k->k"(@[FM.drop-trivial@ %a@])" Constr.pp c);
sys
) else (
match Constr.find_max c with
| None, _ -> {sys with empties=c :: sys.empties}
| Some v, occ_pos ->
Log.debugf 30 (fun k->k "(@[FM.add-sys %a@ :max_var %a@ :occurs-pos %B@])"
Constr.pp c T.pp v occ_pos);
let cs = T_map.get_or ~default:empty_c_for_v v sys.idx in
let cs =
if c.pred = Eq then {cs with occ_eq = c :: cs.occ_eq}
else if occ_pos then {cs with occ_pos = c :: cs.occ_pos}
else {cs with occ_neg = c :: cs.occ_neg }
in
let idx = T_map.add v cs sys.idx in
{sys with idx}
)
let assert_c (self:t) c0 : unit =
Log.debugf 10 (fun k->k "(@[FM.add-constr@ %a@ :tags %a@])"
Constr.pp c0 (Fmt.Dump.list A.pp_tag) c0.tag);
let c = Constr.normalize c0 in
if c.pred <> c0.pred then (
Log.debugf 30 (fun k->k "(@[FM.normalized %a@])" Constr.pp c);
);
assert (match c.pred with Eq | Leq | Lt -> true | _ -> false);
self.cs <- c :: self.cs;
self.sys <- add_sys self.sys c;
()
let pp_system out (self:system) : unit =
let pp_idxkv out (t,{occ_eq; occ_pos; occ_neg}) =
Fmt.fprintf out "(@[for-var %a@ :occ-eq %a@ :occ-pos %a@ :occ-neg %a@])"
T.pp t
(Fmt.Dump.list Constr.pp) occ_eq
(Fmt.Dump.list Constr.pp) occ_pos
(Fmt.Dump.list Constr.pp) occ_neg
in
Fmt.fprintf out "(@[:empties %a@ :idx (@[%a@])@])"
(Fmt.Dump.list Constr.pp) self.empties
(Util.pp_iter pp_idxkv) (T_map.to_iter self.idx)
(* TODO: be able to provide a model for SAT *)
type res =
| Sat
| Unsat of A.tag list
(* replace [x] with [by] inside [le] *)
let subst_le (x:T.t) (le:LE.t) ~by:(le1:LE.t) : LE.t =
let q = LE.find_exn x le in
let le = LE.remove x le in
LE.( le + q * le1 )
let subst_constr x c ~by : Constr.t =
let c = {c with Constr.le=subst_le x ~by c.Constr.le} in
Constr.normalize c
let rec solve_ (self:system) : res =
Log.debugf 50
(fun k->k "(@[FM.solve-rec@ :sys %a@])" pp_system self);
begin match List.find Constr.is_absurd self.empties with
| c ->
Log.debugf 10 (fun k->k"(@[FM.unsat@ :by-absurd %a@])" Constr.pp c);
Unsat c.tag
| exception Not_found ->
(* need to process biggest variable first *)
match T_map.max_binding_opt self.idx with
| None -> Sat
| Some (v, {occ_eq=c0 :: ceq'; occ_pos; occ_neg}) ->
(* at least one equality constraint, use it as a substitution *)
(* remove [v] from [idx] *)
let sys = {self with idx=T_map.remove v self.idx} in
(* substitute using [c0] in the other constraints containing [v] *)
assert (c0.pred = Eq);
let c0 = LE.normalize_wrt v c0.le in
(* turn equation [c0] into [v = rhs] *)
let rhs = LE.neg @@ LE.remove v c0 in
Log.debugf 50
(fun k->k "(@[FM.subst-from-eq@ :v %a@ :rhs %a@])"
T.pp v LE.pp rhs);
(* perform substitution in other constraints. Note that [v] cannot
occur in constraints in the rest of [sys] because it's the
maximal variable of the system, so it would be the maximum
variable of these other constraints too.
*)
let new_sys =
[Iter.of_list ceq'; Iter.of_list occ_pos; Iter.of_list occ_neg]
|> Iter.of_list
|> Iter.flatten
|> Iter.map (subst_constr v ~by:rhs)
|> Iter.fold add_sys sys
in
solve_ new_sys
| Some (v, {occ_eq=[]; occ_pos=l_pos; occ_neg=l_neg}) ->
Log.debugf 10
(fun k->k "(@[@{<yellow>FM.pivot@}@ :v %a@ :lpos %a@ :lneg %a@])"
T.pp v (Fmt.Dump.list Constr.pp) l_pos
(Fmt.Dump.list Constr.pp) l_neg);
(* remove [v] *)
let sys = {self with idx=T_map.remove v self.idx} in
(* TODO: store all lower bound constraints for [v], so we can use
their max to build the model once we have values for lower
variables *)
let new_sys =
Iter.product (Iter.of_list l_pos) (Iter.of_list l_neg)
|> Iter.map
(fun (c1,c2) ->
let q1 = LE.find_exn v c1.Constr.le in
let q2 = LE.find_exn v c2.Constr.le in
assert (Q.sign q1 > 0 && Q.sign q2 < 0);
let le = LE.( c1.Constr.le + (Q.(q1 / abs q2) * c2.Constr.le) ) in
Log.debugf 50 (fun k->k "coeff=%a; le: %a" Q.pp_print (Q.inv q1) LE.pp le);
let pred = match c1.Constr.pred, c2.Constr.pred with
| Lt, _ | _, Lt -> Pred.Lt
| Leq, Leq -> Pred.Leq
| _ ->
Log.debugf 1
(fun k->k "unexpected pair in pivot@ :c1 %a@ :c2 %a"
Constr.pp c1 Constr.pp c2);
assert false
in
let c = Constr.mk_ ~tag:(c1.tag @ c2.tag) pred le in
Log.debugf 50 (fun k->k "(@[FM.resolve@ %a@ %a@ :yields@ %a@])"
Constr.pp c1 Constr.pp c2 Constr.pp c);
assert (not (LE.mem v c.Constr.le));
c)
|> Iter.fold add_sys sys
in
solve_ new_sys
end
let solve (self:t) : res =
Log.debugf 5
(fun k->k"(@[<hv>@{<Green>FM.solve@}@ %a@])" (Util.pp_list Constr.pp) self.cs);
solve_ self.sys
end

View file

@ -0,0 +1,220 @@
(*
copyright (c) 2014-2018, Guillaume Bury, Simon Cruanes
*)
module type COEFF = Linear_expr_intf.COEFF
module type VAR = Linear_expr_intf.VAR
module type FRESH = Linear_expr_intf.FRESH
module type VAR_GEN = Linear_expr_intf.VAR_GEN
module type VAR_EXTENDED = Linear_expr_intf.VAR_EXTENDED
module type S = Linear_expr_intf.S
type bool_op = Linear_expr_intf.bool_op = Leq | Geq | Lt | Gt | Eq | Neq
module Make(C : COEFF)(Var : VAR) = struct
module C = C
module Var_map = CCMap.Make(Var)
module Var = Var
type var = Var.t
type subst = C.t Var_map.t
(** Linear combination of variables. *)
module Comb = struct
(* A map from variables to their coefficient in the linear combination. *)
type t = C.t Var_map.t
let compare = Var_map.compare C.compare
let empty = Var_map.empty
let is_empty = Var_map.is_empty
let monomial c x =
if C.equal c C.zero then empty else Var_map.singleton x c
let monomial1 x = Var_map.singleton x C.one
let add c x e =
let c' = Var_map.get_or ~default:C.zero x e in
let c' = C.(c + c') in
if C.equal C.zero c' then Var_map.remove x e else Var_map.add x c' e
let[@inline] map2 ~fr ~f a b =
Var_map.merge_safe
~f:(fun _ rhs -> match rhs with
| `Left x -> Some x
| `Right x -> Some (fr x)
| `Both (x,y) -> f x y)
a b
let[@inline] some_if_nzero x =
if C.equal C.zero x then None else Some x
let filter_map ~f m =
Var_map.fold
(fun x y m -> match f y with
| None -> m
| Some z -> Var_map.add x z m)
m Var_map.empty
module Infix = struct
let (+) = map2 ~fr:(fun x->x) ~f:(fun a b -> some_if_nzero C.(a + b))
let (-) = map2 ~fr:C.neg ~f:(fun a b -> some_if_nzero C.(a - b))
let ( * ) q = filter_map ~f:(fun x -> some_if_nzero C.(x * q))
end
include Infix
let of_list l = List.fold_left (fun e (c,x) -> add c x e) empty l
let to_list e = Var_map.bindings e |> List.rev_map CCPair.swap
let to_map e = e
let of_map e = Var_map.filter (fun _ c -> not (C.equal C.zero c)) e
let pp_pair =
Fmt.(pair ~sep:(return "@ * ") C.pp Var.pp)
let pp out (e:t) =
Fmt.(hovbox @@ list ~sep:(return "@ + ") pp_pair) out (to_list e)
let eval (subst : subst) (e:t) : C.t =
Var_map.fold
(fun x c acc -> C.(acc + c * (Var_map.find x subst)))
e C.zero
end
(** A linear arithmetic expression, composed of a combination of variables
with coefficients and a constant offset. *)
module Expr = struct
type t = {
const : C.t;
comb : Comb.t
}
let[@inline] const e = e.const
let[@inline] comb e = e.comb
let compare e e' =
CCOrd.(C.compare e.const e'.const
<?> (Comb.compare, e.comb, e'.comb))
let pp fmt e =
Format.fprintf fmt "@[<hov>%a@ + %a" Comb.pp e.comb C.pp e.const
let[@inline] make comb const : t = { comb; const; }
let of_const = make Comb.empty
let of_comb c = make c C.zero
let monomial c x = of_comb (Comb.monomial c x)
let monomial1 x = of_comb (Comb.monomial1 x)
let of_list c l = make (Comb.of_list l) c
let zero = of_const C.zero
let is_zero e = C.equal C.zero e.const && Comb.is_empty e.comb
let map2 f g e e' = make (f e.comb e'.comb) (g e.const e'.const)
module Infix = struct
let (+) = map2 Comb.(+) C.(+)
let (-) = map2 Comb.(-) C.(-)
let ( * ) c e =
if C.equal C.zero c
then zero
else make Comb.(c * e.comb) C.(c * e.const)
end
include Infix
let eval subst e = C.(e.const + Comb.eval subst e.comb)
end
module Constr = struct
type op = bool_op = Leq | Geq | Lt | Gt | Eq | Neq
(** Constraints are expressions implicitly compared to zero. *)
type t = {
expr: Expr.t;
op: op;
}
let compare c c' =
CCOrd.(compare c.op c'.op
<?> (Expr.compare, c.expr, c'.expr))
let pp_op out o =
Fmt.string out (match o with
| Leq -> "=<" | Geq -> ">=" | Lt -> "<"
| Gt -> ">" | Eq -> "=" | Neq -> "!=")
let pp out c =
Format.fprintf out "(@[%a@ %a 0@])"
Expr.pp c.expr pp_op c.op
let op t = t.op
let expr t = t.expr
let[@inline] of_expr expr op = { expr; op; }
let make comb op const = of_expr (Expr.make comb (C.neg const)) op
let geq e c = make e Geq c
let leq e c = make e Leq c
let gt e c = make e Gt c
let lt e c = make e Lt c
let eq e c = make e Eq c
let neq e c = make e Neq c
let geq0 e = of_expr e Geq
let leq0 e = of_expr e Leq
let gt0 e = of_expr e Gt
let lt0 e = of_expr e Lt
let eq0 e = of_expr e Eq
let neq0 e = of_expr e Neq
let[@inline] split {expr = {Expr.const; comb}; op} =
comb, op, C.neg const
let eval subst c =
let v = Expr.eval subst c.expr in
begin match c.op with
| Leq -> C.compare v C.zero <= 0
| Geq -> C.compare v C.zero >= 0
| Lt -> C.compare v C.zero < 0
| Gt -> C.compare v C.zero > 0
| Eq -> C.compare v C.zero = 0
| Neq -> C.compare v C.zero <> 0
end
end
end[@@inline]
module Make_var_gen(Var : VAR)
: VAR_EXTENDED with type user_var = Var.t
and type lit = Var.lit
= struct
type user_var = Var.t
type t =
| User of user_var
| Internal of int
let compare (a:t) b : int = match a, b with
| User a, User b -> Var.compare a b
| User _, Internal _ -> -1
| Internal _, User _ -> 1
| Internal i, Internal j -> CCInt.compare i j
let pp out = function
| User v -> Var.pp out v
| Internal i -> Format.fprintf out "internal_v_%d" i
type lit = Var.lit
let pp_lit = Var.pp_lit
module Fresh = struct
type t = int ref
let create() = ref 0
let copy r = ref !r
let fresh r = Internal (CCRef.get_then_incr r)
end
end[@@inline]

View file

@ -0,0 +1,26 @@
(*
copyright (c) 2014-2018, Guillaume Bury, Simon Cruanes
*)
(** Arithmetic expressions *)
module type COEFF = Linear_expr_intf.COEFF
module type VAR = Linear_expr_intf.VAR
module type FRESH = Linear_expr_intf.FRESH
module type VAR_GEN = Linear_expr_intf.VAR_GEN
module type VAR_EXTENDED = Linear_expr_intf.VAR_EXTENDED
module type S = Linear_expr_intf.S
type nonrec bool_op = Linear_expr_intf.bool_op = Leq | Geq | Lt | Gt | Eq | Neq
module Make(C : COEFF)(Var : VAR)
: S with module C = C
and module Var = Var
and module Var_map = CCMap.Make(Var)
module Make_var_gen(Var : VAR)
: VAR_EXTENDED
with type user_var = Var.t
and type lit = Var.lit

View file

@ -0,0 +1,306 @@
(*
copyright (c) 2014-2018, Guillaume Bury, Simon Cruanes
*)
(** {1 Linear expressions interface} *)
(** {2 Coefficients}
Coefficients are used in expressions. They usually
are either rationals, or integers.
*)
module type COEFF = sig
type t
val equal : t -> t -> bool
(** Equality on coefficients. *)
val compare : t -> t -> int
(** Comparison on coefficients. *)
val pp : t Fmt.printer
(** Printer for coefficients. *)
val zero : t
(** The zero coefficient. *)
val one : t
(** The one coefficient (to rule them all, :p). *)
val neg : t -> t
(** Unary negation *)
val (+) : t -> t -> t
val (-) : t -> t -> t
val ( * ) : t -> t -> t
(** Standard operations on coefficients. *)
end
(** {2 Variable interface}
Standard interface for variables that are meant to be used
in expressions.
*)
module type VAR = sig
type t
(** Variable type. *)
val compare : t -> t -> int
(** Standard comparison function on variables. *)
val pp : t Fmt.printer
(** Printer for variables. *)
type lit
val pp_lit : lit Fmt.printer
end
(** {2 Fresh variables}
Standard interface for variables with an infinite number
of 'fresh' variables. A 'fresh' variable should be distinct
from any other.
*)
module type FRESH = sig
type var
(** The type of variables. *)
type t
(** A type of state for creating fresh variables. *)
val copy : t -> t
(** Copy state *)
val fresh : t -> var
(** Create a fresh variable using an existing variable as base.
TODO: need some explaining, about the difference with {!create}. *)
end
(** {2 Generative Variable interface}
Standard interface for variables that are meant to be used
in expressions. Furthermore, fresh variables can be generated
(which is useful to refactor and/or put problems in specific
formats used by algorithms).
*)
module type VAR_GEN = sig
include VAR
(** Generate fresh variables on demand *)
module Fresh : FRESH with type var := t
end
module type VAR_EXTENDED = sig
type user_var (** original variables *)
type t =
| User of user_var
| Internal of int
include VAR_GEN with type t := t
end
type bool_op = Predicate.t = Leq | Geq | Lt | Gt | Eq | Neq
(** {2 Linear expressions & formulas} *)
(** Linear expressions & formulas.
This modules defines linear expressions (which are linear
combinations of variables), and linear constraints, where
the value of a linear expressions is constrained.
*)
module type S = sig
module C : COEFF
(** Coeficients used. Can be integers as well as rationals. *)
module Var : VAR
(** Variables used in expressions. *)
type var = Var.t
(** The type of variables appearing in expressions. *)
module Var_map : CCMap.S with type key = var
(** Maps from variables, used for expressions as well as substitutions. *)
type subst = C.t Var_map.t
(** Type for substitutions. *)
(** Combinations.
This module defines linear combnations as mapping from variables
to coefficients. This allows for very fast computations.
*)
module Comb : sig
type t = private C.t Var_map.t
(** The type of linear combinations. *)
val compare : t -> t -> int
(** Comparisons on linear combinations. *)
val pp : t Fmt.printer
(** Printer for linear combinations. *)
val is_empty : t -> bool
(** Is the given expression empty ?*)
(** {5 Creation} *)
val empty : t
(** The empty linear combination. *)
val monomial : C.t -> var -> t
(** [monome n v] creates the linear combination [n * v] *)
val monomial1 : var -> t
(** [monome1 v] creates the linear combination [1 * v] *)
val add : C.t -> var -> t -> t
(** [add n v t] adds the monome [n * v] to the combination [t]. *)
(** Infix operations on combinations
This module defines usual operations on linear combinations,
as infix operators to ease reading of complex computations. *)
module Infix : sig
val (+) : t -> t -> t
(** Addition between combinations. *)
val (-) : t -> t -> t
(** Substraction between combinations. *)
val ( * ) : C.t -> t -> t
(** Multiplication by a constant. *)
end
include module type of Infix
(** Include the previous module. *)
val of_list : (C.t * var) list -> t
val to_list : t -> (C.t * var) list
(** Converters to and from lists of monomes. *)
val of_map : C.t Var_map.t -> t
val to_map : t -> C.t Var_map.t
(** {5 Semantics} *)
val eval : subst -> t -> C.t
(** Evaluate a linear combination given a substitution for its variables.
TODO: document potential exceptions raised ?*)
end
(** {2 Linear expressions.} *)
(** Linear expressions represent linear arithmetic expressions as
a linear combination and a constant. *)
module Expr : sig
type t
(** The type of linear expressions. *)
val comb : t -> Comb.t
val const : t -> C.t
val is_zero : t -> bool
val compare : t -> t -> int
(** Standard comparison function on expressions. *)
val pp : t Fmt.printer
(** Standard printing function on expressions. *)
val zero : t
(** The expression [2]. *)
val of_const : C.t -> t
(** The constant expression. *)
val of_comb : Comb.t -> t
(** Combination without constant *)
val of_list : C.t -> (C.t * Var.t) list -> t
val make : Comb.t -> C.t -> t
(** [make c n] makes the linear expression [c + n]. *)
val monomial : C.t -> var -> t
val monomial1 : var -> t
(** Infix operations on expressions
This module defines usual operations on linear expressions,
as infix operators to ease reading of complex computations. *)
module Infix : sig
val (+) : t -> t -> t
(** Addition between expressions. *)
val (-) : t -> t -> t
(** Substraction between expressions. *)
val ( * ) : C.t -> t -> t
(** Multiplication by a constant. *)
end
include module type of Infix
(** Include the previous module. *)
(** {5 Semantics} *)
val eval : subst -> t -> C.t
(** Evaluate a linear expression given a substitution for its variables.
TODO: document potential exceptions raised ?*)
end
(** {2 Linear constraints.}
Represents constraints on linear expressions. *)
module Constr : sig
type op = bool_op
(** Arithmetic comparison operators. *)
type t = {
expr: Expr.t;
op: op;
}
(** Linear constraints. Expressions are implicitly compared to zero. *)
val compare : t -> t -> int
(** Standard comparison function. *)
val pp : t Fmt.printer
(** Standard printing function. *)
val of_expr : Expr.t -> bool_op -> t
val make : Comb.t -> bool_op -> C.t -> t
(** Create a constraint from a linear expression/combination and a constant. *)
val geq : Comb.t -> C.t -> t
val leq : Comb.t -> C.t -> t
val gt: Comb.t -> C.t -> t
val lt : Comb.t -> C.t -> t
val eq : Comb.t -> C.t -> t
val neq : Comb.t -> C.t -> t
val geq0 : Expr.t -> t
val leq0 : Expr.t -> t
val gt0 : Expr.t -> t
val lt0 : Expr.t -> t
val eq0 : Expr.t -> t
val neq0 : Expr.t -> t
val op : t -> bool_op
val expr : t -> Expr.t
(** Extract the given part from a constraint. *)
val split : t -> Comb.t * bool_op * C.t
(** Split the linear combinations from the constant *)
val eval : subst -> t -> bool
(** Evaluate the given constraint under a substitution. *)
end
end

View file

@ -0,0 +1,17 @@
type t = Leq | Geq | Lt | Gt | Eq | Neq
let neg = function
| Leq -> Gt
| Lt -> Geq
| Eq -> Neq
| Neq -> Eq
| Geq -> Lt
| Gt -> Leq
let to_string = function
| Leq -> "=<" | Geq -> ">=" | Lt -> "<"
| Gt -> ">" | Eq -> "=" | Neq -> "!="
let pp out (self:t) = Fmt.string out (to_string self)

View file

@ -6,9 +6,11 @@
open Sidekick_core
module FM = Fourier_motzkin
module Simplex = Simplex
module Predicate = Predicate
module Linear_expr = Linear_expr
type pred = FM.Pred.t = Lt | Leq | Geq | Gt | Neq | Eq
type pred = Linear_expr_intf.bool_op = Leq | Geq | Lt | Gt | Eq | Neq
type op = Plus | Minus
type 'a lra_view =
@ -31,6 +33,7 @@ module type ARG = sig
module S : Sidekick_core.SOLVER
type term = S.T.Term.t
type ty = S.T.Ty.t
val view_as_lra : term -> term lra_view
(** Project the term into the theory view *)
@ -38,11 +41,17 @@ module type ARG = sig
val mk_lra : S.T.Term.state -> term lra_view -> term
(** Make a term from the given theory view *)
val ty_lra : S.T.Term.state -> ty
module Gensym : sig
type t
val create : S.T.Term.state -> t
val tst : t -> S.T.Term.state
val copy : t -> t
val fresh_term : t -> pre:string -> S.T.Ty.t -> term
(** Make a fresh term of the given type *)
end
@ -65,15 +74,32 @@ module Make(A : ARG) : S with module A = A = struct
module Lit = A.S.Solver_internal.Lit
module SI = A.S.Solver_internal
(* the fourier motzkin module *)
module FM_A = FM.Make(struct
module T = T
type tag = Lit.t
let pp_tag = Lit.pp
end)
module SimpVar
: Linear_expr.VAR_GEN
with type t = A.term
and type Fresh.t = A.Gensym.t
and type lit = Lit.t
= struct
type t = A.term
let pp = A.S.T.Term.pp
let compare = A.S.T.Term.compare
type lit = Lit.t
let pp_lit = Lit.pp
module Fresh = struct
type t = A.Gensym.t
let copy = A.Gensym.copy
let fresh (st:t) =
let ty = A.ty_lra (A.Gensym.tst st) in
A.Gensym.fresh_term ~pre:"_lra" st ty
end
end
module SimpSolver = Simplex.Make_full(SimpVar)
(* linear expressions *)
module LE = FM_A.LE
module LComb = SimpSolver.L.Comb
module LE = SimpSolver.L.Expr
module LConstr = SimpSolver.L.Constr
type state = {
tst: T.state;
@ -144,13 +170,13 @@ module Make(A : ARG) : S with module A = A = struct
mk_lit t
let pp_pred_def out (p,l1,l2) : unit =
Fmt.fprintf out "(@[%a@ :l1 %a@ :l2 %a@])" FM.Pred.pp p LE.pp l1 LE.pp l2
Fmt.fprintf out "(@[%a@ :l1 %a@ :l2 %a@])" Predicate.pp p LE.pp l1 LE.pp l2
(* turn the term into a linear expression. Apply [f] on leaves. *)
let rec as_linexp ~f (t:T.t) : LE.t =
let open LE.Infix in
match A.view_as_lra t with
| LRA_other _ -> LE.var (f t)
| LRA_other _ -> LE.monomial1 (f t)
| LRA_pred _ ->
Error.errorf "type error: in linexp, LRA predicate %a" T.pp t
| LRA_op (op, t1, t2) ->
@ -163,7 +189,7 @@ module Make(A : ARG) : S with module A = A = struct
| LRA_mult (n, x) ->
let t = as_linexp ~f x in
LE.( n * t )
| LRA_const q -> LE.const q
| LRA_const q -> LE.of_const q
(* TODO: keep the linexps until they're asserted;
TODO: but use simplification in preprocess
@ -219,15 +245,15 @@ module Make(A : ARG) : S with module A = A = struct
let final_check_ (self:state) si (acts:SI.actions) (trail:_ Iter.t) : unit =
Log.debug 5 "(th-lra.final-check)";
let fm = FM_A.create() in
let simplex = SimpSolver.create self.gensym in
(* first, add definitions *)
begin
List.iter
(fun (t,le) ->
let open LE.Infix in
let le = le - LE.var t in
let c = FM_A.Constr.mk ?tag:None Eq (LE.var t) le in
FM_A.assert_c fm c)
let le = le - LE.monomial1 t in
let c = LConstr.eq0 le in
SimpSolver.add_constr simplex c)
self.t_defs
end;
(* add trail *)
@ -240,28 +266,34 @@ module Make(A : ARG) : S with module A = A = struct
begin match T.Tbl.find self.pred_defs t with
| exception Not_found -> ()
| (pred, a, b) ->
let pred = if sign then pred else FM.Pred.neg pred in
(* FIXME: generic negation+printer in Linear_expr_intf;
actually move predicates to their own module *)
let pred = if sign then pred else Predicate.neg pred in
if pred = Neq then (
Log.debugf 50 (fun k->k "skip neq in %a" T.pp t);
) else (
let c = FM_A.Constr.mk ~tag:lit pred a b in
FM_A.assert_c fm c;
(* TODO: tag *)
let c = LConstr.of_expr LE.(a-b) pred in
SimpSolver.add_constr simplex c;
)
end)
end;
Log.debug 5 "lra: call arith solver";
begin match FM_A.solve fm with
| FM_A.Sat ->
begin match SimpSolver.solve simplex with
| SimpSolver.Solution _m ->
Log.debug 5 "lra: solver returns SAT";
() (* TODO: get a model + model combination *)
| FM_A.Unsat lits ->
| SimpSolver.Unsatisfiable _cert ->
(* we tagged assertions with their lit, so the certificate being an
unsat core translates directly into a conflict clause *)
assert false
(* TODO
Log.debugf 5 (fun k->k"lra: solver returns UNSAT@ with cert %a"
(Fmt.Dump.list Lit.pp) lits);
let confl = List.rev_map Lit.neg lits in
(* TODO: produce and store a proper LRA resolution proof *)
SI.raise_conflict si acts confl SI.P.default
*)
end;
()

688
src/arith/lra/simplex.ml Normal file
View file

@ -0,0 +1,688 @@
(*
copyright (c) 2014-2018, Guillaume Bury, Simon Cruanes
*)
(* OPTIMS:
* - distinguish separate systems (that do not interact), such as in { 1 <= 3x = 3y <= 2; z <= 3} ?
* - Implement gomorry cuts ?
*)
open Containers
module type VAR = Linear_expr_intf.VAR
module type FRESH = Linear_expr_intf.FRESH
module type VAR_GEN = Linear_expr_intf.VAR_GEN
module type S = Simplex_intf.S
module type S_FULL = Simplex_intf.S_FULL
module Vec = CCVector
module Matrix : sig
type 'a t
val create : unit -> 'a t
val get : 'a t -> int -> int -> 'a
val set : 'a t -> int -> int -> 'a -> unit
val get_row : 'a t -> int -> 'a Vec.vector
val copy : 'a t -> 'a t
val n_row : _ t -> int
val n_col : _ t -> int
val push_row : 'a t -> 'a -> unit (* new row, filled with element *)
val push_col : 'a t -> 'a -> unit (* new column, filled with element *)
(**/**)
val check_invariants : _ t -> bool
(**/**)
end = struct
type 'a t = {
mutable n_col: int; (* num of columns *)
tab: 'a Vec.vector Vec.vector;
}
let[@inline] create() : _ = {tab=Vec.create(); n_col=0}
let[@inline] get m i j = Vec.get (Vec.get m.tab i) j
let[@inline] get_row m i = Vec.get m.tab i
let[@inline] set (m:_ t) i j x = Vec.set (Vec.get m.tab i) j x
let[@inline] copy m = {m with tab=Vec.map Vec.copy m.tab}
let[@inline] n_row m = Vec.length m.tab
let[@inline] n_col m = m.n_col
let push_row m x = Vec.push m.tab (Vec.make (n_col m) x)
let push_col m x =
m.n_col <- m.n_col + 1;
Vec.iter (fun row -> Vec.push row x) m.tab
let check_invariants m = Vec.for_all (fun r -> Vec.length r = n_col m) m.tab
end
(* use non-polymorphic comparison ops *)
open Int.Infix
(* Simplex Implementation *)
module Make_inner
(Var: VAR)
(VMap : CCMap.S with type key=Var.t)
(Param: sig type t val copy : t -> t end)
= struct
module Var_map = VMap
module M = Var_map
(* Exceptions *)
exception Unsat of Var.t
exception AbsurdBounds of Var.t
exception NoneSuitable
type param = Param.t
type var = Var.t
type lit = Var.lit
type basic_var = var
type nbasic_var = var
type erat = {
base: Q.t; (* reference number *)
eps_factor: Q.t; (* coefficient for epsilon, the infinitesimal *)
}
(** Epsilon-rationals, used for strict bounds *)
module Erat = struct
type t = erat
let zero : t = {base=Q.zero; eps_factor=Q.zero}
let[@inline] make base eps_factor : t = {base; eps_factor}
let[@inline] base t = t.base
let[@inline] eps_factor t = t.eps_factor
let[@inline] mul k e = make Q.(k * e.base) Q.(k * e.eps_factor)
let[@inline] sum e1 e2 = make Q.(e1.base + e2.base) Q.(e1.eps_factor + e2.eps_factor)
let[@inline] compare e1 e2 = match Q.compare e1.base e2.base with
| 0 -> Q.compare e1.eps_factor e2.eps_factor
| x -> x
let lt a b = compare a b < 0
let gt a b = compare a b > 0
let[@inline] min x y = if compare x y <= 0 then x else y
let[@inline] max x y = if compare x y >= 0 then x else y
let[@inline] evaluate (epsilon:Q.t) (e:t) : Q.t = Q.(e.base + epsilon * e.eps_factor)
let pp out e =
if Q.equal Q.zero (eps_factor e)
then Q.pp_print out (base e)
else
Format.fprintf out "(@[<h>%a + @<1>ε * %a@])"
Q.pp_print (base e) Q.pp_print (eps_factor e)
end
let str_of_var = Format.to_string Var.pp
let str_of_erat = Format.to_string Erat.pp
let str_of_q = Format.to_string Q.pp_print
type t = {
param: param;
tab : Q.t Matrix.t; (* the matrix of coefficients *)
basic : basic_var Vec.vector; (* basic variables *)
nbasic : nbasic_var Vec.vector; (* non basic variables *)
mutable assign : Erat.t M.t; (* assignments *)
mutable bounds : (Erat.t * Erat.t) M.t; (* (lower, upper) bounds for variables *)
mutable idx_basic : int M.t; (* basic var -> its index in [basic] *)
mutable idx_nbasic : int M.t; (* non basic var -> its index in [nbasic] *)
}
type cert = {
cert_var: var;
cert_expr: (Q.t * var) list;
cert_core: lit list;
}
type res =
| Solution of Q.t Var_map.t
| Unsatisfiable of cert
let create param : t = {
param: param;
tab = Matrix.create ();
basic = Vec.create ();
nbasic = Vec.create ();
assign = M.empty;
bounds = M.empty;
idx_basic = M.empty;
idx_nbasic = M.empty;
}
let copy t = {
param = Param.copy t.param;
tab = Matrix.copy t.tab;
basic = Vec.copy t.basic;
nbasic = Vec.copy t.nbasic;
assign = t.assign;
bounds = t.bounds;
idx_nbasic = t.idx_nbasic;
idx_basic = t.idx_basic;
}
let index_basic (t:t) (x:basic_var) : int =
match M.find x t.idx_basic with
| n -> n
| exception Not_found -> -1
let index_nbasic (t:t) (x:nbasic_var) : int =
match M.find x t.idx_nbasic with
| n -> n
| exception Not_found -> -1
let[@inline] mem_basic (t:t) (x:var) : bool = M.mem x t.idx_basic
let[@inline] mem_nbasic (t:t) (x:var) : bool = M.mem x t.idx_nbasic
(* check invariants, for test purposes *)
let check_invariants (t:t) : bool =
Matrix.check_invariants t.tab &&
Vec.for_all (fun v -> mem_basic t v) t.basic &&
Vec.for_all (fun v -> mem_nbasic t v) t.nbasic &&
Vec.for_all (fun v -> not (mem_nbasic t v)) t.basic &&
Vec.for_all (fun v -> not (mem_basic t v)) t.nbasic &&
Vec.for_all (fun v -> Var_map.mem v t.assign) t.nbasic &&
Vec.for_all (fun v -> not (Var_map.mem v t.assign)) t.basic &&
true
(* find the definition of the basic variable [x],
as a linear combination of non basic variables *)
let find_expr_basic_opt t (x:var) : Q.t Vec.vector option =
begin match index_basic t x with
| -1 -> None
| i -> Some (Matrix.get_row t.tab i)
end
let find_expr_basic t (x:basic_var) : Q.t Vec.vector =
begin match find_expr_basic_opt t x with
| None -> assert false
| Some e -> e
end
(* build the expression [y = \sum_i (if x_i=y then 1 else 0)·x_i] *)
let find_expr_nbasic t (x:nbasic_var) : Q.t Vec.vector =
Vec.map
(fun y -> if Var.compare x y = 0 then Q.one else Q.zero)
t.nbasic
(* TODO: avoid double lookup in maps *)
(* find expression of [x] *)
let find_expr_total (t:t) (x:var) : Q.t Vec.vector =
if mem_basic t x then
find_expr_basic t x
else (
assert (mem_nbasic t x);
find_expr_nbasic t x
)
(* compute value of basic variable.
It can be computed by using [x]'s definition
in terms of nbasic variables, which have values *)
let value_basic (t:t) (x:basic_var) : Erat.t =
assert (mem_basic t x);
let res = ref Erat.zero in
let expr = find_expr_basic t x in
for i = 0 to Vec.length expr - 1 do
let val_nbasic_i =
try M.find (Vec.get t.nbasic i) t.assign
with Not_found -> assert false
in
res := Erat.sum !res (Erat.mul (Vec.get expr i) val_nbasic_i)
done;
!res
(* extract a value for [x] *)
let[@inline] value (t:t) (x:var) : Erat.t =
try M.find x t.assign (* nbasic variables are assigned *)
with Not_found -> value_basic t x
(* trivial bounds *)
let empty_bounds : Erat.t * Erat.t = Q.(Erat.make minus_inf zero, Erat.make inf zero)
(* find bounds of [x] *)
let[@inline] get_bounds (t:t) (x:var) : Erat.t * Erat.t =
try M.find x t.bounds
with Not_found -> empty_bounds
(* is [value x] within the bounds for [x]? *)
let is_within_bounds (t:t) (x:var) : bool * Erat.t =
let v = value t x in
let low, upp = get_bounds t x in
if Erat.compare v low < 0 then
false, low
else if Erat.compare v upp > 0 then
false, upp
else
true, v
(* add nbasic variables *)
let add_vars (t:t) (l:var list) : unit =
(* add new variable to idx and array for nbasic, removing duplicates
and variables already present *)
let idx_nbasic, _, l =
List.fold_left
(fun ((idx_nbasic, offset, l) as acc) x ->
if mem_basic t x then acc
else if M.mem x idx_nbasic then acc
else (
(* allocate new index for [x] *)
M.add x offset idx_nbasic, offset+1, x::l
))
(t.idx_nbasic, Vec.length t.nbasic, [])
l
in
(* add new columns to the matrix *)
let old_dim = Matrix.n_col t.tab in
List.iter (fun _ -> Matrix.push_col t.tab Q.zero) l;
assert (old_dim + List.length l = Matrix.n_col t.tab);
Vec.append_list t.nbasic (List.rev l);
(* assign these variables *)
t.assign <- List.fold_left (fun acc y -> M.add y Erat.zero acc) t.assign l;
t.idx_nbasic <- idx_nbasic;
()
(* define basic variable [x] by [eq] in [t] *)
let add_eq (t:t) (x, eq : basic_var * _ list) : unit =
if mem_basic t x || mem_nbasic t x then (
invalid_arg (Format.sprintf "Variable `%a` already defined." Var.pp x);
);
add_vars t (List.map snd eq);
(* add [x] as a basic var *)
t.idx_basic <- M.add x (Vec.length t.basic) t.idx_basic;
Vec.push t.basic x;
(* add new row for defining [x] *)
assert (Matrix.n_col t.tab > 0);
Matrix.push_row t.tab Q.zero;
let row_i = Matrix.n_row t.tab - 1 in
assert (row_i >= 0);
(* now put into the row the coefficients corresponding to [eq],
expanding basic variables to their definition *)
List.iter
(fun (c, x) ->
let expr = find_expr_total t x in
assert (Vec.length expr = Matrix.n_col t.tab);
Vec.iteri
(fun j c' ->
if not (Q.equal Q.zero c') then (
Matrix.set t.tab row_i j Q.(Matrix.get t.tab row_i j + c * c')
))
expr)
eq;
()
(* add bounds to [x] in [t] *)
let add_bound_aux (t:t) (x:var) (low:Erat.t) (upp:Erat.t) : unit =
add_vars t [x];
let l, u = get_bounds t x in
t.bounds <- M.add x (Erat.max l low, Erat.min u upp) t.bounds
let add_bounds (t:t) ?strict_lower:(slow=false) ?strict_upper:(supp=false) (x, l, u) : unit =
let e1 = if slow then Q.one else Q.zero in
let e2 = if supp then Q.neg Q.one else Q.zero in
add_bound_aux t x (Erat.make l e1) (Erat.make u e2);
if mem_nbasic t x then (
let b, v = is_within_bounds t x in
if not b then (
t.assign <- M.add x v t.assign;
)
)
let add_lower_bound t ?strict x l = add_bounds t ?strict_lower:strict (x,l,Q.inf)
let add_upper_bound t ?strict x u = add_bounds t ?strict_upper:strict (x,Q.minus_inf,u)
(* full assignment *)
let full_assign (t:t) : (var * Erat.t) Iter.t =
Iter.append (Vec.to_iter t.nbasic) (Vec.to_iter t.basic)
|> Iter.map (fun x -> x, value t x)
let[@inline] min x y = if Q.compare x y < 0 then x else y
(* Find an epsilon that is small enough for finding a solution, yet
it must be positive.
{!Erat.t} values are used to turn strict bounds ([X > 0]) into
non-strict bounds ([X >= 0 + ε]), because the simplex algorithm
only deals with non-strict bounds.
When a solution is found, we need to turn {!Erat.t} into {!Q.t} by
finding a rational value that is small enough that it will fit into
all the intervals of [t]. This rational will be the actual value of [ε].
*)
let solve_epsilon (t:t) : Q.t =
let emax =
M.fold
(fun x ({base=low;eps_factor=e_low}, {base=upp;eps_factor=e_upp}) emax ->
let {base=v; eps_factor=e_v} = value t x in
(* lower bound *)
let emax =
if Q.compare low Q.minus_inf > 0 && Q.compare e_v e_low < 0
then min emax Q.((low - v) / (e_v - e_low))
else emax
in
(* upper bound *)
if Q.compare upp Q.inf < 0 && Q.compare e_v e_upp > 0
then min emax Q.((upp - v) / (e_v - e_upp))
else emax)
t.bounds
Q.inf
in
if Q.compare emax Q.one >= 0 then Q.one else emax
let get_full_assign_seq (t:t) : _ Iter.t =
let e = solve_epsilon t in
let f = Erat.evaluate e in
full_assign t
|> Iter.map (fun (x,v) -> x, f v)
let get_full_assign t : Q.t Var_map.t = Var_map.of_iter (get_full_assign_seq t)
(* Find nbasic variable suitable for pivoting with [x].
A nbasic variable [y] is suitable if it "goes into the right direction"
(its coefficient in the definition of [x] is of the adequate sign)
and if it hasn't reached its bound in this direction.
precondition: [x] is a basic variable whose value in current assignment
is outside its bounds
We return the smallest (w.r.t Var.compare) suitable variable.
This is important for termination.
*)
let find_suitable_nbasic_for_pivot (t:t) (x:basic_var) : nbasic_var * Q.t =
assert (mem_basic t x);
let _, v = is_within_bounds t x in
let b = Erat.compare (value t x) v < 0 in
(* is nbasic var [y], with coeff [a] in definition of [x], suitable? *)
let test (y:nbasic_var) (a:Q.t) : bool =
assert (mem_nbasic t y);
let v = value t y in
let low, upp = get_bounds t y in
if b then (
(Erat.lt v upp && Q.compare a Q.zero > 0) ||
(Erat.gt v low && Q.compare a Q.zero < 0)
) else (
(Erat.gt v low && Q.compare a Q.zero > 0) ||
(Erat.lt v upp && Q.compare a Q.zero < 0)
)
in
let nbasic_vars = t.nbasic in
let expr = find_expr_basic t x in
(* find best suitable variable *)
let rec aux i =
if i = Vec.length nbasic_vars then (
assert (i = Vec.length expr);
None
) else (
let y = Vec.get nbasic_vars i in
let a = Vec.get expr i in
if test y a then (
(* see if other variables are better suited *)
begin match aux (i+1) with
| None -> Some (y,a)
| Some (z, _) as res_tail ->
if Var.compare y z <= 0
then Some (y,a)
else res_tail
end
) else (
aux (i+1)
)
)
in
begin match aux 0 with
| Some res -> res
| None -> raise NoneSuitable
end
(* pivot to exchange [x] and [y] *)
let pivot (t:t) (x:basic_var) (y:nbasic_var) (a:Q.t) : unit =
(* swap values ([x] becomes assigned) *)
let val_x = value t x in
t.assign <- t.assign |> M.remove y |> M.add x val_x;
(* Matrixrix Pivot operation *)
let kx = index_basic t x in
let ky = index_nbasic t y in
for j = 0 to Vec.length t.nbasic - 1 do
if Var.compare y (Vec.get t.nbasic j) = 0 then (
Matrix.set t.tab kx j Q.(one / a)
) else (
Matrix.set t.tab kx j Q.(neg (Matrix.get t.tab kx j) / a)
)
done;
for i = 0 to Vec.length t.basic - 1 do
if i <> kx then (
let c = Matrix.get t.tab i ky in
Matrix.set t.tab i ky Q.zero;
for j = 0 to Vec.length t.nbasic - 1 do
Matrix.set t.tab i j Q.(Matrix.get t.tab i j + c * Matrix.get t.tab kx j)
done
)
done;
(* Switch x and y in basic and nbasic vars *)
Vec.set t.basic kx y;
Vec.set t.nbasic ky x;
t.idx_basic <- t.idx_basic |> M.remove x |> M.add y kx;
t.idx_nbasic <- t.idx_nbasic |> M.remove y |> M.add x ky;
()
(* find minimum element of [arr] (wrt [cmp]) that satisfies predicate [f] *)
let find_min_filter ~cmp (f:'a -> bool) (arr:('a,_) Vec.t) : 'a option =
(* find the first element that satisfies [f] *)
let rec aux_find_first i =
if i = Vec.length arr then None
else (
let x = Vec.get arr i in
if f x
then aux_compare_with x (i+1)
else aux_find_first (i+1)
)
(* find if any element of [l] satisfies [f] and is smaller than [x] *)
and aux_compare_with x i =
if i = Vec.length arr then Some x
else (
let y = Vec.get arr i in
let best = if f y && cmp y x < 0 then y else x in
aux_compare_with best (i+1)
)
in
aux_find_first 0
(* check bounds *)
let check_bounds (t:t) : unit =
M.iter (fun x (l, u) -> if Erat.gt l u then raise (AbsurdBounds x)) t.bounds
(* actual solving algorithm *)
let solve_aux (t:t) : unit =
check_bounds t;
(* select the smallest basic variable that is not satisfied in the current
assignment. *)
let rec aux_select_basic_var () =
match
find_min_filter ~cmp:Var.compare
(fun x -> not (fst (is_within_bounds t x)))
t.basic
with
| Some x -> aux_pivot_on_basic x
| None -> ()
(* remove the basic variable *)
and aux_pivot_on_basic x =
let _b, v = is_within_bounds t x in
assert (not _b);
match find_suitable_nbasic_for_pivot t x with
| y, a ->
(* exchange [x] and [y] by pivoting *)
pivot t x y a;
(* assign [x], now a nbasic variable, to the faulty bound [v] *)
t.assign <- M.add x v t.assign;
(* next iteration *)
aux_select_basic_var ()
| exception NoneSuitable ->
raise (Unsat x)
in
aux_select_basic_var ();
()
(* main method for the user to call *)
let solve (t:t) : res =
try
solve_aux t;
Solution (get_full_assign t)
with
| Unsat x ->
let cert_expr =
List.combine
(Vec.to_list (find_expr_basic t x))
(Vec.to_list t.nbasic)
in
Unsatisfiable { cert_var=x; cert_expr; cert_core=[]; } (* FIXME *)
| AbsurdBounds x ->
Unsatisfiable { cert_var=x; cert_expr=[]; cert_core=[]; }
(* add [c·x] to [m] *)
let add_expr_ (x:var) (c:Q.t) (m:Q.t M.t) =
let c' = M.get_or ~default:Q.zero x m in
let c' = Q.(c + c') in
if Q.equal Q.zero c' then M.remove x m else M.add x c' m
(* dereference basic variables from [c·x], and add the result to [m] *)
let rec deref_var_ t x c m = match find_expr_basic_opt t x with
| None -> add_expr_ x c m
| Some expr_x ->
let m = ref m in
Vec.iteri
(fun i c_i ->
let y_i = Vec.get t.nbasic i in
m := deref_var_ t y_i Q.(c * c_i) !m)
expr_x;
!m
(* maybe invert bounds, if [c < 0] *)
let scale_bounds c (l,u) : erat * erat =
match Q.compare c Q.zero with
| 0 -> Erat.zero, Erat.zero
| n when n<0 -> Erat.mul c u, Erat.mul c l
| _ -> Erat.mul c l, Erat.mul c u
let check_cert (t:t) (c:cert) =
let x = c.cert_var in
let low_x, up_x = get_bounds t x in
begin match c.cert_expr with
| [] ->
if Erat.compare low_x up_x > 0 then `Ok
else `Bad_bounds (str_of_erat low_x, str_of_erat up_x)
| expr ->
let e0 = deref_var_ t x (Q.neg Q.one) M.empty in
(* compute bounds for the expression [c.cert_expr],
and also compute [c.cert_expr - x] to check if it's 0] *)
let low, up, expr_minus_x =
List.fold_left
(fun (l,u,expr_minus_x) (c, y) ->
let ly, uy = scale_bounds c (get_bounds t y) in
assert (Erat.compare ly uy <= 0);
let expr_minus_x = deref_var_ t y c expr_minus_x in
Erat.sum l ly, Erat.sum u uy, expr_minus_x)
(Erat.zero, Erat.zero, e0)
expr
in
(* check that the expanded expression is [x], and that
one of the bounds on [x] is incompatible with bounds of [c.cert_expr] *)
if M.is_empty expr_minus_x then (
if Erat.compare low_x up > 0 || Erat.compare up_x low < 0
then `Ok
else `Bad_bounds (str_of_erat low, str_of_erat up)
) else `Diff_not_0 expr_minus_x
end
(* printer *)
let matrix_pp_width = ref 8
let fmt_head = format_of_string "|%*s|| "
let fmt_cell = format_of_string "%*s| "
let pp_cert out (c:cert) = match c.cert_expr with
| [] -> Format.fprintf out "(@[inconsistent-bounds %a@])" Var.pp c.cert_var
| _ ->
let pp_pair = Format.(hvbox ~i:2 @@ pair ~sep:(return "@ * ") Q.pp_print Var.pp) in
Format.fprintf out "(@[<hv>cert@ :var %a@ :linexp %a@])"
Var.pp c.cert_var
Format.(within "[" "]" @@ hvbox @@ list ~sep:(return "@ + ") pp_pair)
c.cert_expr
let pp_mat out t =
let open Format in
fprintf out "@[<v>";
(* header *)
fprintf out fmt_head !matrix_pp_width "";
Vec.iter (fun x -> fprintf out fmt_cell !matrix_pp_width (str_of_var x)) t.nbasic;
fprintf out "@,";
(* rows *)
for i=0 to Matrix.n_row t.tab-1 do
if i>0 then fprintf out "@,";
let v = Vec.get t.basic i in
fprintf out fmt_head !matrix_pp_width (str_of_var v);
let row = Matrix.get_row t.tab i in
Vec.iter (fun q -> fprintf out fmt_cell !matrix_pp_width (str_of_q q)) row;
done;
fprintf out "@]"
let pp_assign =
let open Format in
let pp_pair =
within "(" ")" @@ hvbox @@ pair ~sep:(return "@ := ") Var.pp Erat.pp
in
map Var_map.to_seq @@ within "(" ")" @@ hvbox @@ seq pp_pair
let pp_bounds =
let open Format in
let pp_pairs out (x,(l,u)) =
fprintf out "(@[%a =< %a =< %a@])" Erat.pp l Var.pp x Erat.pp u
in
map Var_map.to_seq @@ within "(" ")" @@ hvbox @@ seq pp_pairs
let pp_full_state out (t:t) : unit =
(* print main matrix *)
Format.fprintf out
"(@[<hv>simplex@ :n-row %d :n-col %d@ :mat %a@ :assign %a@ :bounds %a@])"
(Matrix.n_row t.tab) (Matrix.n_col t.tab) pp_mat t pp_assign t.assign
pp_bounds t.bounds
end
module Make(Var:VAR) =
Make_inner(Var)(CCMap.Make(Var))(struct
type t = unit
let copy ()=()
end)
module Make_full_for_expr(V : VAR_GEN)
(L : Linear_expr.S
with type Var.t = V.t
and type C.t = Q.t
and type Var.lit = V.lit)
= struct
include Make_inner(V)(L.Var_map)(V.Fresh)
module L = L
type op = Predicate.t = Leq | Geq | Lt | Gt | Eq | Neq
type constr = L.Constr.t
(* add a constraint *)
let add_constr (t:t) (c:constr) : unit =
let (x:var) = V.Fresh.fresh t.param in
let e, op, q = L.Constr.split c in
add_eq t (x, L.Comb.to_list e);
begin match op with
| Leq -> add_upper_bound t ~strict:false x q
| Geq -> add_lower_bound t ~strict:false x q
| Lt -> add_upper_bound t ~strict:true x q
| Gt -> add_lower_bound t ~strict:true x q
| Eq -> add_bounds t ~strict_lower:false ~strict_upper:false (x,q,q)
| Neq -> assert false
end
end
module Make_full(V : VAR_GEN)
= Make_full_for_expr(V)(Linear_expr.Make(struct include Q let pp = pp_print end)(V))

31
src/arith/lra/simplex.mli Normal file
View file

@ -0,0 +1,31 @@
(** Solving Linear systems of rational equations. *)
module type VAR = Linear_expr_intf.VAR
module type FRESH = Linear_expr_intf.FRESH
module type VAR_GEN = Linear_expr_intf.VAR_GEN
module type S = Simplex_intf.S
module type S_FULL = Simplex_intf.S_FULL
(** Low level simplex interface *)
module Make(V : VAR) :
S with type var = V.t
and type lit = V.lit
and type param = unit
and module Var_map = CCMap.Make(V)
(** High-level simplex interface *)
module Make_full_for_expr(V : VAR_GEN)
(L : Linear_expr.S with type Var.t = V.t and type Var.lit = V.lit and type C.t = Q.t)
: S_FULL with type var = V.t
and type lit = V.lit
and module L = L
and module Var_map = L.Var_map
and type param = V.Fresh.t
module Make_full(V : VAR_GEN)
: S_FULL with type var = V.t
and type lit = V.lit
and type L.var = V.t
and type param = V.Fresh.t

View file

@ -0,0 +1,124 @@
(*
copyright (c) 2014-2018, Guillaume Bury, Simon Cruanes
*)
(** {1 Modular and incremental implementation of the general simplex}. *)
(** The simplex is used as a decision procedure for linear rational arithmetic
problems.
More information can be found on the particular flavor of this
implementation at https://gbury.eu/public/papers/stage-m2.pdf
*)
module type S = sig
(** The given type of the variables *)
type var
(** A map on variables *)
module Var_map : CCMap.S with type key = var
(** Parameter required at the creation of the simplex *)
type param
type lit
(** The type of a (possibly not solved) linear system *)
type t
(** An unsatisfiability explanation is a couple [(x, expr)]. If [expr] is the
empty list, then there is a contradiction between two given bounds of [x].
Else, the explanation is an equality [x = expr] that is valid
(it can be derived from the original equations of the system) from which a
bound can be deduced which contradicts an already given bound of the
system. *)
type cert = {
cert_var: var;
cert_expr: (Q.t * var) list;
cert_core: lit list;
}
(** Generic type returned when solving the simplex. A solution is a list of
bindings that satisfies all the constraints inside the system. If the
system is unsatisfiable, an explanation of type ['cert] is returned. *)
type res =
| Solution of Q.t Var_map.t
| Unsatisfiable of cert
(** {3 Simplex construction} *)
(** The empty system.
@param fresh the state for generating fresh variables on demand. *)
val create : param -> t
(** Returns a copy of the given system *)
val copy : t -> t
(** [add_eq s (x, eq)] adds the equation [x=eq] to [s] *)
val add_eq : t -> var * (Q.t * var) list -> unit
(** [add_bounds (x, lower, upper)] adds to [s]
the bounds [lower] and [upper] for the given variable [x].
If the bound is loose on one side
(no upper bounds for instance), the values [Q.inf] and
[Q.minus_inf] can be used. By default, in a system, all variables
have no bounds, i.e have lower bound [Q.minus_inf] and upper bound
[Q.inf].
Optional parameters allow to make the the bounds strict. Defaults to false,
so that bounds are large by default. *)
val add_bounds : t -> ?strict_lower:bool -> ?strict_upper:bool -> var * Q.t * Q.t -> unit
val add_lower_bound : t -> ?strict:bool -> var -> Q.t -> unit
val add_upper_bound : t -> ?strict:bool -> var -> Q.t -> unit
(** {3 Simplex solving} *)
(** [solve s] solves the system [s] and returns a solution, if one exists.
This function may change the internal representation of the system to
that of an equivalent one
(permutation of basic and non basic variables and pivot operation
on the tableaux).
*)
val solve : t -> res
val check_cert :
t ->
cert ->
[`Ok | `Bad_bounds of string * string | `Diff_not_0 of Q.t Var_map.t]
(** checks that the certificat indeed yields to a contradiction
in the current state of the simplex.
@return [`Ok] if the certificate is valid. *)
(* TODO: push/pop? at least on bounds *)
val pp_cert : cert CCFormat.printer
val pp_full_state : t CCFormat.printer
(**/**)
val check_invariants : t -> bool (* check that all invariants hold *)
val matrix_pp_width : int ref (* horizontal filling when we print the matrix *)
(**/**)
end
(* TODO: benchmark
- copy current implem;
- move random generator somewhere shared;
- compare cur & old implem;
- optimize (remove find_expr?))
*)
module type S_FULL = sig
include S
module L : Linear_expr_intf.S
with type C.t = Q.t and type Var.t = var and type Var.lit = lit
type op = Predicate.t = Leq | Geq | Lt | Gt | Eq | Neq
type constr = L.Constr.t
val add_constr : t -> constr -> unit
(** Add a constraint to a simplex state. *)
end

View file

@ -306,6 +306,7 @@ module Th_lra = Sidekick_arith_lra.Make(struct
module S = Solver
module T = BT.Term
type term = S.T.Term.t
type ty = S.T.Ty.t
let mk_lra = T.lra
let view_as_lra t = match T.view t with
@ -313,6 +314,8 @@ module Th_lra = Sidekick_arith_lra.Make(struct
| T.Eq (a,b) when Ty.equal (T.ty a) Ty.real -> LRA_pred (Eq, a, b)
| _ -> LRA_other t
let ty_lra _st = Ty.real
module Gensym = struct
type t = {
tst: T.state;
@ -320,6 +323,8 @@ module Th_lra = Sidekick_arith_lra.Make(struct
}
let create tst : t = {tst; fresh=0}
let tst self = self.tst
let copy s = {s with tst=s.tst}
let fresh_term (self:t) ~pre (ty:Ty.t) : T.t =
let name = Printf.sprintf "_sk_lra_%s%d" pre self.fresh in