refactor(LRA): custom iterators in simplex, makes code more readable

This commit is contained in:
Simon Cruanes 2022-01-04 11:09:25 -05:00
parent 2d9f17b5b1
commit 2bce3e6dd9
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 84 additions and 84 deletions

View file

@ -619,6 +619,8 @@ module Make(A : ARG) : S with module A = A = struct
~f:(fun (n1,n2) -> add_local_eq self si acts n1 n2);
Log.debug 5 "(th-lra: call arith solver)";
(* TODO: jiggle model to reduce the number of variables that
have the same value *)
let model = check_simplex_ self si acts in
Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model);
Log.debug 5 "lra: solver returns SAT";

View file

@ -18,13 +18,13 @@ module Op = struct
| Geq
| Gt
let neg_sign = function
let[@inline] neg_sign = function
| Leq -> Geq
| Lt -> Gt
| Geq -> Leq
| Gt -> Lt
let not_ = function
let[@inline] not_ = function
| Leq -> Gt
| Lt -> Geq
| Geq -> Lt
@ -107,7 +107,8 @@ module type S = sig
@raise Unsat if it's immediately obvious that this is not satisfiable. *)
val declare_bound : t -> Constraint.t -> V.lit -> unit
(** Declare that this constraint exists, so we can possibly propagate it.
(** Declare that this constraint exists and map it to a literal,
so we can possibly propagate it later.
Unlike {!add_constraint} this does {b NOT} assert that the constraint
is true *)
@ -271,6 +272,8 @@ module Make(Q : RATIONAL)(Var: VAR)
val n_rows : t -> int
val n_cols : t -> int
val iter_rows : ?skip:int -> t -> (int -> var_state -> unit) -> unit
val iter_cols : ?skip:int -> t -> (int -> unit) -> unit
val add_column : t -> unit
(** Add a non-basic variable (only adds a column) *)
@ -350,6 +353,15 @@ module Make(Q : RATIONAL)(Var: VAR)
let r = Vec.get self.rows i in
Vec.set r.cols j n
let[@inline] iter_rows ?(skip= ~-1) (self:t) f : unit =
Vec.iteri (fun i row ->
if i<>skip then f i row.vs
) self.rows
let[@inline] iter_cols ?(skip= ~-1) (self:t) f : unit =
for i=0 to n_cols self-1 do
if i<>skip then f i
done
(* add [n] to [m_ij] *)
let add self i j n : unit =
let r = Vec.get self.rows i in
@ -434,17 +446,16 @@ module Make(Q : RATIONAL)(Var: VAR)
Vec.iteri (fun i v -> assert(v.idx = i)) self.vars;
let n = Vec.size self.vars in
assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n);
for i = 0 to Matrix.n_rows self.matrix-1 do
let v = Matrix.get_row_var self.matrix i in
assert (Var_state.is_basic v);
assert (v.basic_idx = i);
assert Q.(Matrix.get self.matrix v.basic_idx v.idx = minus_one);
Matrix.iter_rows self.matrix begin fun i x_i ->
assert (Var_state.is_basic x_i);
assert (x_i.basic_idx = i);
assert Q.(Matrix.get self.matrix x_i.basic_idx x_i.idx = minus_one);
(* basic vars are only defined in terms of non-basic vars *)
Vec.iteri
(fun j v_j ->
if Var_state.(v != v_j) && Q.(Matrix.get self.matrix i j <> zero) then (
assert (Var_state.is_n_basic v_j)
(fun j x_j ->
if Var_state.(x_i != x_j) && Q.(Matrix.get self.matrix i j <> zero) then (
assert (Var_state.is_n_basic x_j)
))
self.vars;
@ -458,7 +469,7 @@ module Make(Q : RATIONAL)(Var: VAR)
Log.debugf 50 (fun k->k "row %d: sum %a" i Erat.pp sum);
assert Erat.(sum = zero);
done;
end;
()
(* for internal checking *)
@ -482,27 +493,25 @@ module Make(Q : RATIONAL)(Var: VAR)
let i_low = ref Erat.zero in
let i_up = ref Erat.zero in
for j=0 to Matrix.n_cols m-1 do
if j <> x_i.idx then (
let a_ij: Q.t = Matrix.get m x_i.basic_idx j in
let x_j = Vec.get self.vars j in
Matrix.iter_cols m ~skip:x_i.idx begin fun j ->
let a_ij: Q.t = Matrix.get m x_i.basic_idx j in
let x_j = Vec.get self.vars j in
let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in
let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in
let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in
let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in
if Q.(a_ij = zero) then()
else if Q.(a_ij > zero) then (
i_low := Erat.(!i_low + a_ij * low_j);
i_up := Erat.(!i_up + a_ij * up_j);
) else (
(* [a_ij < 0] and [x_j < up],
means [-a_ij > 0] and [-x_j > -up]
means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *)
i_low := Erat.(!i_low + a_ij * up_j);
i_up := Erat.(!i_up + a_ij * low_j);
)
if Q.(a_ij = zero) then()
else if Q.(a_ij > zero) then (
i_low := Erat.(!i_low + a_ij * low_j);
i_up := Erat.(!i_up + a_ij * up_j);
) else (
(* [a_ij < 0] and [x_j < up],
means [-a_ij > 0] and [-x_j > -up]
means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *)
i_low := Erat.(!i_low + a_ij * up_j);
i_up := Erat.(!i_up + a_ij * low_j);
)
done;
end;
let old_i_low = x_i.l_implied in
let old_i_up = x_i.u_implied in
@ -562,18 +571,16 @@ module Make(Q : RATIONAL)(Var: VAR)
(* [v2] is also basic, so instead of depending on it,
we depend on its definition in terms of non-basic variables. *)
for k=0 to Matrix.n_cols self.matrix - 1 do
if k <> v2.idx then (
let v2_jk = Matrix.get self.matrix v2.basic_idx k in
if Q.(v2_jk <> zero) then (
let v_k = Vec.get self.vars k in
assert (Var_state.is_n_basic v_k);
Matrix.iter_cols ~skip:v2.idx self.matrix begin fun k ->
let v2_jk = Matrix.get self.matrix v2.basic_idx k in
if Q.(v2_jk <> zero) then (
let x_k = Vec.get self.vars k in
assert (Var_state.is_n_basic x_k);
(* [v2 := v2_jk * v_k + …], so [v := … + coeff * v2_jk * v_k] *)
Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk);
);
(* [v2 := v2_jk * x_k + …], so [v := … + coeff * v2_jk * x_k] *)
Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk);
);
done;
end;
) else (
(* directly add coefficient with non-basic var [v2] *)
Matrix.add self.matrix vs.basic_idx v2.idx coeff;
@ -622,15 +629,14 @@ module Make(Q : RATIONAL)(Var: VAR)
let diff = Erat.(v - x.value) in
for j=0 to Matrix.n_rows m - 1 do
Matrix.iter_rows m begin fun j x_j ->
let a_ji = Matrix.get m j i in
let x_j = Matrix.get_row_var m j in
assert (Var_state.is_basic x_j);
(* value of [x_j] by [a_ji * diff] *)
let new_val = Erat.(x_j.value + a_ji * diff) in
(* Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); *)
x_j.value <- new_val;
done;
end;
x.value <- v;
_check_invariants_internal self;
()
@ -650,21 +656,18 @@ module Make(Q : RATIONAL)(Var: VAR)
x_i.value <- v;
x_j.value <- Erat.(x_j.value + theta);
for k=0 to Matrix.n_rows m-1 do
if k <> x_i.basic_idx then (
let x_k = Matrix.get_row_var m k in
let a_kj = Matrix.get m x_k.basic_idx x_j.idx in
x_k.value <- Erat.(x_k.value + a_kj * theta);
)
done;
Matrix.iter_rows m ~skip:x_i.basic_idx begin fun _k x_k ->
let a_kj = Matrix.get m x_k.basic_idx x_j.idx in
x_k.value <- Erat.(x_k.value + a_kj * theta);
end;
begin
(* now pivot the variables so that [x_j]'s coeff is -1 and so that
other basic variables only depend on non-basic variables. *)
let new_coeff = Q.(minus_one / a_ij) in
for k=0 to Matrix.n_cols m-1 do
Matrix.iter_cols m begin fun k ->
Matrix.mult m x_i.basic_idx k new_coeff; (* update row of [x_i] *)
done;
end;
assert Q.(Matrix.get m x_i.basic_idx x_j.idx = minus_one);
(* make [x_i] non basic, and [x_j] basic *)
@ -675,32 +678,29 @@ module Make(Q : RATIONAL)(Var: VAR)
Matrix.set_row_var m x_j.basic_idx x_j;
(* adjust other rows so they don't depend on [x_j] *)
for k=0 to Matrix.n_rows m-1 do
if k <> x_j.basic_idx then (
let x_k = Matrix.get_row_var m k in
assert (Var_state.is_basic x_k);
Matrix.iter_rows ~skip:x_j.basic_idx m begin fun k x_k ->
assert (Var_state.is_basic x_k);
let c_kj = Matrix.get m k x_j.idx in
if Q.(c_kj <> zero) then (
(* [m[k,j] != 0] indicates that basic variable [x_k] depends on
let c_kj = Matrix.get m k x_j.idx in
if Q.(c_kj <> zero) then (
(* [m[k,j] != 0] indicates that basic variable [x_k] depends on
[x_j], which is about to become basic. To avoid basic-basic
dependency we replace [x_j] by its (new) definition *)
for l=0 to Matrix.n_cols m-1 do
if l<>x_j.idx then (
let c_jl = Matrix.get m x_j.basic_idx l in
(* so:
Matrix.iter_cols m begin fun l ->
if l<>x_j.idx then (
let c_jl = Matrix.get m x_j.basic_idx l in
(* so:
[x_k := c_kj * x_j + ], we want to eliminate [x_j],
and [x_j := + c_jl * x_l + ].
therefore [x_j := + c_jl * c_kl * x_l] *)
Matrix.add m k l Q.(c_kj * c_jl);
)
done;
Matrix.add m k l Q.(c_kj * c_jl);
)
end;
Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *)
)
Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *)
)
done;
end;
update_implied_bounds_ self x_j;
end;
@ -852,13 +852,12 @@ module Make(Q : RATIONAL)(Var: VAR)
(self:t) (vs:var_state) ~on_propagate : unit =
(* section 3.2.5: update implied bounds on basic variables that
depend on [vs] *)
for i = 0 to Matrix.n_rows self.matrix -1 do
Matrix.iter_rows self.matrix begin fun i x_i ->
if Q.(Matrix.get self.matrix i vs.idx <> zero) then (
let x_i = Matrix.get_row_var self.matrix i in
update_implied_bounds_ self x_i;
propagate_basic_implied_bounds self ~on_propagate x_i;
);
done
)
end
let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit =
let open Constraint in
@ -971,19 +970,18 @@ module Make(Q : RATIONAL)(Var: VAR)
(* try to find basic variable that doesn't respect one of its bounds *)
let find_basic_var_ (self:t) : (var_state * [`Lower | `Upper] * bound) option =
let n = Matrix.n_rows self.matrix in
let rec aux i =
if i >= n then None
else (
let x_i = Matrix.get_row_var self.matrix i in
let exception Found of var_state * [`Lower | `Upper] * bound in
try
Matrix.iter_rows self.matrix begin fun _i x_i ->
let v_i = x_i.value in
match x_i.l_bound, x_i.u_bound with
| Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l)
| _, Some u when Erat.(u.b_val < v_i) -> Some (x_i, `Upper, u)
| _ -> (aux[@tailcall]) (i+1)
)
in
aux 0
| Some l, _ when Erat.(l.b_val > v_i) -> raise_notrace (Found (x_i, `Lower, l))
| _, Some u when Erat.(u.b_val < v_i) -> raise_notrace (Found (x_i, `Upper, u))
| _ -> ()
end;
None
with Found (v,k,bnd) -> Some (v,k,bnd)
let find_n_basic_var_ (self:t) ~f : var_state option =
let rec aux j =