fix(simplex2): correct pivot; refactor; better printing

This commit is contained in:
Simon Cruanes 2021-02-15 13:29:12 -05:00
parent f226c6b820
commit 4d9f99e65d

View file

@ -182,6 +182,7 @@ module Make(Var: VAR)
let[@inline] (<=) a b = compare a b <= 0 let[@inline] (<=) a b = compare a b <= 0
let[@inline] (>) a b = compare a b > 0 let[@inline] (>) a b = compare a b > 0
let[@inline] (>=) a b = compare a b >= 0 let[@inline] (>=) a b = compare a b >= 0
let[@inline] (=) a b = compare a b = 0
let[@inline] min x y = if x <= y then x else y let[@inline] min x y = if x <= y then x else y
let[@inline] max x y = if x >= y then x else y let[@inline] max x y = if x >= y then x else y
@ -196,7 +197,19 @@ module Make(Var: VAR)
pp_q_dbg (base e) pp_q_dbg (eps_factor e) pp_q_dbg (base e) pp_q_dbg (eps_factor e)
end end
type var_idx = int type bound = {
b_val: erat;
b_lit: Var.lit;
}
type var_state = {
var: V.t;
mutable value: erat;
idx: int; (* index in {!t.vars} *)
mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *)
mutable l_bound: bound option;
mutable u_bound: bound option;
}
(** {2 Matrix} (** {2 Matrix}
The matrix [A] from the paper, with m rows and n columns. The matrix [A] from the paper, with m rows and n columns.
@ -221,14 +234,14 @@ module Make(Var: VAR)
val add_column : t -> unit val add_column : t -> unit
(** Add a non-basic variable (only adds a column) *) (** Add a non-basic variable (only adds a column) *)
val add_row_and_column : t -> int val add_row_and_column : t -> f:(row_idx:int -> col_idx:int -> var_state) -> var_state
(** Add a basic variable. returns the row index. *) (** Add a basic variable. *)
val get_row_var_idx : t -> int -> var_idx val get_row_var : t -> int -> var_state
(** Index of the basic variable for row [i] *) (** The basic variable for row [i] *)
val set_row_var_idx : t -> int -> var_idx -> unit val set_row_var : t -> int -> var_state -> unit
(** Set index of the basic variable for row [i] *) (** Set basic variable for row [i] *)
val get : t -> int -> int -> num val get : t -> int -> int -> num
@ -239,7 +252,7 @@ module Make(Var: VAR)
val mult : t -> int -> int -> num -> unit val mult : t -> int -> int -> num -> unit
end = struct end = struct
type row = { type row = {
mutable var_idx: var_idx; mutable vs: var_state;
cols: num Vec.t; cols: num Vec.t;
} }
type t = { type t = {
@ -254,30 +267,41 @@ module Make(Var: VAR)
let[@inline] n_cols self = self.n_cols let[@inline] n_cols self = self.n_cols
let pp out self = let pp out self =
Fmt.fprintf out "{@[<v>matrix[%dx%d]@," (n_rows self) (n_cols self); Fmt.fprintf out "@[<v1>{matrix[%dx%d]@," (n_rows self) (n_cols self);
(* header *)
let ppi out i =
Fmt.string out @@ CCString.pad ~side:`Left 6 @@ Printf.sprintf "v%d" i in
Fmt.fprintf out "{@[<hov2>%9s: %a@]}" "vars"
(Fmt.iter ~sep:(Fmt.return "@ ") ppi) CCInt.(0 -- (n_cols self-1));
Vec.iteri (fun i row -> Vec.iteri (fun i row ->
Fmt.fprintf out "{@[<hov2>r%-3d: %a@]}@," i let hd =
CCString.pad ~side:`Left 6 @@
Printf.sprintf "r%d (v%d)" i row.vs.idx in
Fmt.fprintf out "@,{@[<hov2>%9s: %a@]}" hd
(Fmt.iter ~sep:(Fmt.return "@ ") (pp_q_float 6)) (Vec.to_seq row.cols)) (Fmt.iter ~sep:(Fmt.return "@ ") (pp_q_float 6)) (Vec.to_seq row.cols))
self.rows; self.rows;
Fmt.fprintf out "@]}" Fmt.fprintf out "@;<0 -1>}@]"
let to_string = Fmt.to_string pp let to_string = Fmt.to_string pp
let add_column self = let add_column self =
self.n_cols <- 1 + self.n_cols; self.n_cols <- 1 + self.n_cols;
Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows
let add_row_and_column self : int = let add_row_and_column self ~f : var_state =
let n = n_rows self in let n = n_rows self in
let j = n_cols self in let j = n_cols self in
add_column self; add_column self;
let cols = Vec.make (j+1) Q.zero in let cols = Vec.make (j+1) Q.zero in
for _k=0 to j do Vec.push cols Q.zero done; for _k=0 to j do Vec.push cols Q.zero done;
let row = {var_idx=j; cols} in let vs = f ~row_idx:n ~col_idx:j in
let row = {vs; cols} in
Vec.push self.rows row; Vec.push self.rows row;
n vs
let[@inline] get_row_var_idx self i = (Vec.get self.rows i).var_idx let[@inline] get_row_var self i = (Vec.get self.rows i).vs
let[@inline] set_row_var_idx self i n = (Vec.get self.rows i).var_idx <- n let[@inline] set_row_var self i v = (Vec.get self.rows i).vs <- v
let[@inline] get self i j : num = Vec.get (Vec.get self.rows i).cols j let[@inline] get self i j : num = Vec.get (Vec.get self.rows i).cols j
@ -294,27 +318,15 @@ module Make(Var: VAR)
let mult self i j c : unit = let mult self i j c : unit =
let r = Vec.get self.rows i in let r = Vec.get self.rows i in
let n_j = Vec.get r.cols j in let n_j = Vec.get r.cols j in
if Q.sign n_j <> 0 then ( if Q.(n_j <> zero) then (
Vec.set r.cols j Q.(n_j * c) Vec.set r.cols j Q.(n_j * c)
) )
end end
type bound = {
b_val: erat;
b_lit: Var.lit;
}
type var_state = {
var: V.t;
mutable value: erat;
idx: int; (* index in {!t.vars} *)
mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *)
mutable l_bound: bound option;
mutable u_bound: bound option;
}
module Var_state = struct module Var_state = struct
type t = var_state type t = var_state
let (==) : t -> t -> bool = Containers.Stdlib.(==)
let (!=) : t -> t -> bool = Containers.Stdlib.(!=)
let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0 let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0
let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0 let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0
@ -323,8 +335,9 @@ module Make(Var: VAR)
(match self.l_bound with None -> true | Some b -> Erat.(self.value >= b.b_val)) && (match self.l_bound with None -> true | Some b -> Erat.(self.value >= b.b_val)) &&
(match self.u_bound with None -> true | Some b -> Erat.(self.value <= b.b_val)) (match self.u_bound with None -> true | Some b -> Erat.(self.value <= b.b_val))
let pp_name out self = Fmt.fprintf out "v%d" self.idx
let pp out self = let pp out self =
let bnd_status = if in_bounds_ self then "" else "(out of bounds)" in let bnd_status = if in_bounds_ self then "" else "(oob)" in
let pp_bnd what out = function let pp_bnd what out = function
| None -> () | None -> ()
| Some b -> Fmt.fprintf out "@ @[%s %a@]" what Erat.pp b.b_val | Some b -> Fmt.fprintf out "@ @[%s %a@]" what Erat.pp b.b_val
@ -332,10 +345,9 @@ module Make(Var: VAR)
if self.basic_idx < 0 then () else Fmt.int out self.basic_idx if self.basic_idx < 0 then () else Fmt.int out self.basic_idx
in in
Fmt.fprintf out Fmt.fprintf out
"(@[var[%s%a]%s %a@ :value %a%a%a@])" "(@[v%d[%s%a]%s@ :value %a%a%a@])"
(if is_basic self then "B" else "N") pp_basic_idx () self.idx (if is_basic self then "B" else "N") pp_basic_idx ()
bnd_status bnd_status Erat.pp self.value
Var.pp self.var Erat.pp self.value
(pp_bnd ":lower") self.l_bound (pp_bnd ":upper") self.u_bound (pp_bnd ":lower") self.l_bound (pp_bnd ":upper") self.u_bound
end end
@ -362,41 +374,108 @@ module Make(Var: VAR)
(Vec.pp Var_state.pp) self.vars (Vec.pp Var_state.pp) self.vars
Matrix.pp self.matrix Matrix.pp self.matrix
(* for debug purposes *)
let _check_invariants self : unit =
Vec.iteri (fun i v -> assert(v.idx = i)) self.vars;
let n = Vec.size self.vars in
assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n);
for i = 0 to Matrix.n_rows self.matrix-1 do
let v = Matrix.get_row_var self.matrix i in
assert (Var_state.is_basic v);
assert (v.basic_idx = i);
assert Q.(Matrix.get self.matrix v.basic_idx v.idx = minus_one);
(* basic vars are only defined in terms of non-basic vars *)
Vec.iteri
(fun j v_j ->
if Var_state.(v != v_j) && Q.(Matrix.get self.matrix i j <> zero) then (
assert (Var_state.is_n_basic v_j)
))
self.vars;
(* sum of each row must be 0 *)
let sum =
Vec.fold
(fun sum v ->
Erat.(sum + Matrix.get self.matrix i v.idx * v.value))
Erat.zero self.vars
in
Log.debugf 50 (fun k->k "row %d: sum %a" i Erat.pp sum);
assert Erat.(sum = zero);
done;
() (* TODO: more *)
(* for internal checking *)
let _check_invariants_internal self =
if false (* FUDGE *) then _check_invariants self
let[@inline] has_var_ (self:t) x : bool = V_map.mem x self.var_tbl let[@inline] has_var_ (self:t) x : bool = V_map.mem x self.var_tbl
let[@inline] find_var_ (self:t) x : var_state = let[@inline] find_var_ (self:t) x : var_state =
try V_map.find x self.var_tbl try V_map.find x self.var_tbl
with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x
(* define [x] as a basic variable *)
let define (self:t) (x:V.t) (le:_ list) : unit = let define (self:t) (x:V.t) (le:_ list) : unit =
assert (not (has_var_ self x)); assert (not (has_var_ self x));
Log.debugf 5 (fun k->k "(@[simplex.define@ %a@ :le %a@])" (* Log.debugf 50 (fun k->k "define-in: %a" pp self); *)
Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var.pp)) le); let le = List.map (fun (q,v) -> q, find_var_ self v) le in
let idx = Vec.size self.vars in
let n = Matrix.add_row_and_column self.matrix in (* initial value for the new variable *)
let vs = { let value =
var=x; value=Erat.zero; List.fold_left
idx; (fun sum (c,v) -> Erat.(sum + c * v.value)) Erat.zero le
basic_idx=n; in
l_bound=None;
u_bound=None; let vs =
} in Matrix.add_row_and_column self.matrix
~f:(fun ~row_idx ~col_idx ->
{
var=x; value;
idx=col_idx;
basic_idx=row_idx;
l_bound=None;
u_bound=None;
})
in
Log.debugf 5 (fun k->k "(@[simplex.define@ @[v%d :var %a@]@ :linexpr %a@])"
vs.idx Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var_state.pp_name)) le);
assert (Var_state.is_basic vs); assert (Var_state.is_basic vs);
assert Var_state.(Matrix.get_row_var self.matrix vs.basic_idx == vs);
Vec.push self.vars vs; Vec.push self.vars vs;
self.var_tbl <- V_map.add x vs self.var_tbl; self.var_tbl <- V_map.add x vs self.var_tbl;
(* set coefficients in the matrix's new row: [-x + le = 0] *) (* set coefficients in the matrix's new row: [-x + le = 0] *)
Matrix.set self.matrix n vs.idx Q.minus_one; Matrix.set self.matrix vs.basic_idx vs.idx Q.minus_one;
List.iter List.iter
(fun (coeff,v2) -> (fun (coeff,v2) ->
let vs2 = find_var_ self v2 in assert Containers.Stdlib.(v2 != vs);
(* FIXME: if [vs2] is basic, instead recurse with vs2's row. if Var_state.is_basic v2 then (
copy coefficients of [vs2]'s row but multiplied with [coeff] . (* [v2] is also basic, so instead of depending on it,
we depend on its definition in terms of non-basic variables. *)
See t4_short *) for k=0 to Matrix.n_cols self.matrix - 1 do
if k <> v2.idx then (
let v2_jk = Matrix.get self.matrix v2.basic_idx k in
if Q.(v2_jk <> zero) then (
let v_k = Vec.get self.vars k in
assert (Var_state.is_n_basic v_k);
Matrix.add self.matrix n vs2.idx coeff; (* [v2 := v2_jk * v_k + …], so [v := … + coeff * v2_jk * v_k] *)
vs.value <- Erat.(vs.value + coeff * vs2.value); (* update value of [v] *) Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk);
);
);
done;
) else (
(* directly add coefficient with non-basic var [v2] *)
Matrix.add self.matrix vs.basic_idx v2.idx coeff;
);
) le; ) le;
Log.debugf 50 (fun k->k "post-define: %a" pp self);
_check_invariants_internal self;
() ()
(* find the state for [x], or add [x] as a non-basic variable *) (* find the state for [x], or add [x] as a non-basic variable *)
@ -423,8 +502,9 @@ module Make(Var: VAR)
let update_n_basic (self:t) (x:var_state) (v:erat) : unit = let update_n_basic (self:t) (x:var_state) (v:erat) : unit =
assert (Var_state.is_n_basic x); assert (Var_state.is_n_basic x);
Log.debugf 50 Log.debugf 50
(fun k->k "(@[simplex.update-n-basic@ %a@ :new-val %a@])" (fun k->k "(@[<hv>simplex.update-n-basic@ %a@ :new-val %a@ :in %a@])"
Var_state.pp x Erat.pp v); Var_state.pp x Erat.pp v pp self);
_check_invariants_internal self;
let m = self.matrix in let m = self.matrix in
let i = x.idx in let i = x.idx in
@ -433,12 +513,15 @@ module Make(Var: VAR)
for j=0 to Matrix.n_rows m - 1 do for j=0 to Matrix.n_rows m - 1 do
let a_ji = Matrix.get m j i in let a_ji = Matrix.get m j i in
let x_j = Vec.get self.vars (Matrix.get_row_var_idx m j) in let x_j = Matrix.get_row_var m j in
assert (Var_state.is_basic x_j); assert (Var_state.is_basic x_j);
(* value of [x_j] by [a_ji * diff] *) (* value of [x_j] by [a_ji * diff] *)
x_j.value <- Erat.(x_j.value + a_ji * diff); let new_val = Erat.(x_j.value + a_ji * diff) in
Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val);
x_j.value <- new_val;
done; done;
x.value <- v; x.value <- v;
_check_invariants_internal self;
() ()
(* pivot [x_i] (basic) and [x_j] (non-basic), setting value of [x_i] (* pivot [x_i] (basic) and [x_j] (non-basic), setting value of [x_i]
@ -447,6 +530,8 @@ module Make(Var: VAR)
let pivot_and_update (self:t) (x_i:var_state) (x_j:var_state) (v:erat) : unit = let pivot_and_update (self:t) (x_i:var_state) (x_j:var_state) (v:erat) : unit =
assert (Var_state.is_basic x_i); assert (Var_state.is_basic x_i);
assert (Var_state.is_n_basic x_j); assert (Var_state.is_n_basic x_j);
_check_invariants_internal self;
let m = self.matrix in let m = self.matrix in
let a_ij = Matrix.get m x_i.basic_idx x_j.idx in let a_ij = Matrix.get m x_i.basic_idx x_j.idx in
assert (Q.sign a_ij <> 0); assert (Q.sign a_ij <> 0);
@ -456,23 +541,59 @@ module Make(Var: VAR)
for k=0 to Matrix.n_rows m-1 do for k=0 to Matrix.n_rows m-1 do
if k <> x_i.basic_idx then ( if k <> x_i.basic_idx then (
let x_k = Vec.get self.vars (Matrix.get_row_var_idx m k) in let x_k = Matrix.get_row_var m k in
let a_kj = Matrix.get m x_k.basic_idx x_j.idx in let a_kj = Matrix.get m x_k.basic_idx x_j.idx in
x_k.value <- Erat.(x_k.value + a_kj * theta); x_k.value <- Erat.(x_k.value + a_kj * theta);
) )
done; done;
(* now pivot the variables so that [x_j]'s coeff is -1 *) begin
let new_coeff = Q.(minus_one / a_ij) in (* now pivot the variables so that [x_j]'s coeff is -1 and so that
for k=0 to Matrix.n_cols m-1 do other basic variables only depend on non-basic variables. *)
Matrix.mult m x_i.basic_idx k new_coeff; let new_coeff = Q.(minus_one / a_ij) in
done; for k=0 to Matrix.n_cols m-1 do
x_j.basic_idx <- x_i.basic_idx; Matrix.mult m x_i.basic_idx k new_coeff; (* update row of [x_i] *)
x_i.basic_idx <- -1; done;
Matrix.set_row_var_idx m x_j.basic_idx x_j.idx; assert Q.(Matrix.get m x_i.basic_idx x_j.idx = minus_one);
(* make [x_i] non basic, and [x_j] basic *)
x_j.basic_idx <- x_i.basic_idx;
x_i.basic_idx <- -1;
Matrix.set_row_var m x_j.basic_idx x_j;
(* adjust other rows so they don't depend on [x_j] *)
for k=0 to Matrix.n_rows m-1 do
if k <> x_j.basic_idx then (
let x_k = Matrix.get_row_var m k in
assert (Var_state.is_basic x_k);
let c_kj = Matrix.get m k x_j.idx in
if Q.(c_kj <> zero) then (
(* [m[k,j] != 0] indicates that basic variable [x_k] depends on
[x_j], which is about to become basic. To avoid basic-basic
dependency we replace [x_j] by its (new) definition *)
for l=0 to Matrix.n_cols m-1 do
if l<>x_j.idx then (
let c_jl = Matrix.get m x_j.basic_idx l in
(* so:
[x_k := c_kj * x_j + ], we want to eliminate [x_j],
and [x_j := + c_jl * x_l + ].
therefore [x_j := + c_jl * c_kl * x_l] *)
Matrix.add m k l Q.(c_kj * c_jl);
)
done;
Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *)
)
)
done;
end;
assert (Var_state.is_basic x_j); assert (Var_state.is_basic x_j);
assert (Var_state.is_n_basic x_i); assert (Var_state.is_n_basic x_i);
(* Log.debugf 50 (fun k->k "post pivot: %a" pp self); *)
_check_invariants_internal self;
() ()
@ -504,8 +625,12 @@ module Make(Var: VAR)
let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit = let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit =
let open Constraint in let open Constraint in
Log.debugf 5 (fun k->k "(@[simplex2.add-constraint@ %a@])" Constraint.pp c);
let vs = find_or_create_n_basic_var_ self c.lhs in let vs = find_or_create_n_basic_var_ self c.lhs in
Log.debugf 5
(fun k->k "(@[simplex2.add-constraint@ :var %a@ :c %a@])"
Var_state.pp_name vs Constraint.pp c);
let is_lower_bnd, new_bnd_val = let is_lower_bnd, new_bnd_val =
match c.op with match c.op with
| Leq -> false, Erat.make_q c.rhs | Leq -> false, Erat.make_q c.rhs
@ -561,7 +686,7 @@ module Make(Var: VAR)
let rec aux i = let rec aux i =
if i >= n then None if i >= n then None
else ( else (
let x_i = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in let x_i = Matrix.get_row_var self.matrix i in
let v_i = x_i.value in let v_i = x_i.value in
match x_i.l_bound, x_i.u_bound with match x_i.l_bound, x_i.u_bound with
| Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l) | Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l)
@ -614,6 +739,7 @@ module Make(Var: VAR)
let m = self.matrix in let m = self.matrix in
try try
while true do while true do
_check_invariants_internal self;
Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" pp self); Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" pp self);
(* basic variable that doesn't respect its bound *) (* basic variable that doesn't respect its bound *)
@ -641,6 +767,9 @@ module Make(Var: VAR)
assert (Var_state.is_n_basic x_j); assert (Var_state.is_n_basic x_j);
(* line 9 *) (* line 9 *)
Log.debugf 50
(fun k->k "(@[simplex2.pivot@ :basic %a@ :n-basic %a@])"
Var_state.pp x_i Var_state.pp x_j);
pivot_and_update self x_i x_j bnd.b_val pivot_and_update self x_i x_j bnd.b_val
) else ( ) else (
(* line 10 *) (* line 10 *)
@ -660,6 +789,9 @@ module Make(Var: VAR)
assert (Var_state.is_n_basic x_j); assert (Var_state.is_n_basic x_j);
(* line 14 *) (* line 14 *)
Log.debugf 50
(fun k->k "(@[simplex2.pivot@ :basic %a@ :n-basic %a@])"
Var_state.pp x_i Var_state.pp x_j);
pivot_and_update self x_i x_j bnd.b_val pivot_and_update self x_i x_j bnd.b_val
) )
done; done;
@ -702,9 +834,9 @@ module Make(Var: VAR)
assert (Var_state.in_bounds_ x); assert (Var_state.in_bounds_ x);
let x_val = x.value in let x_val = x.value in
Log.debugf 50 (fun k->k "v.base=%a, v.eps=%a, emax=%a" (*Log.debugf 50 (fun k->k "(@[solve-eps v.base=%a, v.eps=%a, emax=%a@])"
pp_q_dbg x_val.base pp_q_dbg x_val.eps_factor pp_q_dbg x_val.base pp_q_dbg x_val.eps_factor
pp_q_dbg eps); pp_q_dbg eps);*)
(* is lower bound *) (* is lower bound *)
let eps = match x.l_bound with let eps = match x.l_bound with
@ -742,7 +874,6 @@ module Make(Var: VAR)
Q.((b_val.base - x_val.base) / Q.((b_val.base - x_val.base) /
(x_val.eps_factor - b_val.eps_factor)) (x_val.eps_factor - b_val.eps_factor))
in in
Log.debugf 5 (fun k->k "new max=%.5f" @@ Q.to_float new_eps);
Q.min eps new_eps Q.min eps new_eps
| _ -> eps | _ -> eps
in in
@ -828,15 +959,4 @@ module Make(Var: VAR)
end end
*) *)
let _check_invariants self : unit =
Vec.iteri (fun i v -> assert(v.idx = i)) self.vars;
let n = Vec.size self.vars in
assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n);
for i = 0 to Matrix.n_rows self.matrix-1 do
let v = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in
assert (Var_state.is_basic v);
assert (v.basic_idx = i);
done;
() (* TODO: more *)
end end