diff --git a/src/lra/sidekick_arith_lra.ml b/src/lra/sidekick_arith_lra.ml index fa334ace..ee90bef4 100644 --- a/src/lra/sidekick_arith_lra.ml +++ b/src/lra/sidekick_arith_lra.ml @@ -619,6 +619,8 @@ module Make(A : ARG) : S with module A = A = struct ~f:(fun (n1,n2) -> add_local_eq self si acts n1 n2); Log.debug 5 "(th-lra: call arith solver)"; + (* TODO: jiggle model to reduce the number of variables that + have the same value *) let model = check_simplex_ self si acts in Log.debugf 20 (fun k->k "(@[lra.model@ %a@])" SimpSolver.Subst.pp model); Log.debug 5 "lra: solver returns SAT"; diff --git a/src/lra/simplex2.ml b/src/lra/simplex2.ml index c59e4742..69f4c072 100644 --- a/src/lra/simplex2.ml +++ b/src/lra/simplex2.ml @@ -18,13 +18,13 @@ module Op = struct | Geq | Gt - let neg_sign = function + let[@inline] neg_sign = function | Leq -> Geq | Lt -> Gt | Geq -> Leq | Gt -> Lt - let not_ = function + let[@inline] not_ = function | Leq -> Gt | Lt -> Geq | Geq -> Lt @@ -107,7 +107,8 @@ module type S = sig @raise Unsat if it's immediately obvious that this is not satisfiable. *) val declare_bound : t -> Constraint.t -> V.lit -> unit - (** Declare that this constraint exists, so we can possibly propagate it. + (** Declare that this constraint exists and map it to a literal, + so we can possibly propagate it later. Unlike {!add_constraint} this does {b NOT} assert that the constraint is true *) @@ -271,6 +272,8 @@ module Make(Q : RATIONAL)(Var: VAR) val n_rows : t -> int val n_cols : t -> int + val iter_rows : ?skip:int -> t -> (int -> var_state -> unit) -> unit + val iter_cols : ?skip:int -> t -> (int -> unit) -> unit val add_column : t -> unit (** Add a non-basic variable (only adds a column) *) @@ -350,6 +353,15 @@ module Make(Q : RATIONAL)(Var: VAR) let r = Vec.get self.rows i in Vec.set r.cols j n + let[@inline] iter_rows ?(skip= ~-1) (self:t) f : unit = + Vec.iteri (fun i row -> + if i<>skip then f i row.vs + ) self.rows + let[@inline] iter_cols ?(skip= ~-1) (self:t) f : unit = + for i=0 to n_cols self-1 do + if i<>skip then f i + done + (* add [n] to [m_ij] *) let add self i j n : unit = let r = Vec.get self.rows i in @@ -434,17 +446,16 @@ module Make(Q : RATIONAL)(Var: VAR) Vec.iteri (fun i v -> assert(v.idx = i)) self.vars; let n = Vec.size self.vars in assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n); - for i = 0 to Matrix.n_rows self.matrix-1 do - let v = Matrix.get_row_var self.matrix i in - assert (Var_state.is_basic v); - assert (v.basic_idx = i); - assert Q.(Matrix.get self.matrix v.basic_idx v.idx = minus_one); + Matrix.iter_rows self.matrix begin fun i x_i -> + assert (Var_state.is_basic x_i); + assert (x_i.basic_idx = i); + assert Q.(Matrix.get self.matrix x_i.basic_idx x_i.idx = minus_one); (* basic vars are only defined in terms of non-basic vars *) Vec.iteri - (fun j v_j -> - if Var_state.(v != v_j) && Q.(Matrix.get self.matrix i j <> zero) then ( - assert (Var_state.is_n_basic v_j) + (fun j x_j -> + if Var_state.(x_i != x_j) && Q.(Matrix.get self.matrix i j <> zero) then ( + assert (Var_state.is_n_basic x_j) )) self.vars; @@ -458,7 +469,7 @@ module Make(Q : RATIONAL)(Var: VAR) Log.debugf 50 (fun k->k "row %d: sum %a" i Erat.pp sum); assert Erat.(sum = zero); - done; + end; () (* for internal checking *) @@ -482,27 +493,25 @@ module Make(Q : RATIONAL)(Var: VAR) let i_low = ref Erat.zero in let i_up = ref Erat.zero in - for j=0 to Matrix.n_cols m-1 do - if j <> x_i.idx then ( - let a_ij: Q.t = Matrix.get m x_i.basic_idx j in - let x_j = Vec.get self.vars j in + Matrix.iter_cols m ~skip:x_i.idx begin fun j -> + let a_ij: Q.t = Matrix.get m x_i.basic_idx j in + let x_j = Vec.get self.vars j in - let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in - let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in + let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in + let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in - if Q.(a_ij = zero) then() - else if Q.(a_ij > zero) then ( - i_low := Erat.(!i_low + a_ij * low_j); - i_up := Erat.(!i_up + a_ij * up_j); - ) else ( - (* [a_ij < 0] and [x_j < up], - means [-a_ij > 0] and [-x_j > -up] - means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *) - i_low := Erat.(!i_low + a_ij * up_j); - i_up := Erat.(!i_up + a_ij * low_j); - ) + if Q.(a_ij = zero) then() + else if Q.(a_ij > zero) then ( + i_low := Erat.(!i_low + a_ij * low_j); + i_up := Erat.(!i_up + a_ij * up_j); + ) else ( + (* [a_ij < 0] and [x_j < up], + means [-a_ij > 0] and [-x_j > -up] + means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *) + i_low := Erat.(!i_low + a_ij * up_j); + i_up := Erat.(!i_up + a_ij * low_j); ) - done; + end; let old_i_low = x_i.l_implied in let old_i_up = x_i.u_implied in @@ -562,18 +571,16 @@ module Make(Q : RATIONAL)(Var: VAR) (* [v2] is also basic, so instead of depending on it, we depend on its definition in terms of non-basic variables. *) - for k=0 to Matrix.n_cols self.matrix - 1 do - if k <> v2.idx then ( - let v2_jk = Matrix.get self.matrix v2.basic_idx k in - if Q.(v2_jk <> zero) then ( - let v_k = Vec.get self.vars k in - assert (Var_state.is_n_basic v_k); + Matrix.iter_cols ~skip:v2.idx self.matrix begin fun k -> + let v2_jk = Matrix.get self.matrix v2.basic_idx k in + if Q.(v2_jk <> zero) then ( + let x_k = Vec.get self.vars k in + assert (Var_state.is_n_basic x_k); - (* [v2 := v2_jk * v_k + …], so [v := … + coeff * v2_jk * v_k] *) - Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk); - ); + (* [v2 := v2_jk * x_k + …], so [v := … + coeff * v2_jk * x_k] *) + Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk); ); - done; + end; ) else ( (* directly add coefficient with non-basic var [v2] *) Matrix.add self.matrix vs.basic_idx v2.idx coeff; @@ -622,15 +629,14 @@ module Make(Q : RATIONAL)(Var: VAR) let diff = Erat.(v - x.value) in - for j=0 to Matrix.n_rows m - 1 do + Matrix.iter_rows m begin fun j x_j -> let a_ji = Matrix.get m j i in - let x_j = Matrix.get_row_var m j in assert (Var_state.is_basic x_j); (* value of [x_j] by [a_ji * diff] *) let new_val = Erat.(x_j.value + a_ji * diff) in (* Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); *) x_j.value <- new_val; - done; + end; x.value <- v; _check_invariants_internal self; () @@ -650,21 +656,18 @@ module Make(Q : RATIONAL)(Var: VAR) x_i.value <- v; x_j.value <- Erat.(x_j.value + theta); - for k=0 to Matrix.n_rows m-1 do - if k <> x_i.basic_idx then ( - let x_k = Matrix.get_row_var m k in - let a_kj = Matrix.get m x_k.basic_idx x_j.idx in - x_k.value <- Erat.(x_k.value + a_kj * theta); - ) - done; + Matrix.iter_rows m ~skip:x_i.basic_idx begin fun _k x_k -> + let a_kj = Matrix.get m x_k.basic_idx x_j.idx in + x_k.value <- Erat.(x_k.value + a_kj * theta); + end; begin (* now pivot the variables so that [x_j]'s coeff is -1 and so that other basic variables only depend on non-basic variables. *) let new_coeff = Q.(minus_one / a_ij) in - for k=0 to Matrix.n_cols m-1 do + Matrix.iter_cols m begin fun k -> Matrix.mult m x_i.basic_idx k new_coeff; (* update row of [x_i] *) - done; + end; assert Q.(Matrix.get m x_i.basic_idx x_j.idx = minus_one); (* make [x_i] non basic, and [x_j] basic *) @@ -675,32 +678,29 @@ module Make(Q : RATIONAL)(Var: VAR) Matrix.set_row_var m x_j.basic_idx x_j; (* adjust other rows so they don't depend on [x_j] *) - for k=0 to Matrix.n_rows m-1 do - if k <> x_j.basic_idx then ( - let x_k = Matrix.get_row_var m k in - assert (Var_state.is_basic x_k); + Matrix.iter_rows ~skip:x_j.basic_idx m begin fun k x_k -> + assert (Var_state.is_basic x_k); - let c_kj = Matrix.get m k x_j.idx in - if Q.(c_kj <> zero) then ( - (* [m[k,j] != 0] indicates that basic variable [x_k] depends on + let c_kj = Matrix.get m k x_j.idx in + if Q.(c_kj <> zero) then ( + (* [m[k,j] != 0] indicates that basic variable [x_k] depends on [x_j], which is about to become basic. To avoid basic-basic dependency we replace [x_j] by its (new) definition *) - for l=0 to Matrix.n_cols m-1 do - if l<>x_j.idx then ( - let c_jl = Matrix.get m x_j.basic_idx l in - (* so: + Matrix.iter_cols m begin fun l -> + if l<>x_j.idx then ( + let c_jl = Matrix.get m x_j.basic_idx l in + (* so: [x_k := c_kj * x_j + …], we want to eliminate [x_j], and [x_j := … + c_jl * x_l + …]. therefore [x_j := … + c_jl * c_kl * x_l] *) - Matrix.add m k l Q.(c_kj * c_jl); - ) - done; + Matrix.add m k l Q.(c_kj * c_jl); + ) + end; - Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *) - ) + Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *) ) - done; + end; update_implied_bounds_ self x_j; end; @@ -852,13 +852,12 @@ module Make(Q : RATIONAL)(Var: VAR) (self:t) (vs:var_state) ~on_propagate : unit = (* section 3.2.5: update implied bounds on basic variables that depend on [vs] *) - for i = 0 to Matrix.n_rows self.matrix -1 do + Matrix.iter_rows self.matrix begin fun i x_i -> if Q.(Matrix.get self.matrix i vs.idx <> zero) then ( - let x_i = Matrix.get_row_var self.matrix i in update_implied_bounds_ self x_i; propagate_basic_implied_bounds self ~on_propagate x_i; - ); - done + ) + end let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit = let open Constraint in @@ -971,19 +970,18 @@ module Make(Q : RATIONAL)(Var: VAR) (* try to find basic variable that doesn't respect one of its bounds *) let find_basic_var_ (self:t) : (var_state * [`Lower | `Upper] * bound) option = - let n = Matrix.n_rows self.matrix in - let rec aux i = - if i >= n then None - else ( - let x_i = Matrix.get_row_var self.matrix i in + let exception Found of var_state * [`Lower | `Upper] * bound in + + try + Matrix.iter_rows self.matrix begin fun _i x_i -> let v_i = x_i.value in match x_i.l_bound, x_i.u_bound with - | Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l) - | _, Some u when Erat.(u.b_val < v_i) -> Some (x_i, `Upper, u) - | _ -> (aux[@tailcall]) (i+1) - ) - in - aux 0 + | Some l, _ when Erat.(l.b_val > v_i) -> raise_notrace (Found (x_i, `Lower, l)) + | _, Some u when Erat.(u.b_val < v_i) -> raise_notrace (Found (x_i, `Upper, u)) + | _ -> () + end; + None + with Found (v,k,bnd) -> Some (v,k,bnd) let find_n_basic_var_ (self:t) ~f : var_state option = let rec aux j =