mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 11:15:43 -05:00
wip(LRA): new simplex implementation
This commit is contained in:
parent
2a0e14a635
commit
dd4719b1a4
1 changed files with 624 additions and 9 deletions
|
|
@ -25,6 +25,7 @@ end
|
||||||
|
|
||||||
module type S = sig
|
module type S = sig
|
||||||
module V : VAR
|
module V : VAR
|
||||||
|
module V_map : CCMap.S with type key = V.t
|
||||||
|
|
||||||
type num = Q.t (** Numbers *)
|
type num = Q.t (** Numbers *)
|
||||||
|
|
||||||
|
|
@ -38,24 +39,72 @@ module type S = sig
|
||||||
rhs: num;
|
rhs: num;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val leq : V.t -> num -> t
|
||||||
|
val lt : V.t -> num -> t
|
||||||
|
val geq : V.t -> num -> t
|
||||||
|
val gt : V.t -> num -> t
|
||||||
|
|
||||||
val pp : t Fmt.printer
|
val pp : t Fmt.printer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
module Subst : sig
|
||||||
|
type t = num V_map.t
|
||||||
|
val pp : t Fmt.printer
|
||||||
|
val to_string : t -> string
|
||||||
|
end
|
||||||
|
|
||||||
type t
|
type t
|
||||||
|
|
||||||
val create : unit -> t
|
val create : unit -> t
|
||||||
(** Create a new simplex. *)
|
(** Create a new simplex. *)
|
||||||
|
|
||||||
|
val push_level : t -> unit
|
||||||
|
|
||||||
|
val pop_levels : t -> int -> unit
|
||||||
|
|
||||||
val define : t -> V.t -> (num * V.t) list -> unit
|
val define : t -> V.t -> (num * V.t) list -> unit
|
||||||
(** Define a basic variable in terms of other variables.
|
(** Define a basic variable in terms of other variables.
|
||||||
This is useful to "name" a linear expression and get back a variable
|
This is useful to "name" a linear expression and get back a variable
|
||||||
that can be used in a {!Constraint.t} *)
|
that can be used in a {!Constraint.t} *)
|
||||||
|
|
||||||
|
type unsat_cert = {
|
||||||
|
cert_bounds: (num * V.lit) list;
|
||||||
|
cert_defs: (V.t * (num * V.t) list) list; (* definitions used *)
|
||||||
|
}
|
||||||
|
|
||||||
|
module Unsat_cert : sig
|
||||||
|
type t = unsat_cert
|
||||||
|
val pp : t Fmt.printer
|
||||||
|
end
|
||||||
|
|
||||||
|
exception E_unsat of Unsat_cert.t
|
||||||
|
|
||||||
|
val add_constraint : t -> Constraint.t -> V.lit -> unit
|
||||||
|
(** Add a constraint to the simplex.
|
||||||
|
@raise Unsat if it's immediately obvious that this is not satisfiable. *)
|
||||||
|
|
||||||
|
val check_exn : t -> unit
|
||||||
|
(** Check the whole simplex for satisfiability.
|
||||||
|
@raise Unsat if the constraints are not satisfiable. *)
|
||||||
|
|
||||||
|
type result =
|
||||||
|
| Sat of Subst.t
|
||||||
|
| Unsat of Unsat_cert.t
|
||||||
|
|
||||||
|
val check : t -> result
|
||||||
|
(** Call {!check_exn} and return a model or a proof of unsat. *)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(* TODO(optim): page 14, paragraph 2: we could detect which variables occur in no
|
||||||
|
atom after preprocessing; then these variables can be "inlined" (removed
|
||||||
|
by Gaussian elimination) as a preprocessing step, and this removes one column
|
||||||
|
and maybe one row if it was basic. *)
|
||||||
|
|
||||||
module Make(Var: VAR)
|
module Make(Var: VAR)
|
||||||
: S with module V = Var
|
: S with module V = Var
|
||||||
= struct
|
= struct
|
||||||
module V = Var
|
module V = Var
|
||||||
|
module V_map = CCMap.Make(Var)
|
||||||
|
|
||||||
type num = Q.t (** Numbers *)
|
type num = Q.t (** Numbers *)
|
||||||
|
|
||||||
|
|
@ -72,6 +121,20 @@ module Make(Var: VAR)
|
||||||
let pp out (self:t) =
|
let pp out (self:t) =
|
||||||
Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs
|
Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs
|
||||||
(Op.to_string self.op) Q.pp_print self.rhs
|
(Op.to_string self.op) Q.pp_print self.rhs
|
||||||
|
|
||||||
|
let leq lhs rhs = {lhs;rhs;op=Op.Leq}
|
||||||
|
let lt lhs rhs = {lhs;rhs;op=Op.Lt}
|
||||||
|
let geq lhs rhs = {lhs;rhs;op=Op.Geq}
|
||||||
|
let gt lhs rhs = {lhs;rhs;op=Op.Gt}
|
||||||
|
end
|
||||||
|
|
||||||
|
module Subst = struct
|
||||||
|
type t = num V_map.t
|
||||||
|
let pp out (self:t) : unit =
|
||||||
|
let pp_pair out (v,n) =
|
||||||
|
Fmt.fprintf out "(@[%a := %a@])" V.pp v Q.pp_print n in
|
||||||
|
Fmt.fprintf out "{@[%a@]}" (Fmt.iter pp_pair) (V_map.to_iter self)
|
||||||
|
let to_string = Fmt.to_string pp
|
||||||
end
|
end
|
||||||
|
|
||||||
(** An extended rational, used to introduce ε so we can use strict
|
(** An extended rational, used to introduce ε so we can use strict
|
||||||
|
|
@ -91,19 +154,24 @@ module Make(Var: VAR)
|
||||||
let zero : t = {base=Q.zero; eps_factor=Q.zero}
|
let zero : t = {base=Q.zero; eps_factor=Q.zero}
|
||||||
|
|
||||||
let[@inline] make base eps_factor : t = {base; eps_factor}
|
let[@inline] make base eps_factor : t = {base; eps_factor}
|
||||||
|
let[@inline] make_q x = make x Q.zero
|
||||||
let[@inline] base t = t.base
|
let[@inline] base t = t.base
|
||||||
let[@inline] eps_factor t = t.eps_factor
|
let[@inline] eps_factor t = t.eps_factor
|
||||||
let[@inline] mul k e = make Q.(k * e.base) Q.(k * e.eps_factor)
|
let[@inline] ( * ) k e = make Q.(k * e.base) Q.(k * e.eps_factor)
|
||||||
let[@inline] sum e1 e2 = make Q.(e1.base + e2.base) Q.(e1.eps_factor + e2.eps_factor)
|
let[@inline] ( / ) e k = make Q.(e.base / k) Q.(e.eps_factor / k)
|
||||||
|
let[@inline] (+) e1 e2 = make Q.(e1.base + e2.base) Q.(e1.eps_factor + e2.eps_factor)
|
||||||
|
let[@inline] (-) e1 e2 = make Q.(e1.base - e2.base) Q.(e1.eps_factor - e2.eps_factor)
|
||||||
let[@inline] compare e1 e2 = match Q.compare e1.base e2.base with
|
let[@inline] compare e1 e2 = match Q.compare e1.base e2.base with
|
||||||
| 0 -> Q.compare e1.eps_factor e2.eps_factor
|
| 0 -> Q.compare e1.eps_factor e2.eps_factor
|
||||||
| x -> x
|
| x -> x
|
||||||
|
|
||||||
let[@inline] lt a b = compare a b < 0
|
let[@inline] (<) a b = compare a b < 0
|
||||||
let[@inline] gt a b = compare a b > 0
|
let[@inline] (<=) a b = compare a b <= 0
|
||||||
|
let[@inline] (>) a b = compare a b > 0
|
||||||
|
let[@inline] (>=) a b = compare a b >= 0
|
||||||
|
|
||||||
let[@inline] min x y = if compare x y <= 0 then x else y
|
let[@inline] min x y = if x <= y then x else y
|
||||||
let[@inline] max x y = if compare x y >= 0 then x else y
|
let[@inline] max x y = if x >= y then x else y
|
||||||
|
|
||||||
let[@inline] evaluate (epsilon:Q.t) (e:t) : Q.t = Q.(e.base + epsilon * e.eps_factor)
|
let[@inline] evaluate (epsilon:Q.t) (e:t) : Q.t = Q.(e.base + epsilon * e.eps_factor)
|
||||||
|
|
||||||
|
|
@ -115,9 +183,556 @@ module Make(Var: VAR)
|
||||||
Q.pp_print (base e) Q.pp_print (eps_factor e)
|
Q.pp_print (base e) Q.pp_print (eps_factor e)
|
||||||
end
|
end
|
||||||
|
|
||||||
type t = unit
|
type var_idx = int
|
||||||
|
|
||||||
let create () : t = ()
|
(** {2 Matrix}
|
||||||
|
The matrix [A] from the paper, with m rows and n columns.
|
||||||
|
- m is the number of basic variables (defined in terms of non-basic variables)
|
||||||
|
- n is the total number of variables, basic and non-basic.
|
||||||
|
|
||||||
let define _ = assert false (* TODO *)
|
The invariant that the simplex maintains at all times is that [Ax = 0],
|
||||||
|
where [x] is the vector of values of all variables (basic and non-basic).
|
||||||
|
*)
|
||||||
|
|
||||||
|
module Matrix : sig
|
||||||
|
type t
|
||||||
|
|
||||||
|
val create : unit -> t
|
||||||
|
|
||||||
|
val pp : t Fmt.printer
|
||||||
|
val to_string : t -> string
|
||||||
|
|
||||||
|
val n_rows : t -> int
|
||||||
|
val n_cols : t -> int
|
||||||
|
|
||||||
|
val add_column : t -> unit
|
||||||
|
(** Add a non-basic variable (only adds a column) *)
|
||||||
|
|
||||||
|
val add_row_and_column : t -> int
|
||||||
|
(** Add a basic variable. returns the row index. *)
|
||||||
|
|
||||||
|
val get_row_var_idx : t -> int -> var_idx
|
||||||
|
(** Index of the basic variable for row [i] *)
|
||||||
|
|
||||||
|
val set_row_var_idx : t -> int -> var_idx -> unit
|
||||||
|
(** Set index of the basic variable for row [i] *)
|
||||||
|
|
||||||
|
val get : t -> int -> int -> num
|
||||||
|
|
||||||
|
val set : t -> int -> int -> num -> unit
|
||||||
|
|
||||||
|
val add : t -> int -> int -> num -> unit
|
||||||
|
|
||||||
|
val mult : t -> int -> int -> num -> unit
|
||||||
|
end = struct
|
||||||
|
type row = {
|
||||||
|
mutable var_idx: var_idx;
|
||||||
|
cols: num Vec.t;
|
||||||
|
}
|
||||||
|
type t = {
|
||||||
|
rows: row Vec.t
|
||||||
|
}
|
||||||
|
|
||||||
|
let create() : t = {rows=Vec.create()}
|
||||||
|
|
||||||
|
let[@inline] n_rows self = Vec.size self.rows
|
||||||
|
let n_cols self =
|
||||||
|
if n_rows self=0 then 0
|
||||||
|
else Vec.size (Vec.get self.rows 0).cols
|
||||||
|
|
||||||
|
let pp out self =
|
||||||
|
Fmt.fprintf out "{@[<v>";
|
||||||
|
Vec.iteri (fun i row ->
|
||||||
|
Fmt.fprintf out "@[<hov2>%-5d: %a@]@," i
|
||||||
|
(Fmt.iter ~sep:(Fmt.return "@ ") Q.pp_print) (Vec.to_seq row.cols))
|
||||||
|
self.rows;
|
||||||
|
Fmt.fprintf out "@]}"
|
||||||
|
let to_string = Fmt.to_string pp
|
||||||
|
|
||||||
|
let add_column self =
|
||||||
|
Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows
|
||||||
|
|
||||||
|
let add_row_and_column self : int =
|
||||||
|
let n = n_rows self in
|
||||||
|
let j = n_cols self in
|
||||||
|
add_column self;
|
||||||
|
let row = {var_idx=j; cols=Vec.make (j+1) Q.zero} in
|
||||||
|
Vec.push self.rows row;
|
||||||
|
n
|
||||||
|
|
||||||
|
let[@inline] get_row_var_idx self i = (Vec.get self.rows i).var_idx
|
||||||
|
let[@inline] set_row_var_idx self i n = (Vec.get self.rows i).var_idx <- n
|
||||||
|
|
||||||
|
let[@inline] get self i j : num = Vec.get (Vec.get self.rows i).cols j
|
||||||
|
|
||||||
|
let[@inline] set self i j n : unit =
|
||||||
|
let r = Vec.get self.rows i in
|
||||||
|
Vec.set r.cols j n
|
||||||
|
|
||||||
|
(* add [n] to [m_ij] *)
|
||||||
|
let add self i j n : unit =
|
||||||
|
let r = Vec.get self.rows i in
|
||||||
|
Vec.set r.cols j Q.(Vec.get r.cols j + n)
|
||||||
|
|
||||||
|
(* multiply [m_ij] by [c] *)
|
||||||
|
let mult self i j c : unit =
|
||||||
|
let r = Vec.get self.rows i in
|
||||||
|
let n_j = Vec.get r.cols j in
|
||||||
|
if Q.sign n_j <> 0 then (
|
||||||
|
Vec.set r.cols j Q.(n_j * c)
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
type bound = {
|
||||||
|
b_val: erat;
|
||||||
|
b_lit: Var.lit;
|
||||||
|
}
|
||||||
|
|
||||||
|
type var_state = {
|
||||||
|
var: V.t;
|
||||||
|
mutable value: erat;
|
||||||
|
idx: int; (* index in {!t.vars} *)
|
||||||
|
mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *)
|
||||||
|
mutable l_bound: bound option;
|
||||||
|
mutable u_bound: bound option;
|
||||||
|
}
|
||||||
|
|
||||||
|
module Var_state = struct
|
||||||
|
type t = var_state
|
||||||
|
|
||||||
|
let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0
|
||||||
|
let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0
|
||||||
|
let pp out self =
|
||||||
|
Fmt.fprintf out "(@[var %a@ :basic %B@ :value %a@ :lbound %a@ :ubound %a@])"
|
||||||
|
Var.pp self.var (is_basic self) Erat.pp self.value
|
||||||
|
Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.l_bound
|
||||||
|
Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.u_bound
|
||||||
|
end
|
||||||
|
|
||||||
|
type t = {
|
||||||
|
matrix: Matrix.t;
|
||||||
|
vars: var_state Vec.t; (* index -> var with this index *)
|
||||||
|
mutable var_tbl: var_state V_map.t; (* var -> its state *)
|
||||||
|
bound_stack: (var_state * [`Upper|`Lower] * bound option) Backtrack_stack.t;
|
||||||
|
}
|
||||||
|
|
||||||
|
let push_level self : unit =
|
||||||
|
Backtrack_stack.push_level self.bound_stack
|
||||||
|
|
||||||
|
let pop_levels self n : unit =
|
||||||
|
Backtrack_stack.pop_levels self.bound_stack n
|
||||||
|
~f:(fun (var, kind, bnd) ->
|
||||||
|
match kind with
|
||||||
|
| `Upper -> var.u_bound <- bnd
|
||||||
|
| `Lower -> var.l_bound <- bnd);
|
||||||
|
()
|
||||||
|
|
||||||
|
let pp out (self:t) : unit =
|
||||||
|
Fmt.fprintf out "(@[simplex@ :vars %a@ :matrix %a@])"
|
||||||
|
(Vec.pp Var_state.pp) self.vars
|
||||||
|
Matrix.pp self.matrix
|
||||||
|
|
||||||
|
let[@inline] has_var_ (self:t) x : bool = V_map.mem x self.var_tbl
|
||||||
|
let[@inline] find_var_ (self:t) x : var_state =
|
||||||
|
try V_map.find x self.var_tbl
|
||||||
|
with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x
|
||||||
|
|
||||||
|
let define (self:t) (x:V.t) (le:_ list) : unit =
|
||||||
|
assert (not (has_var_ self x));
|
||||||
|
Log.debugf 5 (fun k->k "(@[simplex.define@ %a@ :le %a@])"
|
||||||
|
Var.pp x Fmt.(Dump.(list @@ pair Q.pp_print Var.pp)) le);
|
||||||
|
let n = Matrix.add_row_and_column self.matrix in
|
||||||
|
let vs = {
|
||||||
|
var=x; value=Erat.zero;
|
||||||
|
idx=Vec.size self.vars;
|
||||||
|
basic_idx=n;
|
||||||
|
l_bound=None;
|
||||||
|
u_bound=None;
|
||||||
|
} in
|
||||||
|
Vec.push self.vars vs;
|
||||||
|
self.var_tbl <- V_map.add x vs self.var_tbl;
|
||||||
|
(* set coefficients in the matrix's new row: [-x + le = 0] *)
|
||||||
|
Matrix.set self.matrix n vs.idx Q.minus_one;
|
||||||
|
List.iter
|
||||||
|
(fun (coeff,v2) ->
|
||||||
|
let vs2 = find_var_ self v2 in
|
||||||
|
Matrix.add self.matrix n vs2.idx coeff;
|
||||||
|
) le;
|
||||||
|
()
|
||||||
|
|
||||||
|
(* find the state for [x], or add [x] as a non-basic variable *)
|
||||||
|
let find_or_create_n_basic_var_ (self:t) (x:V.t) : var_state =
|
||||||
|
try V_map.find x self.var_tbl
|
||||||
|
with Not_found ->
|
||||||
|
Matrix.add_column self.matrix;
|
||||||
|
let vs = {
|
||||||
|
idx=Vec.size self.vars;
|
||||||
|
basic_idx= -1;
|
||||||
|
var=x;
|
||||||
|
l_bound=None;
|
||||||
|
u_bound=None;
|
||||||
|
value=Erat.zero;
|
||||||
|
} in
|
||||||
|
assert (Var_state.is_n_basic vs);
|
||||||
|
self.var_tbl <- V_map.add x vs self.var_tbl;
|
||||||
|
Vec.push self.vars vs;
|
||||||
|
vs
|
||||||
|
|
||||||
|
(* update the simplex so that non-basic [x] is now assigned value [n].
|
||||||
|
See page 14, figure 3.1.
|
||||||
|
*)
|
||||||
|
let update_n_basic (self:t) (x:var_state) (v:erat) : unit =
|
||||||
|
assert (Var_state.is_n_basic x);
|
||||||
|
let m = self.matrix in
|
||||||
|
let i = x.idx in
|
||||||
|
|
||||||
|
let diff = Erat.(v - x.value) in
|
||||||
|
|
||||||
|
for j=0 to Matrix.n_rows m - 1 do
|
||||||
|
let a_ji = Matrix.get m j i in
|
||||||
|
let x_j = Vec.get self.vars (Matrix.get_row_var_idx m j) in
|
||||||
|
assert (Var_state.is_basic x_j);
|
||||||
|
(* value of [x_j] by [a_ji * diff] *)
|
||||||
|
x_j.value <- Erat.(x_j.value + a_ji * diff);
|
||||||
|
done;
|
||||||
|
x.value <- v;
|
||||||
|
()
|
||||||
|
|
||||||
|
(* pivot [x_i] (basic) and [x_j] (non-basic), setting value of [x_i]
|
||||||
|
to [v] at the same time.
|
||||||
|
See page 14, figure 3.1 *)
|
||||||
|
let pivot_and_update (self:t) (x_i:var_state) (x_j:var_state) (v:erat) : unit =
|
||||||
|
assert (Var_state.is_basic x_i);
|
||||||
|
assert (Var_state.is_n_basic x_j);
|
||||||
|
let m = self.matrix in
|
||||||
|
let a_ij = Matrix.get m x_i.basic_idx x_j.idx in
|
||||||
|
assert (Q.sign a_ij <> 0);
|
||||||
|
let theta = Erat.((v - x_i.value) / a_ij) in
|
||||||
|
x_i.value <- v;
|
||||||
|
x_j.value <- Erat.(x_j.value + theta);
|
||||||
|
|
||||||
|
for k=0 to Matrix.n_rows m-1 do
|
||||||
|
if k <> x_i.basic_idx then (
|
||||||
|
let x_k = Vec.get self.vars (Matrix.get_row_var_idx m k) in
|
||||||
|
let a_kj = Matrix.get m x_k.basic_idx x_j.idx in
|
||||||
|
x_k.value <- Erat.(x_k.value + a_kj * theta);
|
||||||
|
)
|
||||||
|
done;
|
||||||
|
|
||||||
|
(* now pivot the variables so that [x_j]'s coeff is -1 *)
|
||||||
|
let new_coeff = Q.(minus_one / a_ij) in
|
||||||
|
for k=0 to Matrix.n_cols m-1 do
|
||||||
|
Matrix.mult m x_i.idx k new_coeff;
|
||||||
|
done;
|
||||||
|
x_j.basic_idx <- x_i.basic_idx;
|
||||||
|
x_i.basic_idx <- -1;
|
||||||
|
Matrix.set_row_var_idx m x_j.basic_idx x_j.idx;
|
||||||
|
|
||||||
|
assert (Var_state.is_basic x_j);
|
||||||
|
assert (Var_state.is_n_basic x_i);
|
||||||
|
|
||||||
|
()
|
||||||
|
|
||||||
|
type unsat_cert = {
|
||||||
|
cert_bounds: (num * V.lit) list;
|
||||||
|
cert_defs: (V.t * (num * V.t) list) list; (* definitions used *)
|
||||||
|
}
|
||||||
|
|
||||||
|
module Unsat_cert = struct
|
||||||
|
type t = unsat_cert
|
||||||
|
|
||||||
|
let pp out (self:t) =
|
||||||
|
let pp_bnd out (n,lit) =
|
||||||
|
Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit Q.pp_print n
|
||||||
|
in
|
||||||
|
Fmt.fprintf out "(@[cert@ :bounds %a@ :defs %a@])"
|
||||||
|
Fmt.(Dump.list pp_bnd) self.cert_bounds
|
||||||
|
Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair Q.pp_print V.pp)))) self.cert_defs
|
||||||
|
|
||||||
|
let mk ~defs ~bounds : t =
|
||||||
|
{ cert_defs=defs; cert_bounds=bounds }
|
||||||
|
end
|
||||||
|
|
||||||
|
exception E_unsat of Unsat_cert.t
|
||||||
|
|
||||||
|
let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit =
|
||||||
|
let open Constraint in
|
||||||
|
let vs = find_or_create_n_basic_var_ self c.lhs in
|
||||||
|
let is_lower_bnd, new_bnd_val =
|
||||||
|
match c.op with
|
||||||
|
| Leq -> false, Erat.make_q c.rhs
|
||||||
|
| Lt -> false, Erat.make c.rhs Q.minus_one
|
||||||
|
| Geq -> true, Erat.make_q c.rhs
|
||||||
|
| Gt -> true, Erat.make c.rhs Q.one
|
||||||
|
in
|
||||||
|
let new_bnd = {b_val=new_bnd_val; b_lit=lit} in
|
||||||
|
if is_lower_bnd then (
|
||||||
|
begin match vs.l_bound, vs.u_bound with
|
||||||
|
| _, Some upper when Erat.(new_bnd.b_val > upper.b_val) ->
|
||||||
|
(* [b_val <= x <= upper], but [b_val > upper] *)
|
||||||
|
let cert = Unsat_cert.mk ~defs:[]
|
||||||
|
~bounds:[(Q.one, upper.b_lit); (Q.one, lit)] in
|
||||||
|
raise (E_unsat cert)
|
||||||
|
| Some lower, _ when Erat.(lower.b_val >= new_bnd.b_val) ->
|
||||||
|
() (* subsumed by existing constraint, do nothing *)
|
||||||
|
| _ ->
|
||||||
|
(* save current bound *)
|
||||||
|
Backtrack_stack.push self.bound_stack (vs, `Lower, vs.l_bound);
|
||||||
|
vs.l_bound <- Some new_bnd;
|
||||||
|
|
||||||
|
if Var_state.is_n_basic vs &&
|
||||||
|
Erat.(vs.value < new_bnd.b_val) then (
|
||||||
|
(* line 5: need to update non-basic variable *)
|
||||||
|
update_n_basic self vs new_bnd.b_val
|
||||||
|
)
|
||||||
|
end
|
||||||
|
) else (
|
||||||
|
begin match vs.l_bound, vs.u_bound with
|
||||||
|
| Some lower, _ when Erat.(new_bnd.b_val < lower.b_val) ->
|
||||||
|
(* [lower <= x <= b_val], but [b_val < lower] *)
|
||||||
|
let cert = Unsat_cert.mk ~defs:[]
|
||||||
|
~bounds:[(Q.one, lower.b_lit); (Q.one, lit)] in
|
||||||
|
raise (E_unsat cert)
|
||||||
|
| _, Some upper when Erat.(upper.b_val <= new_bnd.b_val) ->
|
||||||
|
() (* subsumed *)
|
||||||
|
| _ ->
|
||||||
|
Backtrack_stack.push self.bound_stack (vs, `Upper, vs.u_bound);
|
||||||
|
vs.u_bound <- Some new_bnd;
|
||||||
|
|
||||||
|
if Var_state.is_n_basic vs &&
|
||||||
|
Erat.(vs.value > new_bnd.b_val) then (
|
||||||
|
(* line 5: need to update non-basic variable *)
|
||||||
|
update_n_basic self vs new_bnd.b_val
|
||||||
|
)
|
||||||
|
end
|
||||||
|
)
|
||||||
|
|
||||||
|
(* try to find basic variable that doesn't respect one of its bounds *)
|
||||||
|
let find_basic_var_ (self:t) : (var_state * [`Lower | `Upper] * bound) option =
|
||||||
|
let n = Matrix.n_rows self.matrix in
|
||||||
|
let rec aux i =
|
||||||
|
if i >= n then None
|
||||||
|
else (
|
||||||
|
let x_i = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in
|
||||||
|
let v_i = x_i.value in
|
||||||
|
match x_i.l_bound, x_i.u_bound with
|
||||||
|
| Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l)
|
||||||
|
| _, Some u when Erat.(u.b_val < v_i) -> Some (x_i, `Upper, u)
|
||||||
|
| _ -> (aux[@tailcall]) (i+1)
|
||||||
|
)
|
||||||
|
in
|
||||||
|
aux 0
|
||||||
|
|
||||||
|
let find_n_basic_var_ (self:t) ~f : var_state option =
|
||||||
|
let rec aux j =
|
||||||
|
if j >= Vec.size self.vars then None
|
||||||
|
else (
|
||||||
|
let x_j = Vec.get self.vars j in
|
||||||
|
if Var_state.is_n_basic x_j && f x_j then Some x_j
|
||||||
|
else aux (j+1)
|
||||||
|
)
|
||||||
|
in
|
||||||
|
aux 0
|
||||||
|
|
||||||
|
(* true if [x.value < x.u_bound] *)
|
||||||
|
let has_upper_slack (x:var_state) : bool =
|
||||||
|
match x.u_bound with
|
||||||
|
| None -> true
|
||||||
|
| Some bnd -> Erat.(x.value < bnd.b_val)
|
||||||
|
|
||||||
|
(* true if [x.value > x.l_bound] *)
|
||||||
|
let has_lower_slack (x:var_state) : bool =
|
||||||
|
match x.l_bound with
|
||||||
|
| None -> true
|
||||||
|
| Some bnd -> Erat.(x.value > bnd.b_val)
|
||||||
|
|
||||||
|
(* make a certificate from the row of basic variable [x_i] which is outside
|
||||||
|
one of its bound, and whose row is tight on all non-basic variables *)
|
||||||
|
let cert_of_row_ (self:t) (x_i:var_state) ~is_lower : unsat_cert =
|
||||||
|
Log.debugf 50 (fun k->k "(@[simplex.cert-of-row[lower: %B]@ x_i=%a@ %a@])"
|
||||||
|
is_lower Var_state.pp x_i pp self);
|
||||||
|
assert (Var_state.is_basic x_i);
|
||||||
|
(* TODO: store initial definition for each matrix row *)
|
||||||
|
let defs = [] in
|
||||||
|
let bounds = [] in (* TODO: use all bounds in the row *)
|
||||||
|
Unsat_cert.mk ~defs ~bounds
|
||||||
|
|
||||||
|
(* main satisfiability check.
|
||||||
|
page 15, figure 3.2 *)
|
||||||
|
let check_exn (self:t) : unit =
|
||||||
|
let exception Stop in
|
||||||
|
Log.debugf 5 (fun k->k "(@[simplex2.check@ %a@])" Matrix.pp self.matrix);
|
||||||
|
|
||||||
|
let m = self.matrix in
|
||||||
|
try
|
||||||
|
while true do
|
||||||
|
Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" Matrix.pp self.matrix);
|
||||||
|
|
||||||
|
(* basic variable that doesn't respect its bound *)
|
||||||
|
let x_i, is_lower, bnd = match find_basic_var_ self with
|
||||||
|
| Some (x, `Lower, bnd) -> x, true, bnd
|
||||||
|
| Some (x, `Upper, bnd) -> x, false, bnd
|
||||||
|
| None -> raise_notrace Stop (* line 4: we are done *)
|
||||||
|
in
|
||||||
|
|
||||||
|
if is_lower then (
|
||||||
|
(* line 5 *)
|
||||||
|
let x_j =
|
||||||
|
match
|
||||||
|
find_n_basic_var_ self
|
||||||
|
~f:(fun x_j ->
|
||||||
|
let a_ij = Matrix.get m x_i.basic_idx x_j.idx in
|
||||||
|
(Q.sign a_ij > 0 && has_upper_slack x_j) ||
|
||||||
|
(Q.sign a_ij < 0 && has_lower_slack x_j))
|
||||||
|
with
|
||||||
|
| Some x -> x
|
||||||
|
| None ->
|
||||||
|
let cert = cert_of_row_ self x_i ~is_lower:true in
|
||||||
|
raise (E_unsat cert)
|
||||||
|
in
|
||||||
|
assert (Var_state.is_n_basic x_j);
|
||||||
|
|
||||||
|
(* line 9 *)
|
||||||
|
pivot_and_update self x_i x_j bnd.b_val
|
||||||
|
) else (
|
||||||
|
(* line 10 *)
|
||||||
|
let x_j =
|
||||||
|
match
|
||||||
|
find_n_basic_var_ self
|
||||||
|
~f:(fun x_j ->
|
||||||
|
let a_ij = Matrix.get m x_i.basic_idx x_j.idx in
|
||||||
|
(Q.sign a_ij < 0 && has_upper_slack x_j) ||
|
||||||
|
(Q.sign a_ij > 0 && has_lower_slack x_j))
|
||||||
|
with
|
||||||
|
| Some x -> x
|
||||||
|
| None ->
|
||||||
|
let cert = cert_of_row_ self x_i ~is_lower:false in
|
||||||
|
raise (E_unsat cert)
|
||||||
|
in
|
||||||
|
assert (Var_state.is_n_basic x_j);
|
||||||
|
|
||||||
|
(* line 14 *)
|
||||||
|
pivot_and_update self x_i x_j bnd.b_val
|
||||||
|
)
|
||||||
|
done;
|
||||||
|
with Stop -> ()
|
||||||
|
|
||||||
|
let create () : t =
|
||||||
|
let self = {
|
||||||
|
matrix=Matrix.create();
|
||||||
|
vars=Vec.create();
|
||||||
|
var_tbl=V_map.empty;
|
||||||
|
bound_stack=Backtrack_stack.create();
|
||||||
|
} in
|
||||||
|
self
|
||||||
|
|
||||||
|
type result =
|
||||||
|
| Sat of Subst.t
|
||||||
|
| Unsat of unsat_cert
|
||||||
|
|
||||||
|
(* Find an epsilon that is small enough for finding a solution, yet
|
||||||
|
it must be positive.
|
||||||
|
|
||||||
|
{!Erat.t} values are used to turn strict bounds ([X > 0]) into
|
||||||
|
non-strict bounds ([X >= 0 + ε]), because the simplex algorithm
|
||||||
|
only deals with non-strict bounds.
|
||||||
|
When a solution is found, we need to turn {!Erat.t} into {!Q.t} by
|
||||||
|
finding a rational value that is small enough that it will fit into
|
||||||
|
all the intervals of [t]. This rational will be the actual value of [ε].
|
||||||
|
*)
|
||||||
|
let solve_epsilon (self:t) : Q.t =
|
||||||
|
let emax =
|
||||||
|
Iter.fold
|
||||||
|
(fun emax x ->
|
||||||
|
let {base=v; eps_factor=e_v} = x.value in
|
||||||
|
(* lower bound *)
|
||||||
|
let emax = match x.l_bound with
|
||||||
|
| Some {b_val={base=low;eps_factor=e_low};_} when Q.(e_v < e_low) ->
|
||||||
|
Q.min emax Q.((low - v) / (e_v - e_low))
|
||||||
|
| _ -> emax
|
||||||
|
in
|
||||||
|
(* upper bound *)
|
||||||
|
let emax = match x.u_bound with
|
||||||
|
| Some { b_val={base=upp;eps_factor=e_upp}; _} when Q.(e_v > e_upp) ->
|
||||||
|
min emax Q.((upp - v) / (e_v - e_upp))
|
||||||
|
| _ -> emax
|
||||||
|
in
|
||||||
|
emax)
|
||||||
|
Q.inf (Vec.to_seq self.vars)
|
||||||
|
in
|
||||||
|
if Q.compare emax Q.one >= 0 then Q.one else emax
|
||||||
|
|
||||||
|
let model_ self =
|
||||||
|
let eps = solve_epsilon self in
|
||||||
|
let subst =
|
||||||
|
Vec.to_seq self.vars
|
||||||
|
|> Iter.fold
|
||||||
|
(fun subst x ->
|
||||||
|
let {base;eps_factor} = x.value in
|
||||||
|
let v = Q.(base + eps * eps_factor) in
|
||||||
|
V_map.add x.var v subst)
|
||||||
|
V_map.empty
|
||||||
|
in
|
||||||
|
subst
|
||||||
|
|
||||||
|
let check (self:t) : result =
|
||||||
|
try
|
||||||
|
check_exn self;
|
||||||
|
let m = model_ self in
|
||||||
|
Sat m
|
||||||
|
with E_unsat c -> Unsat c
|
||||||
|
|
||||||
|
(* TODO
|
||||||
|
|
||||||
|
(* maybe invert bounds, if [c < 0] *)
|
||||||
|
let scale_bounds c (l,u) : bound * bound =
|
||||||
|
match Q.compare c Q.zero with
|
||||||
|
| 0 ->
|
||||||
|
let b = { value = Erat.zero; reason = None; } in
|
||||||
|
b, b
|
||||||
|
| n when n<0 ->
|
||||||
|
{ u with value = Erat.mul c u.value; },
|
||||||
|
{ l with value = Erat.mul c l.value; }
|
||||||
|
| _ ->
|
||||||
|
{ l with value = Erat.mul c l.value; },
|
||||||
|
{ u with value = Erat.mul c u.value; }
|
||||||
|
|
||||||
|
|
||||||
|
let check_cert (t:t) (c:cert) =
|
||||||
|
let x = M.get c.cert_var t.var_states |> CCOpt.get_lazy (fun()->assert false) in
|
||||||
|
let { value = low_x; reason = low_x_reason; } = x.l_bound in
|
||||||
|
let { value = up_x; reason = upp_x_reason; } = x.u_bound in
|
||||||
|
begin match c.cert_expr with
|
||||||
|
| [] ->
|
||||||
|
if Erat.compare low_x up_x > 0
|
||||||
|
then `Ok (add_to_unsat_core (add_to_unsat_core [] low_x_reason) upp_x_reason)
|
||||||
|
else `Bad_bounds (str_of_erat low_x, str_of_erat up_x)
|
||||||
|
| expr ->
|
||||||
|
let e0 = deref_var_ t x (Q.neg Q.one) M.empty in
|
||||||
|
(* compute bounds for the expression [c.cert_expr],
|
||||||
|
and also compute [c.cert_expr - x] to check if it's 0] *)
|
||||||
|
let low, low_unsat_core, up, up_unsat_core, expr_minus_x =
|
||||||
|
List.fold_left
|
||||||
|
(fun (l, luc, u, uuc, expr_minus_x) (c, y) ->
|
||||||
|
let y = M.get y t.var_states |> CCOpt.get_lazy (fun ()->assert false) in
|
||||||
|
let ly, uy = scale_bounds c (get_bounds y) in
|
||||||
|
assert (Erat.compare ly.value uy.value <= 0);
|
||||||
|
let expr_minus_x = deref_var_ t y c expr_minus_x in
|
||||||
|
let luc = add_to_unsat_core luc ly.reason in
|
||||||
|
let uuc = add_to_unsat_core uuc uy.reason in
|
||||||
|
Erat.sum l ly.value, luc, Erat.sum u uy.value, uuc, expr_minus_x)
|
||||||
|
(Erat.zero, [], Erat.zero, [], e0)
|
||||||
|
expr
|
||||||
|
in
|
||||||
|
(* check that the expanded expression is [x], and that
|
||||||
|
one of the bounds on [x] is incompatible with bounds of [c.cert_expr] *)
|
||||||
|
if M.is_empty expr_minus_x then (
|
||||||
|
if Erat.compare low_x up > 0
|
||||||
|
then `Ok (add_to_unsat_core up_unsat_core low_x_reason)
|
||||||
|
else if Erat.compare up_x low < 0
|
||||||
|
then `Ok (add_to_unsat_core low_unsat_core upp_x_reason)
|
||||||
|
else `Bad_bounds (str_of_erat low, str_of_erat up)
|
||||||
|
) else `Diff_not_0 expr_minus_x
|
||||||
|
end
|
||||||
|
*)
|
||||||
end
|
end
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue