From 4d9f99e65d2d32946638abd67c5783849eb8ab98 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 15 Feb 2021 13:29:12 -0500 Subject: [PATCH] fix(simplex2): correct pivot; refactor; better printing --- src/arith/lra/simplex2.ml | 286 +++++++++++++++++++++++++++----------- 1 file changed, 203 insertions(+), 83 deletions(-) diff --git a/src/arith/lra/simplex2.ml b/src/arith/lra/simplex2.ml index 626bc5d7..fbd26f7a 100644 --- a/src/arith/lra/simplex2.ml +++ b/src/arith/lra/simplex2.ml @@ -182,6 +182,7 @@ module Make(Var: VAR) let[@inline] (<=) a b = compare a b <= 0 let[@inline] (>) a b = compare a b > 0 let[@inline] (>=) a b = compare a b >= 0 + let[@inline] (=) a b = compare a b = 0 let[@inline] min x y = if x <= y then x else y let[@inline] max x y = if x >= y then x else y @@ -196,7 +197,19 @@ module Make(Var: VAR) pp_q_dbg (base e) pp_q_dbg (eps_factor e) end - type var_idx = int + type bound = { + b_val: erat; + b_lit: Var.lit; + } + + type var_state = { + var: V.t; + mutable value: erat; + idx: int; (* index in {!t.vars} *) + mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *) + mutable l_bound: bound option; + mutable u_bound: bound option; + } (** {2 Matrix} The matrix [A] from the paper, with m rows and n columns. @@ -221,14 +234,14 @@ module Make(Var: VAR) val add_column : t -> unit (** Add a non-basic variable (only adds a column) *) - val add_row_and_column : t -> int - (** Add a basic variable. returns the row index. *) + val add_row_and_column : t -> f:(row_idx:int -> col_idx:int -> var_state) -> var_state + (** Add a basic variable. *) - val get_row_var_idx : t -> int -> var_idx - (** Index of the basic variable for row [i] *) + val get_row_var : t -> int -> var_state + (** The basic variable for row [i] *) - val set_row_var_idx : t -> int -> var_idx -> unit - (** Set index of the basic variable for row [i] *) + val set_row_var : t -> int -> var_state -> unit + (** Set basic variable for row [i] *) val get : t -> int -> int -> num @@ -239,7 +252,7 @@ module Make(Var: VAR) val mult : t -> int -> int -> num -> unit end = struct type row = { - mutable var_idx: var_idx; + mutable vs: var_state; cols: num Vec.t; } type t = { @@ -254,30 +267,41 @@ module Make(Var: VAR) let[@inline] n_cols self = self.n_cols let pp out self = - Fmt.fprintf out "{@[matrix[%dx%d]@," (n_rows self) (n_cols self); + Fmt.fprintf out "@[{matrix[%dx%d]@," (n_rows self) (n_cols self); + + (* header *) + let ppi out i = + Fmt.string out @@ CCString.pad ~side:`Left 6 @@ Printf.sprintf "v%d" i in + Fmt.fprintf out "{@[%9s: %a@]}" "vars" + (Fmt.iter ~sep:(Fmt.return "@ ") ppi) CCInt.(0 -- (n_cols self-1)); + Vec.iteri (fun i row -> - Fmt.fprintf out "{@[r%-3d: %a@]}@," i + let hd = + CCString.pad ~side:`Left 6 @@ + Printf.sprintf "r%d (v%d)" i row.vs.idx in + Fmt.fprintf out "@,{@[%9s: %a@]}" hd (Fmt.iter ~sep:(Fmt.return "@ ") (pp_q_float 6)) (Vec.to_seq row.cols)) self.rows; - Fmt.fprintf out "@]}" + Fmt.fprintf out "@;<0 -1>}@]" let to_string = Fmt.to_string pp let add_column self = self.n_cols <- 1 + self.n_cols; Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows - let add_row_and_column self : int = + let add_row_and_column self ~f : var_state = let n = n_rows self in let j = n_cols self in add_column self; let cols = Vec.make (j+1) Q.zero in for _k=0 to j do Vec.push cols Q.zero done; - let row = {var_idx=j; cols} in + let vs = f ~row_idx:n ~col_idx:j in + let row = {vs; cols} in Vec.push self.rows row; - n + vs - let[@inline] get_row_var_idx self i = (Vec.get self.rows i).var_idx - let[@inline] set_row_var_idx self i n = (Vec.get self.rows i).var_idx <- n + let[@inline] get_row_var self i = (Vec.get self.rows i).vs + let[@inline] set_row_var self i v = (Vec.get self.rows i).vs <- v let[@inline] get self i j : num = Vec.get (Vec.get self.rows i).cols j @@ -294,27 +318,15 @@ module Make(Var: VAR) let mult self i j c : unit = let r = Vec.get self.rows i in let n_j = Vec.get r.cols j in - if Q.sign n_j <> 0 then ( + if Q.(n_j <> zero) then ( Vec.set r.cols j Q.(n_j * c) ) end - type bound = { - b_val: erat; - b_lit: Var.lit; - } - - type var_state = { - var: V.t; - mutable value: erat; - idx: int; (* index in {!t.vars} *) - mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *) - mutable l_bound: bound option; - mutable u_bound: bound option; - } - module Var_state = struct type t = var_state + let (==) : t -> t -> bool = Containers.Stdlib.(==) + let (!=) : t -> t -> bool = Containers.Stdlib.(!=) let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0 let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0 @@ -323,8 +335,9 @@ module Make(Var: VAR) (match self.l_bound with None -> true | Some b -> Erat.(self.value >= b.b_val)) && (match self.u_bound with None -> true | Some b -> Erat.(self.value <= b.b_val)) + let pp_name out self = Fmt.fprintf out "v%d" self.idx let pp out self = - let bnd_status = if in_bounds_ self then "" else "(out of bounds)" in + let bnd_status = if in_bounds_ self then "" else "(oob)" in let pp_bnd what out = function | None -> () | Some b -> Fmt.fprintf out "@ @[%s %a@]" what Erat.pp b.b_val @@ -332,10 +345,9 @@ module Make(Var: VAR) if self.basic_idx < 0 then () else Fmt.int out self.basic_idx in Fmt.fprintf out - "(@[var[%s%a]%s %a@ :value %a%a%a@])" - (if is_basic self then "B" else "N") pp_basic_idx () - bnd_status - Var.pp self.var Erat.pp self.value + "(@[v%d[%s%a]%s@ :value %a%a%a@])" + self.idx (if is_basic self then "B" else "N") pp_basic_idx () + bnd_status Erat.pp self.value (pp_bnd ":lower") self.l_bound (pp_bnd ":upper") self.u_bound end @@ -362,41 +374,108 @@ module Make(Var: VAR) (Vec.pp Var_state.pp) self.vars Matrix.pp self.matrix + (* for debug purposes *) + let _check_invariants self : unit = + Vec.iteri (fun i v -> assert(v.idx = i)) self.vars; + let n = Vec.size self.vars in + assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n); + for i = 0 to Matrix.n_rows self.matrix-1 do + let v = Matrix.get_row_var self.matrix i in + assert (Var_state.is_basic v); + assert (v.basic_idx = i); + assert Q.(Matrix.get self.matrix v.basic_idx v.idx = minus_one); + + (* basic vars are only defined in terms of non-basic vars *) + Vec.iteri + (fun j v_j -> + if Var_state.(v != v_j) && Q.(Matrix.get self.matrix i j <> zero) then ( + assert (Var_state.is_n_basic v_j) + )) + self.vars; + + (* sum of each row must be 0 *) + let sum = + Vec.fold + (fun sum v -> + Erat.(sum + Matrix.get self.matrix i v.idx * v.value)) + Erat.zero self.vars + in + Log.debugf 50 (fun k->k "row %d: sum %a" i Erat.pp sum); + assert Erat.(sum = zero); + + done; + () (* TODO: more *) + + (* for internal checking *) + let _check_invariants_internal self = + if false (* FUDGE *) then _check_invariants self + let[@inline] has_var_ (self:t) x : bool = V_map.mem x self.var_tbl let[@inline] find_var_ (self:t) x : var_state = try V_map.find x self.var_tbl with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x + (* define [x] as a basic variable *) let define (self:t) (x:V.t) (le:_ list) : unit = assert (not (has_var_ self x)); - Log.debugf 5 (fun k->k "(@[simplex.define@ %a@ :le %a@])" - Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var.pp)) le); - let idx = Vec.size self.vars in - let n = Matrix.add_row_and_column self.matrix in - let vs = { - var=x; value=Erat.zero; - idx; - basic_idx=n; - l_bound=None; - u_bound=None; - } in + (* Log.debugf 50 (fun k->k "define-in: %a" pp self); *) + let le = List.map (fun (q,v) -> q, find_var_ self v) le in + + (* initial value for the new variable *) + let value = + List.fold_left + (fun sum (c,v) -> Erat.(sum + c * v.value)) Erat.zero le + in + + let vs = + Matrix.add_row_and_column self.matrix + ~f:(fun ~row_idx ~col_idx -> + { + var=x; value; + idx=col_idx; + basic_idx=row_idx; + l_bound=None; + u_bound=None; + }) + in + Log.debugf 5 (fun k->k "(@[simplex.define@ @[v%d :var %a@]@ :linexpr %a@])" + vs.idx Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var_state.pp_name)) le); + assert (Var_state.is_basic vs); + assert Var_state.(Matrix.get_row_var self.matrix vs.basic_idx == vs); Vec.push self.vars vs; self.var_tbl <- V_map.add x vs self.var_tbl; + (* set coefficients in the matrix's new row: [-x + le = 0] *) - Matrix.set self.matrix n vs.idx Q.minus_one; + Matrix.set self.matrix vs.basic_idx vs.idx Q.minus_one; List.iter (fun (coeff,v2) -> - let vs2 = find_var_ self v2 in + assert Containers.Stdlib.(v2 != vs); - (* FIXME: if [vs2] is basic, instead recurse with vs2's row. - copy coefficients of [vs2]'s row but multiplied with [coeff] . + if Var_state.is_basic v2 then ( + (* [v2] is also basic, so instead of depending on it, + we depend on its definition in terms of non-basic variables. *) - See t4_short *) + for k=0 to Matrix.n_cols self.matrix - 1 do + if k <> v2.idx then ( + let v2_jk = Matrix.get self.matrix v2.basic_idx k in + if Q.(v2_jk <> zero) then ( + let v_k = Vec.get self.vars k in + assert (Var_state.is_n_basic v_k); - Matrix.add self.matrix n vs2.idx coeff; - vs.value <- Erat.(vs.value + coeff * vs2.value); (* update value of [v] *) + (* [v2 := v2_jk * v_k + …], so [v := … + coeff * v2_jk * v_k] *) + Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk); + ); + ); + done; + ) else ( + (* directly add coefficient with non-basic var [v2] *) + Matrix.add self.matrix vs.basic_idx v2.idx coeff; + ); ) le; + + Log.debugf 50 (fun k->k "post-define: %a" pp self); + _check_invariants_internal self; () (* find the state for [x], or add [x] as a non-basic variable *) @@ -423,8 +502,9 @@ module Make(Var: VAR) let update_n_basic (self:t) (x:var_state) (v:erat) : unit = assert (Var_state.is_n_basic x); Log.debugf 50 - (fun k->k "(@[simplex.update-n-basic@ %a@ :new-val %a@])" - Var_state.pp x Erat.pp v); + (fun k->k "(@[simplex.update-n-basic@ %a@ :new-val %a@ :in %a@])" + Var_state.pp x Erat.pp v pp self); + _check_invariants_internal self; let m = self.matrix in let i = x.idx in @@ -433,12 +513,15 @@ module Make(Var: VAR) for j=0 to Matrix.n_rows m - 1 do let a_ji = Matrix.get m j i in - let x_j = Vec.get self.vars (Matrix.get_row_var_idx m j) in + let x_j = Matrix.get_row_var m j in assert (Var_state.is_basic x_j); (* value of [x_j] by [a_ji * diff] *) - x_j.value <- Erat.(x_j.value + a_ji * diff); + let new_val = Erat.(x_j.value + a_ji * diff) in + Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); + x_j.value <- new_val; done; x.value <- v; + _check_invariants_internal self; () (* pivot [x_i] (basic) and [x_j] (non-basic), setting value of [x_i] @@ -447,6 +530,8 @@ module Make(Var: VAR) let pivot_and_update (self:t) (x_i:var_state) (x_j:var_state) (v:erat) : unit = assert (Var_state.is_basic x_i); assert (Var_state.is_n_basic x_j); + _check_invariants_internal self; + let m = self.matrix in let a_ij = Matrix.get m x_i.basic_idx x_j.idx in assert (Q.sign a_ij <> 0); @@ -456,23 +541,59 @@ module Make(Var: VAR) for k=0 to Matrix.n_rows m-1 do if k <> x_i.basic_idx then ( - let x_k = Vec.get self.vars (Matrix.get_row_var_idx m k) in + let x_k = Matrix.get_row_var m k in let a_kj = Matrix.get m x_k.basic_idx x_j.idx in x_k.value <- Erat.(x_k.value + a_kj * theta); ) done; - (* now pivot the variables so that [x_j]'s coeff is -1 *) - let new_coeff = Q.(minus_one / a_ij) in - for k=0 to Matrix.n_cols m-1 do - Matrix.mult m x_i.basic_idx k new_coeff; - done; - x_j.basic_idx <- x_i.basic_idx; - x_i.basic_idx <- -1; - Matrix.set_row_var_idx m x_j.basic_idx x_j.idx; + begin + (* now pivot the variables so that [x_j]'s coeff is -1 and so that + other basic variables only depend on non-basic variables. *) + let new_coeff = Q.(minus_one / a_ij) in + for k=0 to Matrix.n_cols m-1 do + Matrix.mult m x_i.basic_idx k new_coeff; (* update row of [x_i] *) + done; + assert Q.(Matrix.get m x_i.basic_idx x_j.idx = minus_one); + + (* make [x_i] non basic, and [x_j] basic *) + x_j.basic_idx <- x_i.basic_idx; + x_i.basic_idx <- -1; + Matrix.set_row_var m x_j.basic_idx x_j; + + (* adjust other rows so they don't depend on [x_j] *) + for k=0 to Matrix.n_rows m-1 do + if k <> x_j.basic_idx then ( + let x_k = Matrix.get_row_var m k in + assert (Var_state.is_basic x_k); + + let c_kj = Matrix.get m k x_j.idx in + if Q.(c_kj <> zero) then ( + (* [m[k,j] != 0] indicates that basic variable [x_k] depends on + [x_j], which is about to become basic. To avoid basic-basic + dependency we replace [x_j] by its (new) definition *) + + for l=0 to Matrix.n_cols m-1 do + if l<>x_j.idx then ( + let c_jl = Matrix.get m x_j.basic_idx l in + (* so: + [x_k := c_kj * x_j + …], we want to eliminate [x_j], + and [x_j := … + c_jl * x_l + …]. + therefore [x_j := … + c_jl * c_kl * x_l] *) + Matrix.add m k l Q.(c_kj * c_jl); + ) + done; + + Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *) + ) + ) + done; + end; assert (Var_state.is_basic x_j); assert (Var_state.is_n_basic x_i); + (* Log.debugf 50 (fun k->k "post pivot: %a" pp self); *) + _check_invariants_internal self; () @@ -504,8 +625,12 @@ module Make(Var: VAR) let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit = let open Constraint in - Log.debugf 5 (fun k->k "(@[simplex2.add-constraint@ %a@])" Constraint.pp c); + let vs = find_or_create_n_basic_var_ self c.lhs in + Log.debugf 5 + (fun k->k "(@[simplex2.add-constraint@ :var %a@ :c %a@])" + Var_state.pp_name vs Constraint.pp c); + let is_lower_bnd, new_bnd_val = match c.op with | Leq -> false, Erat.make_q c.rhs @@ -561,7 +686,7 @@ module Make(Var: VAR) let rec aux i = if i >= n then None else ( - let x_i = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in + let x_i = Matrix.get_row_var self.matrix i in let v_i = x_i.value in match x_i.l_bound, x_i.u_bound with | Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l) @@ -614,6 +739,7 @@ module Make(Var: VAR) let m = self.matrix in try while true do + _check_invariants_internal self; Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" pp self); (* basic variable that doesn't respect its bound *) @@ -641,6 +767,9 @@ module Make(Var: VAR) assert (Var_state.is_n_basic x_j); (* line 9 *) + Log.debugf 50 + (fun k->k "(@[simplex2.pivot@ :basic %a@ :n-basic %a@])" + Var_state.pp x_i Var_state.pp x_j); pivot_and_update self x_i x_j bnd.b_val ) else ( (* line 10 *) @@ -660,6 +789,9 @@ module Make(Var: VAR) assert (Var_state.is_n_basic x_j); (* line 14 *) + Log.debugf 50 + (fun k->k "(@[simplex2.pivot@ :basic %a@ :n-basic %a@])" + Var_state.pp x_i Var_state.pp x_j); pivot_and_update self x_i x_j bnd.b_val ) done; @@ -702,9 +834,9 @@ module Make(Var: VAR) assert (Var_state.in_bounds_ x); let x_val = x.value in - Log.debugf 50 (fun k->k "v.base=%a, v.eps=%a, emax=%a" + (*Log.debugf 50 (fun k->k "(@[solve-eps v.base=%a, v.eps=%a, emax=%a@])" pp_q_dbg x_val.base pp_q_dbg x_val.eps_factor - pp_q_dbg eps); + pp_q_dbg eps);*) (* is lower bound *) let eps = match x.l_bound with @@ -742,7 +874,6 @@ module Make(Var: VAR) Q.((b_val.base - x_val.base) / (x_val.eps_factor - b_val.eps_factor)) in - Log.debugf 5 (fun k->k "new max=%.5f" @@ Q.to_float new_eps); Q.min eps new_eps | _ -> eps in @@ -828,15 +959,4 @@ module Make(Var: VAR) end *) - let _check_invariants self : unit = - Vec.iteri (fun i v -> assert(v.idx = i)) self.vars; - let n = Vec.size self.vars in - assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n); - for i = 0 to Matrix.n_rows self.matrix-1 do - let v = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in - assert (Var_state.is_basic v); - assert (v.basic_idx = i); - done; - () (* TODO: more *) - end