fix(lra): many fixes in simplex; some fixme/todo

This commit is contained in:
Simon Cruanes 2021-02-12 19:42:16 -05:00
parent 5fc8d746c2
commit f226c6b820

View file

@ -5,6 +5,8 @@
de Moura and Dutertre.
*)
open CCMonomorphic
module type VAR = Linear_expr_intf.VAR
(** {2 Basic operator} *)
@ -79,6 +81,9 @@ module type S = sig
exception E_unsat of Unsat_cert.t
val add_var : t -> V.t -> unit
(** Make sure the variable exists in the simplex. *)
val add_constraint : t -> Constraint.t -> V.lit -> unit
(** Add a constraint to the simplex.
@raise Unsat if it's immediately obvious that this is not satisfiable. *)
@ -93,6 +98,11 @@ module type S = sig
val check : t -> result
(** Call {!check_exn} and return a model or a proof of unsat. *)
(**/**)
val _check_invariants : t -> unit
(* check internal invariants *)
(**/**)
end
(* TODO(optim): page 14, paragraph 2: we could detect which variables occur in no
@ -108,6 +118,9 @@ module Make(Var: VAR)
type num = Q.t (** Numbers *)
let pp_q_float n out q = Fmt.fprintf out "%*.1f" n (Q.to_float q)
let pp_q_dbg = pp_q_float 1
module Constraint = struct
type op = Op.t
@ -120,7 +133,7 @@ module Make(Var: VAR)
let pp out (self:t) =
Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs
(Op.to_string self.op) Q.pp_print self.rhs
(Op.to_string self.op) pp_q_dbg self.rhs
let leq lhs rhs = {lhs;rhs;op=Op.Leq}
let lt lhs rhs = {lhs;rhs;op=Op.Lt}
@ -132,7 +145,7 @@ module Make(Var: VAR)
type t = num V_map.t
let pp out (self:t) : unit =
let pp_pair out (v,n) =
Fmt.fprintf out "(@[%a := %a@])" V.pp v Q.pp_print n in
Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in
Fmt.fprintf out "{@[%a@]}" (Fmt.iter pp_pair) (V_map.to_iter self)
let to_string = Fmt.to_string pp
end
@ -177,10 +190,10 @@ module Make(Var: VAR)
let pp out e : unit =
if Q.equal Q.zero (eps_factor e)
then Q.pp_print out (base e)
then pp_q_dbg out (base e)
else
Fmt.fprintf out "(@[<h>%a + @<1>ε * %a@])"
Q.pp_print (base e) Q.pp_print (eps_factor e)
pp_q_dbg (base e) pp_q_dbg (eps_factor e)
end
type var_idx = int
@ -230,33 +243,36 @@ module Make(Var: VAR)
cols: num Vec.t;
}
type t = {
mutable n_cols: int;
rows: row Vec.t
}
let create() : t = {rows=Vec.create()}
let create() : t =
{n_cols=0; rows=Vec.create()}
let[@inline] n_rows self = Vec.size self.rows
let n_cols self =
if n_rows self=0 then 0
else Vec.size (Vec.get self.rows 0).cols
let[@inline] n_cols self = self.n_cols
let pp out self =
Fmt.fprintf out "{@[<v>";
Fmt.fprintf out "{@[<v>matrix[%dx%d]@," (n_rows self) (n_cols self);
Vec.iteri (fun i row ->
Fmt.fprintf out "@[<hov2>%-5d: %a@]@," i
(Fmt.iter ~sep:(Fmt.return "@ ") Q.pp_print) (Vec.to_seq row.cols))
Fmt.fprintf out "{@[<hov2>r%-3d: %a@]}@," i
(Fmt.iter ~sep:(Fmt.return "@ ") (pp_q_float 6)) (Vec.to_seq row.cols))
self.rows;
Fmt.fprintf out "@]}"
let to_string = Fmt.to_string pp
let add_column self =
self.n_cols <- 1 + self.n_cols;
Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows
let add_row_and_column self : int =
let n = n_rows self in
let j = n_cols self in
add_column self;
let row = {var_idx=j; cols=Vec.make (j+1) Q.zero} in
let cols = Vec.make (j+1) Q.zero in
for _k=0 to j do Vec.push cols Q.zero done;
let row = {var_idx=j; cols} in
Vec.push self.rows row;
n
@ -302,11 +318,25 @@ module Make(Var: VAR)
let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0
let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0
let in_bounds_ self =
(match self.l_bound with None -> true | Some b -> Erat.(self.value >= b.b_val)) &&
(match self.u_bound with None -> true | Some b -> Erat.(self.value <= b.b_val))
let pp out self =
Fmt.fprintf out "(@[var %a@ :basic %B@ :value %a@ :lbound %a@ :ubound %a@])"
Var.pp self.var (is_basic self) Erat.pp self.value
Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.l_bound
Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.u_bound
let bnd_status = if in_bounds_ self then "" else "(out of bounds)" in
let pp_bnd what out = function
| None -> ()
| Some b -> Fmt.fprintf out "@ @[%s %a@]" what Erat.pp b.b_val
and pp_basic_idx out () =
if self.basic_idx < 0 then () else Fmt.int out self.basic_idx
in
Fmt.fprintf out
"(@[var[%s%a]%s %a@ :value %a%a%a@])"
(if is_basic self then "B" else "N") pp_basic_idx ()
bnd_status
Var.pp self.var Erat.pp self.value
(pp_bnd ":lower") self.l_bound (pp_bnd ":upper") self.u_bound
end
type t = {
@ -328,7 +358,7 @@ module Make(Var: VAR)
()
let pp out (self:t) : unit =
Fmt.fprintf out "(@[simplex@ :vars %a@ :matrix %a@])"
Fmt.fprintf out "(@[simplex@ @[<1>:vars@ [@[<hov>%a@]]@]@ @[<1>:matrix@ %a@]@])"
(Vec.pp Var_state.pp) self.vars
Matrix.pp self.matrix
@ -340,15 +370,17 @@ module Make(Var: VAR)
let define (self:t) (x:V.t) (le:_ list) : unit =
assert (not (has_var_ self x));
Log.debugf 5 (fun k->k "(@[simplex.define@ %a@ :le %a@])"
Var.pp x Fmt.(Dump.(list @@ pair Q.pp_print Var.pp)) le);
Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var.pp)) le);
let idx = Vec.size self.vars in
let n = Matrix.add_row_and_column self.matrix in
let vs = {
var=x; value=Erat.zero;
idx=Vec.size self.vars;
idx;
basic_idx=n;
l_bound=None;
u_bound=None;
} in
assert (Var_state.is_basic vs);
Vec.push self.vars vs;
self.var_tbl <- V_map.add x vs self.var_tbl;
(* set coefficients in the matrix's new row: [-x + le = 0] *)
@ -356,7 +388,14 @@ module Make(Var: VAR)
List.iter
(fun (coeff,v2) ->
let vs2 = find_var_ self v2 in
(* FIXME: if [vs2] is basic, instead recurse with vs2's row.
copy coefficients of [vs2]'s row but multiplied with [coeff] .
See t4_short *)
Matrix.add self.matrix n vs2.idx coeff;
vs.value <- Erat.(vs.value + coeff * vs2.value); (* update value of [v] *)
) le;
()
@ -383,6 +422,10 @@ module Make(Var: VAR)
*)
let update_n_basic (self:t) (x:var_state) (v:erat) : unit =
assert (Var_state.is_n_basic x);
Log.debugf 50
(fun k->k "(@[simplex.update-n-basic@ %a@ :new-val %a@])"
Var_state.pp x Erat.pp v);
let m = self.matrix in
let i = x.idx in
@ -422,7 +465,7 @@ module Make(Var: VAR)
(* now pivot the variables so that [x_j]'s coeff is -1 *)
let new_coeff = Q.(minus_one / a_ij) in
for k=0 to Matrix.n_cols m-1 do
Matrix.mult m x_i.idx k new_coeff;
Matrix.mult m x_i.basic_idx k new_coeff;
done;
x_j.basic_idx <- x_i.basic_idx;
x_i.basic_idx <- -1;
@ -443,11 +486,11 @@ module Make(Var: VAR)
let pp out (self:t) =
let pp_bnd out (n,lit) =
Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit Q.pp_print n
Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit pp_q_dbg n
in
Fmt.fprintf out "(@[cert@ :bounds %a@ :defs %a@])"
Fmt.(Dump.list pp_bnd) self.cert_bounds
Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair Q.pp_print V.pp)))) self.cert_defs
Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair pp_q_dbg V.pp)))) self.cert_defs
let mk ~defs ~bounds : t =
{ cert_defs=defs; cert_bounds=bounds }
@ -455,8 +498,13 @@ module Make(Var: VAR)
exception E_unsat of Unsat_cert.t
let add_var self (v:V.t) : unit =
ignore (find_or_create_n_basic_var_ self v : var_state);
()
let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit =
let open Constraint in
Log.debugf 5 (fun k->k "(@[simplex2.add-constraint@ %a@])" Constraint.pp c);
let vs = find_or_create_n_basic_var_ self c.lhs in
let is_lower_bnd, new_bnd_val =
match c.op with
@ -561,12 +609,12 @@ module Make(Var: VAR)
page 15, figure 3.2 *)
let check_exn (self:t) : unit =
let exception Stop in
Log.debugf 5 (fun k->k "(@[simplex2.check@ %a@])" Matrix.pp self.matrix);
Log.debugf 20 (fun k->k "(@[simplex2.check@ %a@])" pp self);
let m = self.matrix in
try
while true do
Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" Matrix.pp self.matrix);
Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" pp self);
(* basic variable that doesn't respect its bound *)
let x_i, is_lower, bnd = match find_basic_var_ self with
@ -615,7 +663,9 @@ module Make(Var: VAR)
pivot_and_update self x_i x_j bnd.b_val
)
done;
with Stop -> ()
with Stop ->
Log.debugf 50 (fun k->k "(@[simplex2.check.done@])");
()
let create () : t =
let self = {
@ -630,6 +680,10 @@ module Make(Var: VAR)
| Sat of Subst.t
| Unsat of unsat_cert
let default_eps =
let denom = 1 lsl 10 in
Q.(one / of_int denom)
(* Find an epsilon that is small enough for finding a solution, yet
it must be positive.
@ -641,29 +695,65 @@ module Make(Var: VAR)
all the intervals of [t]. This rational will be the actual value of [ε].
*)
let solve_epsilon (self:t) : Q.t =
let emax =
let eps =
Iter.fold
(fun emax x ->
let {base=v; eps_factor=e_v} = x.value in
(* lower bound *)
let emax = match x.l_bound with
| Some {b_val={base=low;eps_factor=e_low};_} when Q.(e_v < e_low) ->
Q.min emax Q.((low - v) / (e_v - e_low))
| _ -> emax
(fun eps x ->
assert Q.(eps >= zero);
assert (Var_state.in_bounds_ x);
let x_val = x.value in
Log.debugf 50 (fun k->k "v.base=%a, v.eps=%a, emax=%a"
pp_q_dbg x_val.base pp_q_dbg x_val.eps_factor
pp_q_dbg eps);
(* is lower bound *)
let eps = match x.l_bound with
| Some {b_val;_}
when Q.(Erat.evaluate eps b_val > Erat.evaluate eps x_val) ->
assert (Erat.(x.value >= b_val));
assert (Q.(b_val.eps_factor > x.value.eps_factor));
(* current epsilon is too big. we need to make it smaller
than [x.value - b_val].
- [b_val.base + eps * b_val.factor
<= x.base + eps * x.factor]
- [eps * (b_val.factor - x.factor) <= x.base - b_val.base]
- [eps <= (x.base - b_val.base) / (b_val.factor - x.factor)]
*)
let new_eps =
Q.((x_val.base - b_val.base) /
(b_val.eps_factor - x_val.eps_factor))
in
Q.min eps new_eps
| _ -> eps
in
(* upper bound *)
let emax = match x.u_bound with
| Some { b_val={base=upp;eps_factor=e_upp}; _} when Q.(e_v > e_upp) ->
min emax Q.((upp - v) / (e_v - e_upp))
| _ -> emax
let eps = match x.u_bound with
| Some { b_val; _}
when Q.(Erat.evaluate eps b_val < Erat.evaluate eps x_val) ->
assert (Erat.(x.value <= b_val));
(* current epsilon is too big. we need to make it smaller
than [b_val - x.value].
- [x.base + eps * x.factor
<= b_val.base + eps * b_val.factor]
- [eps * (x.factor - b_val.factor) <= b_val.base - x.base]
- [eps <= (b_val.base - x.base) / (x.factor - b_val.factor)]
*)
let new_eps =
Q.((b_val.base - x_val.base) /
(x_val.eps_factor - b_val.eps_factor))
in
Log.debugf 5 (fun k->k "new max=%.5f" @@ Q.to_float new_eps);
Q.min eps new_eps
| _ -> eps
in
emax)
Q.inf (Vec.to_seq self.vars)
eps)
default_eps (Vec.to_seq self.vars)
in
if Q.compare emax Q.one >= 0 then Q.one else emax
if Q.(eps >= one) then Q.one else eps
let model_ self =
let eps = solve_epsilon self in
Log.debugf 50 (fun k->k "(@[simplex2.model@ :epsilon-val %a@])" pp_q_dbg eps);
let subst =
Vec.to_seq self.vars
|> Iter.fold
@ -673,6 +763,8 @@ module Make(Var: VAR)
V_map.add x.var v subst)
V_map.empty
in
Log.debugf 5
(fun k->k "(@[simplex2.model@ %a@])" Subst.pp subst);
subst
let check (self:t) : result =
@ -735,4 +827,16 @@ module Make(Var: VAR)
) else `Diff_not_0 expr_minus_x
end
*)
let _check_invariants self : unit =
Vec.iteri (fun i v -> assert(v.idx = i)) self.vars;
let n = Vec.size self.vars in
assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n);
for i = 0 to Matrix.n_rows self.matrix-1 do
let v = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in
assert (Var_state.is_basic v);
assert (v.basic_idx = i);
done;
() (* TODO: more *)
end