fix(lra): many fixes in simplex; some fixme/todo

This commit is contained in:
Simon Cruanes 2021-02-12 19:42:16 -05:00
parent 5fc8d746c2
commit f226c6b820

View file

@ -5,6 +5,8 @@
de Moura and Dutertre. de Moura and Dutertre.
*) *)
open CCMonomorphic
module type VAR = Linear_expr_intf.VAR module type VAR = Linear_expr_intf.VAR
(** {2 Basic operator} *) (** {2 Basic operator} *)
@ -79,6 +81,9 @@ module type S = sig
exception E_unsat of Unsat_cert.t exception E_unsat of Unsat_cert.t
val add_var : t -> V.t -> unit
(** Make sure the variable exists in the simplex. *)
val add_constraint : t -> Constraint.t -> V.lit -> unit val add_constraint : t -> Constraint.t -> V.lit -> unit
(** Add a constraint to the simplex. (** Add a constraint to the simplex.
@raise Unsat if it's immediately obvious that this is not satisfiable. *) @raise Unsat if it's immediately obvious that this is not satisfiable. *)
@ -93,6 +98,11 @@ module type S = sig
val check : t -> result val check : t -> result
(** Call {!check_exn} and return a model or a proof of unsat. *) (** Call {!check_exn} and return a model or a proof of unsat. *)
(**/**)
val _check_invariants : t -> unit
(* check internal invariants *)
(**/**)
end end
(* TODO(optim): page 14, paragraph 2: we could detect which variables occur in no (* TODO(optim): page 14, paragraph 2: we could detect which variables occur in no
@ -108,6 +118,9 @@ module Make(Var: VAR)
type num = Q.t (** Numbers *) type num = Q.t (** Numbers *)
let pp_q_float n out q = Fmt.fprintf out "%*.1f" n (Q.to_float q)
let pp_q_dbg = pp_q_float 1
module Constraint = struct module Constraint = struct
type op = Op.t type op = Op.t
@ -120,7 +133,7 @@ module Make(Var: VAR)
let pp out (self:t) = let pp out (self:t) =
Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs
(Op.to_string self.op) Q.pp_print self.rhs (Op.to_string self.op) pp_q_dbg self.rhs
let leq lhs rhs = {lhs;rhs;op=Op.Leq} let leq lhs rhs = {lhs;rhs;op=Op.Leq}
let lt lhs rhs = {lhs;rhs;op=Op.Lt} let lt lhs rhs = {lhs;rhs;op=Op.Lt}
@ -132,7 +145,7 @@ module Make(Var: VAR)
type t = num V_map.t type t = num V_map.t
let pp out (self:t) : unit = let pp out (self:t) : unit =
let pp_pair out (v,n) = let pp_pair out (v,n) =
Fmt.fprintf out "(@[%a := %a@])" V.pp v Q.pp_print n in Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in
Fmt.fprintf out "{@[%a@]}" (Fmt.iter pp_pair) (V_map.to_iter self) Fmt.fprintf out "{@[%a@]}" (Fmt.iter pp_pair) (V_map.to_iter self)
let to_string = Fmt.to_string pp let to_string = Fmt.to_string pp
end end
@ -177,10 +190,10 @@ module Make(Var: VAR)
let pp out e : unit = let pp out e : unit =
if Q.equal Q.zero (eps_factor e) if Q.equal Q.zero (eps_factor e)
then Q.pp_print out (base e) then pp_q_dbg out (base e)
else else
Fmt.fprintf out "(@[<h>%a + @<1>ε * %a@])" Fmt.fprintf out "(@[<h>%a + @<1>ε * %a@])"
Q.pp_print (base e) Q.pp_print (eps_factor e) pp_q_dbg (base e) pp_q_dbg (eps_factor e)
end end
type var_idx = int type var_idx = int
@ -230,33 +243,36 @@ module Make(Var: VAR)
cols: num Vec.t; cols: num Vec.t;
} }
type t = { type t = {
mutable n_cols: int;
rows: row Vec.t rows: row Vec.t
} }
let create() : t = {rows=Vec.create()} let create() : t =
{n_cols=0; rows=Vec.create()}
let[@inline] n_rows self = Vec.size self.rows let[@inline] n_rows self = Vec.size self.rows
let n_cols self = let[@inline] n_cols self = self.n_cols
if n_rows self=0 then 0
else Vec.size (Vec.get self.rows 0).cols
let pp out self = let pp out self =
Fmt.fprintf out "{@[<v>"; Fmt.fprintf out "{@[<v>matrix[%dx%d]@," (n_rows self) (n_cols self);
Vec.iteri (fun i row -> Vec.iteri (fun i row ->
Fmt.fprintf out "@[<hov2>%-5d: %a@]@," i Fmt.fprintf out "{@[<hov2>r%-3d: %a@]}@," i
(Fmt.iter ~sep:(Fmt.return "@ ") Q.pp_print) (Vec.to_seq row.cols)) (Fmt.iter ~sep:(Fmt.return "@ ") (pp_q_float 6)) (Vec.to_seq row.cols))
self.rows; self.rows;
Fmt.fprintf out "@]}" Fmt.fprintf out "@]}"
let to_string = Fmt.to_string pp let to_string = Fmt.to_string pp
let add_column self = let add_column self =
self.n_cols <- 1 + self.n_cols;
Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows
let add_row_and_column self : int = let add_row_and_column self : int =
let n = n_rows self in let n = n_rows self in
let j = n_cols self in let j = n_cols self in
add_column self; add_column self;
let row = {var_idx=j; cols=Vec.make (j+1) Q.zero} in let cols = Vec.make (j+1) Q.zero in
for _k=0 to j do Vec.push cols Q.zero done;
let row = {var_idx=j; cols} in
Vec.push self.rows row; Vec.push self.rows row;
n n
@ -302,11 +318,25 @@ module Make(Var: VAR)
let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0 let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0
let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0 let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0
let in_bounds_ self =
(match self.l_bound with None -> true | Some b -> Erat.(self.value >= b.b_val)) &&
(match self.u_bound with None -> true | Some b -> Erat.(self.value <= b.b_val))
let pp out self = let pp out self =
Fmt.fprintf out "(@[var %a@ :basic %B@ :value %a@ :lbound %a@ :ubound %a@])" let bnd_status = if in_bounds_ self then "" else "(out of bounds)" in
Var.pp self.var (is_basic self) Erat.pp self.value let pp_bnd what out = function
Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.l_bound | None -> ()
Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.u_bound | Some b -> Fmt.fprintf out "@ @[%s %a@]" what Erat.pp b.b_val
and pp_basic_idx out () =
if self.basic_idx < 0 then () else Fmt.int out self.basic_idx
in
Fmt.fprintf out
"(@[var[%s%a]%s %a@ :value %a%a%a@])"
(if is_basic self then "B" else "N") pp_basic_idx ()
bnd_status
Var.pp self.var Erat.pp self.value
(pp_bnd ":lower") self.l_bound (pp_bnd ":upper") self.u_bound
end end
type t = { type t = {
@ -328,7 +358,7 @@ module Make(Var: VAR)
() ()
let pp out (self:t) : unit = let pp out (self:t) : unit =
Fmt.fprintf out "(@[simplex@ :vars %a@ :matrix %a@])" Fmt.fprintf out "(@[simplex@ @[<1>:vars@ [@[<hov>%a@]]@]@ @[<1>:matrix@ %a@]@])"
(Vec.pp Var_state.pp) self.vars (Vec.pp Var_state.pp) self.vars
Matrix.pp self.matrix Matrix.pp self.matrix
@ -340,15 +370,17 @@ module Make(Var: VAR)
let define (self:t) (x:V.t) (le:_ list) : unit = let define (self:t) (x:V.t) (le:_ list) : unit =
assert (not (has_var_ self x)); assert (not (has_var_ self x));
Log.debugf 5 (fun k->k "(@[simplex.define@ %a@ :le %a@])" Log.debugf 5 (fun k->k "(@[simplex.define@ %a@ :le %a@])"
Var.pp x Fmt.(Dump.(list @@ pair Q.pp_print Var.pp)) le); Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var.pp)) le);
let idx = Vec.size self.vars in
let n = Matrix.add_row_and_column self.matrix in let n = Matrix.add_row_and_column self.matrix in
let vs = { let vs = {
var=x; value=Erat.zero; var=x; value=Erat.zero;
idx=Vec.size self.vars; idx;
basic_idx=n; basic_idx=n;
l_bound=None; l_bound=None;
u_bound=None; u_bound=None;
} in } in
assert (Var_state.is_basic vs);
Vec.push self.vars vs; Vec.push self.vars vs;
self.var_tbl <- V_map.add x vs self.var_tbl; self.var_tbl <- V_map.add x vs self.var_tbl;
(* set coefficients in the matrix's new row: [-x + le = 0] *) (* set coefficients in the matrix's new row: [-x + le = 0] *)
@ -356,7 +388,14 @@ module Make(Var: VAR)
List.iter List.iter
(fun (coeff,v2) -> (fun (coeff,v2) ->
let vs2 = find_var_ self v2 in let vs2 = find_var_ self v2 in
(* FIXME: if [vs2] is basic, instead recurse with vs2's row.
copy coefficients of [vs2]'s row but multiplied with [coeff] .
See t4_short *)
Matrix.add self.matrix n vs2.idx coeff; Matrix.add self.matrix n vs2.idx coeff;
vs.value <- Erat.(vs.value + coeff * vs2.value); (* update value of [v] *)
) le; ) le;
() ()
@ -383,6 +422,10 @@ module Make(Var: VAR)
*) *)
let update_n_basic (self:t) (x:var_state) (v:erat) : unit = let update_n_basic (self:t) (x:var_state) (v:erat) : unit =
assert (Var_state.is_n_basic x); assert (Var_state.is_n_basic x);
Log.debugf 50
(fun k->k "(@[simplex.update-n-basic@ %a@ :new-val %a@])"
Var_state.pp x Erat.pp v);
let m = self.matrix in let m = self.matrix in
let i = x.idx in let i = x.idx in
@ -422,7 +465,7 @@ module Make(Var: VAR)
(* now pivot the variables so that [x_j]'s coeff is -1 *) (* now pivot the variables so that [x_j]'s coeff is -1 *)
let new_coeff = Q.(minus_one / a_ij) in let new_coeff = Q.(minus_one / a_ij) in
for k=0 to Matrix.n_cols m-1 do for k=0 to Matrix.n_cols m-1 do
Matrix.mult m x_i.idx k new_coeff; Matrix.mult m x_i.basic_idx k new_coeff;
done; done;
x_j.basic_idx <- x_i.basic_idx; x_j.basic_idx <- x_i.basic_idx;
x_i.basic_idx <- -1; x_i.basic_idx <- -1;
@ -443,11 +486,11 @@ module Make(Var: VAR)
let pp out (self:t) = let pp out (self:t) =
let pp_bnd out (n,lit) = let pp_bnd out (n,lit) =
Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit Q.pp_print n Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit pp_q_dbg n
in in
Fmt.fprintf out "(@[cert@ :bounds %a@ :defs %a@])" Fmt.fprintf out "(@[cert@ :bounds %a@ :defs %a@])"
Fmt.(Dump.list pp_bnd) self.cert_bounds Fmt.(Dump.list pp_bnd) self.cert_bounds
Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair Q.pp_print V.pp)))) self.cert_defs Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair pp_q_dbg V.pp)))) self.cert_defs
let mk ~defs ~bounds : t = let mk ~defs ~bounds : t =
{ cert_defs=defs; cert_bounds=bounds } { cert_defs=defs; cert_bounds=bounds }
@ -455,8 +498,13 @@ module Make(Var: VAR)
exception E_unsat of Unsat_cert.t exception E_unsat of Unsat_cert.t
let add_var self (v:V.t) : unit =
ignore (find_or_create_n_basic_var_ self v : var_state);
()
let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit = let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit =
let open Constraint in let open Constraint in
Log.debugf 5 (fun k->k "(@[simplex2.add-constraint@ %a@])" Constraint.pp c);
let vs = find_or_create_n_basic_var_ self c.lhs in let vs = find_or_create_n_basic_var_ self c.lhs in
let is_lower_bnd, new_bnd_val = let is_lower_bnd, new_bnd_val =
match c.op with match c.op with
@ -561,12 +609,12 @@ module Make(Var: VAR)
page 15, figure 3.2 *) page 15, figure 3.2 *)
let check_exn (self:t) : unit = let check_exn (self:t) : unit =
let exception Stop in let exception Stop in
Log.debugf 5 (fun k->k "(@[simplex2.check@ %a@])" Matrix.pp self.matrix); Log.debugf 20 (fun k->k "(@[simplex2.check@ %a@])" pp self);
let m = self.matrix in let m = self.matrix in
try try
while true do while true do
Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" Matrix.pp self.matrix); Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" pp self);
(* basic variable that doesn't respect its bound *) (* basic variable that doesn't respect its bound *)
let x_i, is_lower, bnd = match find_basic_var_ self with let x_i, is_lower, bnd = match find_basic_var_ self with
@ -615,7 +663,9 @@ module Make(Var: VAR)
pivot_and_update self x_i x_j bnd.b_val pivot_and_update self x_i x_j bnd.b_val
) )
done; done;
with Stop -> () with Stop ->
Log.debugf 50 (fun k->k "(@[simplex2.check.done@])");
()
let create () : t = let create () : t =
let self = { let self = {
@ -630,6 +680,10 @@ module Make(Var: VAR)
| Sat of Subst.t | Sat of Subst.t
| Unsat of unsat_cert | Unsat of unsat_cert
let default_eps =
let denom = 1 lsl 10 in
Q.(one / of_int denom)
(* Find an epsilon that is small enough for finding a solution, yet (* Find an epsilon that is small enough for finding a solution, yet
it must be positive. it must be positive.
@ -641,29 +695,65 @@ module Make(Var: VAR)
all the intervals of [t]. This rational will be the actual value of [ε]. all the intervals of [t]. This rational will be the actual value of [ε].
*) *)
let solve_epsilon (self:t) : Q.t = let solve_epsilon (self:t) : Q.t =
let emax = let eps =
Iter.fold Iter.fold
(fun emax x -> (fun eps x ->
let {base=v; eps_factor=e_v} = x.value in assert Q.(eps >= zero);
(* lower bound *) assert (Var_state.in_bounds_ x);
let emax = match x.l_bound with
| Some {b_val={base=low;eps_factor=e_low};_} when Q.(e_v < e_low) -> let x_val = x.value in
Q.min emax Q.((low - v) / (e_v - e_low)) Log.debugf 50 (fun k->k "v.base=%a, v.eps=%a, emax=%a"
| _ -> emax pp_q_dbg x_val.base pp_q_dbg x_val.eps_factor
pp_q_dbg eps);
(* is lower bound *)
let eps = match x.l_bound with
| Some {b_val;_}
when Q.(Erat.evaluate eps b_val > Erat.evaluate eps x_val) ->
assert (Erat.(x.value >= b_val));
assert (Q.(b_val.eps_factor > x.value.eps_factor));
(* current epsilon is too big. we need to make it smaller
than [x.value - b_val].
- [b_val.base + eps * b_val.factor
<= x.base + eps * x.factor]
- [eps * (b_val.factor - x.factor) <= x.base - b_val.base]
- [eps <= (x.base - b_val.base) / (b_val.factor - x.factor)]
*)
let new_eps =
Q.((x_val.base - b_val.base) /
(b_val.eps_factor - x_val.eps_factor))
in
Q.min eps new_eps
| _ -> eps
in in
(* upper bound *) (* upper bound *)
let emax = match x.u_bound with let eps = match x.u_bound with
| Some { b_val={base=upp;eps_factor=e_upp}; _} when Q.(e_v > e_upp) -> | Some { b_val; _}
min emax Q.((upp - v) / (e_v - e_upp)) when Q.(Erat.evaluate eps b_val < Erat.evaluate eps x_val) ->
| _ -> emax assert (Erat.(x.value <= b_val));
(* current epsilon is too big. we need to make it smaller
than [b_val - x.value].
- [x.base + eps * x.factor
<= b_val.base + eps * b_val.factor]
- [eps * (x.factor - b_val.factor) <= b_val.base - x.base]
- [eps <= (b_val.base - x.base) / (x.factor - b_val.factor)]
*)
let new_eps =
Q.((b_val.base - x_val.base) /
(x_val.eps_factor - b_val.eps_factor))
in
Log.debugf 5 (fun k->k "new max=%.5f" @@ Q.to_float new_eps);
Q.min eps new_eps
| _ -> eps
in in
emax) eps)
Q.inf (Vec.to_seq self.vars) default_eps (Vec.to_seq self.vars)
in in
if Q.compare emax Q.one >= 0 then Q.one else emax if Q.(eps >= one) then Q.one else eps
let model_ self = let model_ self =
let eps = solve_epsilon self in let eps = solve_epsilon self in
Log.debugf 50 (fun k->k "(@[simplex2.model@ :epsilon-val %a@])" pp_q_dbg eps);
let subst = let subst =
Vec.to_seq self.vars Vec.to_seq self.vars
|> Iter.fold |> Iter.fold
@ -673,6 +763,8 @@ module Make(Var: VAR)
V_map.add x.var v subst) V_map.add x.var v subst)
V_map.empty V_map.empty
in in
Log.debugf 5
(fun k->k "(@[simplex2.model@ %a@])" Subst.pp subst);
subst subst
let check (self:t) : result = let check (self:t) : result =
@ -735,4 +827,16 @@ module Make(Var: VAR)
) else `Diff_not_0 expr_minus_x ) else `Diff_not_0 expr_minus_x
end end
*) *)
let _check_invariants self : unit =
Vec.iteri (fun i v -> assert(v.idx = i)) self.vars;
let n = Vec.size self.vars in
assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n);
for i = 0 to Matrix.n_rows self.matrix-1 do
let v = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in
assert (Var_state.is_basic v);
assert (v.basic_idx = i);
done;
() (* TODO: more *)
end end