mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 11:15:43 -05:00
refactor(LRA): custom iterators in simplex, makes code more readable
This commit is contained in:
parent
2d9f17b5b1
commit
2bce3e6dd9
2 changed files with 84 additions and 84 deletions
|
|
@ -619,6 +619,8 @@ module Make(A : ARG) : S with module A = A = struct
|
||||||
~f:(fun (n1,n2) -> add_local_eq self si acts n1 n2);
|
~f:(fun (n1,n2) -> add_local_eq self si acts n1 n2);
|
||||||
|
|
||||||
Log.debug 5 "(th-lra: call arith solver)";
|
Log.debug 5 "(th-lra: call arith solver)";
|
||||||
|
(* TODO: jiggle model to reduce the number of variables that
|
||||||
|
have the same value *)
|
||||||
let model = check_simplex_ self si acts in
|
let model = check_simplex_ self si acts in
|
||||||
Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model);
|
Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model);
|
||||||
Log.debug 5 "lra: solver returns SAT";
|
Log.debug 5 "lra: solver returns SAT";
|
||||||
|
|
|
||||||
|
|
@ -18,13 +18,13 @@ module Op = struct
|
||||||
| Geq
|
| Geq
|
||||||
| Gt
|
| Gt
|
||||||
|
|
||||||
let neg_sign = function
|
let[@inline] neg_sign = function
|
||||||
| Leq -> Geq
|
| Leq -> Geq
|
||||||
| Lt -> Gt
|
| Lt -> Gt
|
||||||
| Geq -> Leq
|
| Geq -> Leq
|
||||||
| Gt -> Lt
|
| Gt -> Lt
|
||||||
|
|
||||||
let not_ = function
|
let[@inline] not_ = function
|
||||||
| Leq -> Gt
|
| Leq -> Gt
|
||||||
| Lt -> Geq
|
| Lt -> Geq
|
||||||
| Geq -> Lt
|
| Geq -> Lt
|
||||||
|
|
@ -107,7 +107,8 @@ module type S = sig
|
||||||
@raise Unsat if it's immediately obvious that this is not satisfiable. *)
|
@raise Unsat if it's immediately obvious that this is not satisfiable. *)
|
||||||
|
|
||||||
val declare_bound : t -> Constraint.t -> V.lit -> unit
|
val declare_bound : t -> Constraint.t -> V.lit -> unit
|
||||||
(** Declare that this constraint exists, so we can possibly propagate it.
|
(** Declare that this constraint exists and map it to a literal,
|
||||||
|
so we can possibly propagate it later.
|
||||||
Unlike {!add_constraint} this does {b NOT} assert that the constraint
|
Unlike {!add_constraint} this does {b NOT} assert that the constraint
|
||||||
is true *)
|
is true *)
|
||||||
|
|
||||||
|
|
@ -271,6 +272,8 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
|
|
||||||
val n_rows : t -> int
|
val n_rows : t -> int
|
||||||
val n_cols : t -> int
|
val n_cols : t -> int
|
||||||
|
val iter_rows : ?skip:int -> t -> (int -> var_state -> unit) -> unit
|
||||||
|
val iter_cols : ?skip:int -> t -> (int -> unit) -> unit
|
||||||
|
|
||||||
val add_column : t -> unit
|
val add_column : t -> unit
|
||||||
(** Add a non-basic variable (only adds a column) *)
|
(** Add a non-basic variable (only adds a column) *)
|
||||||
|
|
@ -350,6 +353,15 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
let r = Vec.get self.rows i in
|
let r = Vec.get self.rows i in
|
||||||
Vec.set r.cols j n
|
Vec.set r.cols j n
|
||||||
|
|
||||||
|
let[@inline] iter_rows ?(skip= ~-1) (self:t) f : unit =
|
||||||
|
Vec.iteri (fun i row ->
|
||||||
|
if i<>skip then f i row.vs
|
||||||
|
) self.rows
|
||||||
|
let[@inline] iter_cols ?(skip= ~-1) (self:t) f : unit =
|
||||||
|
for i=0 to n_cols self-1 do
|
||||||
|
if i<>skip then f i
|
||||||
|
done
|
||||||
|
|
||||||
(* add [n] to [m_ij] *)
|
(* add [n] to [m_ij] *)
|
||||||
let add self i j n : unit =
|
let add self i j n : unit =
|
||||||
let r = Vec.get self.rows i in
|
let r = Vec.get self.rows i in
|
||||||
|
|
@ -434,17 +446,16 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
Vec.iteri (fun i v -> assert(v.idx = i)) self.vars;
|
Vec.iteri (fun i v -> assert(v.idx = i)) self.vars;
|
||||||
let n = Vec.size self.vars in
|
let n = Vec.size self.vars in
|
||||||
assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n);
|
assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n);
|
||||||
for i = 0 to Matrix.n_rows self.matrix-1 do
|
Matrix.iter_rows self.matrix begin fun i x_i ->
|
||||||
let v = Matrix.get_row_var self.matrix i in
|
assert (Var_state.is_basic x_i);
|
||||||
assert (Var_state.is_basic v);
|
assert (x_i.basic_idx = i);
|
||||||
assert (v.basic_idx = i);
|
assert Q.(Matrix.get self.matrix x_i.basic_idx x_i.idx = minus_one);
|
||||||
assert Q.(Matrix.get self.matrix v.basic_idx v.idx = minus_one);
|
|
||||||
|
|
||||||
(* basic vars are only defined in terms of non-basic vars *)
|
(* basic vars are only defined in terms of non-basic vars *)
|
||||||
Vec.iteri
|
Vec.iteri
|
||||||
(fun j v_j ->
|
(fun j x_j ->
|
||||||
if Var_state.(v != v_j) && Q.(Matrix.get self.matrix i j <> zero) then (
|
if Var_state.(x_i != x_j) && Q.(Matrix.get self.matrix i j <> zero) then (
|
||||||
assert (Var_state.is_n_basic v_j)
|
assert (Var_state.is_n_basic x_j)
|
||||||
))
|
))
|
||||||
self.vars;
|
self.vars;
|
||||||
|
|
||||||
|
|
@ -458,7 +469,7 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
Log.debugf 50 (fun k->k "row %d: sum %a" i Erat.pp sum);
|
Log.debugf 50 (fun k->k "row %d: sum %a" i Erat.pp sum);
|
||||||
assert Erat.(sum = zero);
|
assert Erat.(sum = zero);
|
||||||
|
|
||||||
done;
|
end;
|
||||||
()
|
()
|
||||||
|
|
||||||
(* for internal checking *)
|
(* for internal checking *)
|
||||||
|
|
@ -482,27 +493,25 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
let i_low = ref Erat.zero in
|
let i_low = ref Erat.zero in
|
||||||
let i_up = ref Erat.zero in
|
let i_up = ref Erat.zero in
|
||||||
|
|
||||||
for j=0 to Matrix.n_cols m-1 do
|
Matrix.iter_cols m ~skip:x_i.idx begin fun j ->
|
||||||
if j <> x_i.idx then (
|
let a_ij: Q.t = Matrix.get m x_i.basic_idx j in
|
||||||
let a_ij: Q.t = Matrix.get m x_i.basic_idx j in
|
let x_j = Vec.get self.vars j in
|
||||||
let x_j = Vec.get self.vars j in
|
|
||||||
|
|
||||||
let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in
|
let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in
|
||||||
let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in
|
let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in
|
||||||
|
|
||||||
if Q.(a_ij = zero) then()
|
if Q.(a_ij = zero) then()
|
||||||
else if Q.(a_ij > zero) then (
|
else if Q.(a_ij > zero) then (
|
||||||
i_low := Erat.(!i_low + a_ij * low_j);
|
i_low := Erat.(!i_low + a_ij * low_j);
|
||||||
i_up := Erat.(!i_up + a_ij * up_j);
|
i_up := Erat.(!i_up + a_ij * up_j);
|
||||||
) else (
|
) else (
|
||||||
(* [a_ij < 0] and [x_j < up],
|
(* [a_ij < 0] and [x_j < up],
|
||||||
means [-a_ij > 0] and [-x_j > -up]
|
means [-a_ij > 0] and [-x_j > -up]
|
||||||
means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *)
|
means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *)
|
||||||
i_low := Erat.(!i_low + a_ij * up_j);
|
i_low := Erat.(!i_low + a_ij * up_j);
|
||||||
i_up := Erat.(!i_up + a_ij * low_j);
|
i_up := Erat.(!i_up + a_ij * low_j);
|
||||||
)
|
|
||||||
)
|
)
|
||||||
done;
|
end;
|
||||||
|
|
||||||
let old_i_low = x_i.l_implied in
|
let old_i_low = x_i.l_implied in
|
||||||
let old_i_up = x_i.u_implied in
|
let old_i_up = x_i.u_implied in
|
||||||
|
|
@ -562,18 +571,16 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
(* [v2] is also basic, so instead of depending on it,
|
(* [v2] is also basic, so instead of depending on it,
|
||||||
we depend on its definition in terms of non-basic variables. *)
|
we depend on its definition in terms of non-basic variables. *)
|
||||||
|
|
||||||
for k=0 to Matrix.n_cols self.matrix - 1 do
|
Matrix.iter_cols ~skip:v2.idx self.matrix begin fun k ->
|
||||||
if k <> v2.idx then (
|
let v2_jk = Matrix.get self.matrix v2.basic_idx k in
|
||||||
let v2_jk = Matrix.get self.matrix v2.basic_idx k in
|
if Q.(v2_jk <> zero) then (
|
||||||
if Q.(v2_jk <> zero) then (
|
let x_k = Vec.get self.vars k in
|
||||||
let v_k = Vec.get self.vars k in
|
assert (Var_state.is_n_basic x_k);
|
||||||
assert (Var_state.is_n_basic v_k);
|
|
||||||
|
|
||||||
(* [v2 := v2_jk * v_k + …], so [v := … + coeff * v2_jk * v_k] *)
|
(* [v2 := v2_jk * x_k + …], so [v := … + coeff * v2_jk * x_k] *)
|
||||||
Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk);
|
Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk);
|
||||||
);
|
|
||||||
);
|
);
|
||||||
done;
|
end;
|
||||||
) else (
|
) else (
|
||||||
(* directly add coefficient with non-basic var [v2] *)
|
(* directly add coefficient with non-basic var [v2] *)
|
||||||
Matrix.add self.matrix vs.basic_idx v2.idx coeff;
|
Matrix.add self.matrix vs.basic_idx v2.idx coeff;
|
||||||
|
|
@ -622,15 +629,14 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
|
|
||||||
let diff = Erat.(v - x.value) in
|
let diff = Erat.(v - x.value) in
|
||||||
|
|
||||||
for j=0 to Matrix.n_rows m - 1 do
|
Matrix.iter_rows m begin fun j x_j ->
|
||||||
let a_ji = Matrix.get m j i in
|
let a_ji = Matrix.get m j i in
|
||||||
let x_j = Matrix.get_row_var m j in
|
|
||||||
assert (Var_state.is_basic x_j);
|
assert (Var_state.is_basic x_j);
|
||||||
(* value of [x_j] by [a_ji * diff] *)
|
(* value of [x_j] by [a_ji * diff] *)
|
||||||
let new_val = Erat.(x_j.value + a_ji * diff) in
|
let new_val = Erat.(x_j.value + a_ji * diff) in
|
||||||
(* Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); *)
|
(* Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); *)
|
||||||
x_j.value <- new_val;
|
x_j.value <- new_val;
|
||||||
done;
|
end;
|
||||||
x.value <- v;
|
x.value <- v;
|
||||||
_check_invariants_internal self;
|
_check_invariants_internal self;
|
||||||
()
|
()
|
||||||
|
|
@ -650,21 +656,18 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
x_i.value <- v;
|
x_i.value <- v;
|
||||||
x_j.value <- Erat.(x_j.value + theta);
|
x_j.value <- Erat.(x_j.value + theta);
|
||||||
|
|
||||||
for k=0 to Matrix.n_rows m-1 do
|
Matrix.iter_rows m ~skip:x_i.basic_idx begin fun _k x_k ->
|
||||||
if k <> x_i.basic_idx then (
|
let a_kj = Matrix.get m x_k.basic_idx x_j.idx in
|
||||||
let x_k = Matrix.get_row_var m k in
|
x_k.value <- Erat.(x_k.value + a_kj * theta);
|
||||||
let a_kj = Matrix.get m x_k.basic_idx x_j.idx in
|
end;
|
||||||
x_k.value <- Erat.(x_k.value + a_kj * theta);
|
|
||||||
)
|
|
||||||
done;
|
|
||||||
|
|
||||||
begin
|
begin
|
||||||
(* now pivot the variables so that [x_j]'s coeff is -1 and so that
|
(* now pivot the variables so that [x_j]'s coeff is -1 and so that
|
||||||
other basic variables only depend on non-basic variables. *)
|
other basic variables only depend on non-basic variables. *)
|
||||||
let new_coeff = Q.(minus_one / a_ij) in
|
let new_coeff = Q.(minus_one / a_ij) in
|
||||||
for k=0 to Matrix.n_cols m-1 do
|
Matrix.iter_cols m begin fun k ->
|
||||||
Matrix.mult m x_i.basic_idx k new_coeff; (* update row of [x_i] *)
|
Matrix.mult m x_i.basic_idx k new_coeff; (* update row of [x_i] *)
|
||||||
done;
|
end;
|
||||||
assert Q.(Matrix.get m x_i.basic_idx x_j.idx = minus_one);
|
assert Q.(Matrix.get m x_i.basic_idx x_j.idx = minus_one);
|
||||||
|
|
||||||
(* make [x_i] non basic, and [x_j] basic *)
|
(* make [x_i] non basic, and [x_j] basic *)
|
||||||
|
|
@ -675,32 +678,29 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
Matrix.set_row_var m x_j.basic_idx x_j;
|
Matrix.set_row_var m x_j.basic_idx x_j;
|
||||||
|
|
||||||
(* adjust other rows so they don't depend on [x_j] *)
|
(* adjust other rows so they don't depend on [x_j] *)
|
||||||
for k=0 to Matrix.n_rows m-1 do
|
Matrix.iter_rows ~skip:x_j.basic_idx m begin fun k x_k ->
|
||||||
if k <> x_j.basic_idx then (
|
assert (Var_state.is_basic x_k);
|
||||||
let x_k = Matrix.get_row_var m k in
|
|
||||||
assert (Var_state.is_basic x_k);
|
|
||||||
|
|
||||||
let c_kj = Matrix.get m k x_j.idx in
|
let c_kj = Matrix.get m k x_j.idx in
|
||||||
if Q.(c_kj <> zero) then (
|
if Q.(c_kj <> zero) then (
|
||||||
(* [m[k,j] != 0] indicates that basic variable [x_k] depends on
|
(* [m[k,j] != 0] indicates that basic variable [x_k] depends on
|
||||||
[x_j], which is about to become basic. To avoid basic-basic
|
[x_j], which is about to become basic. To avoid basic-basic
|
||||||
dependency we replace [x_j] by its (new) definition *)
|
dependency we replace [x_j] by its (new) definition *)
|
||||||
|
|
||||||
for l=0 to Matrix.n_cols m-1 do
|
Matrix.iter_cols m begin fun l ->
|
||||||
if l<>x_j.idx then (
|
if l<>x_j.idx then (
|
||||||
let c_jl = Matrix.get m x_j.basic_idx l in
|
let c_jl = Matrix.get m x_j.basic_idx l in
|
||||||
(* so:
|
(* so:
|
||||||
[x_k := c_kj * x_j + …], we want to eliminate [x_j],
|
[x_k := c_kj * x_j + …], we want to eliminate [x_j],
|
||||||
and [x_j := … + c_jl * x_l + …].
|
and [x_j := … + c_jl * x_l + …].
|
||||||
therefore [x_j := … + c_jl * c_kl * x_l] *)
|
therefore [x_j := … + c_jl * c_kl * x_l] *)
|
||||||
Matrix.add m k l Q.(c_kj * c_jl);
|
Matrix.add m k l Q.(c_kj * c_jl);
|
||||||
)
|
)
|
||||||
done;
|
end;
|
||||||
|
|
||||||
Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *)
|
Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *)
|
||||||
)
|
|
||||||
)
|
)
|
||||||
done;
|
end;
|
||||||
|
|
||||||
update_implied_bounds_ self x_j;
|
update_implied_bounds_ self x_j;
|
||||||
end;
|
end;
|
||||||
|
|
@ -852,13 +852,12 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
(self:t) (vs:var_state) ~on_propagate : unit =
|
(self:t) (vs:var_state) ~on_propagate : unit =
|
||||||
(* section 3.2.5: update implied bounds on basic variables that
|
(* section 3.2.5: update implied bounds on basic variables that
|
||||||
depend on [vs] *)
|
depend on [vs] *)
|
||||||
for i = 0 to Matrix.n_rows self.matrix -1 do
|
Matrix.iter_rows self.matrix begin fun i x_i ->
|
||||||
if Q.(Matrix.get self.matrix i vs.idx <> zero) then (
|
if Q.(Matrix.get self.matrix i vs.idx <> zero) then (
|
||||||
let x_i = Matrix.get_row_var self.matrix i in
|
|
||||||
update_implied_bounds_ self x_i;
|
update_implied_bounds_ self x_i;
|
||||||
propagate_basic_implied_bounds self ~on_propagate x_i;
|
propagate_basic_implied_bounds self ~on_propagate x_i;
|
||||||
);
|
)
|
||||||
done
|
end
|
||||||
|
|
||||||
let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit =
|
let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit =
|
||||||
let open Constraint in
|
let open Constraint in
|
||||||
|
|
@ -971,19 +970,18 @@ module Make(Q : RATIONAL)(Var: VAR)
|
||||||
|
|
||||||
(* try to find basic variable that doesn't respect one of its bounds *)
|
(* try to find basic variable that doesn't respect one of its bounds *)
|
||||||
let find_basic_var_ (self:t) : (var_state * [`Lower | `Upper] * bound) option =
|
let find_basic_var_ (self:t) : (var_state * [`Lower | `Upper] * bound) option =
|
||||||
let n = Matrix.n_rows self.matrix in
|
let exception Found of var_state * [`Lower | `Upper] * bound in
|
||||||
let rec aux i =
|
|
||||||
if i >= n then None
|
try
|
||||||
else (
|
Matrix.iter_rows self.matrix begin fun _i x_i ->
|
||||||
let x_i = Matrix.get_row_var self.matrix i in
|
|
||||||
let v_i = x_i.value in
|
let v_i = x_i.value in
|
||||||
match x_i.l_bound, x_i.u_bound with
|
match x_i.l_bound, x_i.u_bound with
|
||||||
| Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l)
|
| Some l, _ when Erat.(l.b_val > v_i) -> raise_notrace (Found (x_i, `Lower, l))
|
||||||
| _, Some u when Erat.(u.b_val < v_i) -> Some (x_i, `Upper, u)
|
| _, Some u when Erat.(u.b_val < v_i) -> raise_notrace (Found (x_i, `Upper, u))
|
||||||
| _ -> (aux[@tailcall]) (i+1)
|
| _ -> ()
|
||||||
)
|
end;
|
||||||
in
|
None
|
||||||
aux 0
|
with Found (v,k,bnd) -> Some (v,k,bnd)
|
||||||
|
|
||||||
let find_n_basic_var_ (self:t) ~f : var_state option =
|
let find_n_basic_var_ (self:t) ~f : var_state option =
|
||||||
let rec aux j =
|
let rec aux j =
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue