From 34b1aa17994de146d7d205344397b3952f1e18e0 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 21 Mar 2021 01:14:33 -0400 Subject: [PATCH] wip: feat(lra): propagate literals based on implied bounds for basic vars --- src/lra/simplex2.ml | 203 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 170 insertions(+), 33 deletions(-) diff --git a/src/lra/simplex2.ml b/src/lra/simplex2.ml index 2b49b109..e61e8c6e 100644 --- a/src/lra/simplex2.ml +++ b/src/lra/simplex2.ml @@ -216,6 +216,8 @@ module Make(Q : RATIONAL)(Var: VAR) let[@inline] (>=) a b = compare a b >= 0 let[@inline] (=) a b = compare a b = 0 + let plus_inf = make Q.inf Q.zero + let minus_inf = make Q.minus_inf Q.zero let[@inline] min x y = if x <= y then x else y let[@inline] max x y = if x >= y then x else y @@ -242,6 +244,10 @@ module Make(Q : RATIONAL)(Var: VAR) mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *) mutable l_bound: bound option; mutable u_bound: bound option; + + mutable l_implied: Erat.t; (* implied lower bound for a basic var *) + mutable u_implied: Erat.t; + mutable all_bound_lits : (is_lower * bound) list; (* all known literals on this var *) } @@ -390,6 +396,7 @@ module Make(Q : RATIONAL)(Var: VAR) vars: var_state Vec.t; (* index -> var with this index *) mutable var_tbl: var_state V_map.t; (* var -> its state *) bound_stack: (var_state * [`Upper|`Lower] * bound option) Backtrack_stack.t; + undo_stack: (unit -> unit) Backtrack_stack.t; stat_check: int Stat.counter; stat_unsat: int Stat.counter; stat_define: int Stat.counter; @@ -398,7 +405,8 @@ module Make(Q : RATIONAL)(Var: VAR) let push_level self : unit = Log.debug 10 "(simplex2.push-level)"; - Backtrack_stack.push_level self.bound_stack + Backtrack_stack.push_level self.bound_stack; + Backtrack_stack.push_level self.undo_stack let pop_levels self n : unit = Log.debugf 10 (fun k->k "(simplex2.pop-levels %d)" n); @@ -407,6 +415,8 @@ module Make(Q : RATIONAL)(Var: VAR) match kind with | `Upper -> var.u_bound <- bnd | `Lower -> var.l_bound <- bnd); + Backtrack_stack.pop_levels self.undo_stack n + ~f:(fun f -> f()); () let pp_stats out (self:t) : unit = @@ -459,6 +469,53 @@ module Make(Q : RATIONAL)(Var: VAR) try V_map.find x self.var_tbl with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x + let[@inline] get_bnd_or v (b:bound option) = + match b with None -> v | Some b -> b.b_val + + (* update implied bounds for basic variable [x_i] by looking at each + non-basic variable's bounds *) + let update_implied_bounds_ (self:t) (x_i:var_state) : unit = + assert (Var_state.is_basic x_i); + let m = self.matrix in + + let i_low = ref Erat.zero in + let i_up = ref Erat.zero in + + for j=0 to Matrix.n_cols m-1 do + if j <> x_i.idx then ( + let a_ij: Q.t = Matrix.get m x_i.basic_idx j in + let x_j = Vec.get self.vars j in + + let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in + let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in + + if Q.(a_ij = zero) then() + else if Q.(a_ij > zero) then ( + i_low := Erat.(!i_low + a_ij * low_j); + i_up := Erat.(!i_up + a_ij * up_j); + ) else ( + (* [a_ij < 0] and [x_j < up], + means [-a_ij > 0] and [-x_j > -up] + means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *) + i_low := Erat.(!i_low + a_ij * up_j); + i_up := Erat.(!i_up + a_ij * low_j); + ) + ) + done; + + let old_i_low = x_i.l_implied in + let old_i_up = x_i.u_implied in + Backtrack_stack.push self.undo_stack + (fun () -> + x_i.l_implied <- old_i_low; + x_i.u_implied <- old_i_up); + x_i.l_implied <- !i_low; + x_i.u_implied <- !i_up; + Log.debugf 50 + (fun k->k"(@[lra.implied-bounds@ :var %a@ :lower %a@ :upper %a@])" + Var_state.pp x_i Erat.pp !i_low Erat.pp !i_up); + () + (* define [x] as a basic variable *) let define (self:t) (x:V.t) (le:_ list) : unit = assert (not (has_var_ self x)); @@ -481,6 +538,8 @@ module Make(Q : RATIONAL)(Var: VAR) basic_idx=row_idx; l_bound=None; u_bound=None; + l_implied=Erat.minus_inf; + u_implied=Erat.plus_inf; all_bound_lits=[]; }) in @@ -520,6 +579,8 @@ module Make(Q : RATIONAL)(Var: VAR) ); ) le; + update_implied_bounds_ self vs; + (* Log.debugf 50 (fun k->k "post-define: %a" pp self); *) _check_invariants_internal self; () @@ -535,6 +596,8 @@ module Make(Q : RATIONAL)(Var: VAR) var=x; l_bound=None; u_bound=None; + l_implied=Erat.minus_inf; + u_implied=Erat.plus_inf; value=Erat.zero; all_bound_lits=[]; } in @@ -606,6 +669,8 @@ module Make(Q : RATIONAL)(Var: VAR) (* make [x_i] non basic, and [x_j] basic *) x_j.basic_idx <- x_i.basic_idx; x_i.basic_idx <- -1; + x_i.l_implied <- Erat.minus_inf; + x_i.u_implied <- Erat.plus_inf; Matrix.set_row_var m x_j.basic_idx x_j; (* adjust other rows so they don't depend on [x_j] *) @@ -635,6 +700,8 @@ module Make(Q : RATIONAL)(Var: VAR) ) ) done; + + update_implied_bounds_ self x_j; end; assert (Var_state.is_basic x_j); @@ -693,6 +760,91 @@ module Make(Q : RATIONAL)(Var: VAR) ignore (find_or_create_n_basic_var_ self v : var_state); () + (* gather all relevant lower (resp. upper) bounds for the definition + of the basic variable [x_i], and collect each relevant variable + with [map] into a list. *) + let gather_bounds_of_row_ (self:t) (x_i:var_state) ~map ~is_lower : _ list * _ = + assert (Var_state.is_basic x_i); + let map_res = ref [] in + let bounds = ref V_map.empty in + Vec.iteri + (fun j x_j -> + if j <> x_i.idx then ( + let c = Matrix.get self.matrix x_i.basic_idx j in + if Q.(c <> zero) then ( + match is_lower, Q.(c > zero) with + | true, true + | false, false -> + begin match x_j.u_bound with + | Some u -> + map_res := (map c x_j u) :: !map_res; + let op = if Q.(u.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in + bounds := V_map.add x_j.var (op, u) !bounds + | None -> assert false (* we could increase [x_j]?! *) + end + | true, false + | false, true -> + begin match x_j.l_bound with + | Some l -> + map_res := (map c x_j l) :: !map_res; + let op = if Q.(l.b_val.eps_factor <= zero) then Op.Geq else Op.Gt in + bounds := V_map.add x_j.var (op, l) !bounds + | None -> assert false (* we could decrease [x_j]?! *) + end + ) + )) + self.vars; + !map_res, !bounds + + (* do propagations for basic var [x_i] based on its implied bounds *) + let propagate_basic_implied_bounds (self:t) ~on_propagate (x_i:var_state) : unit = + assert (Var_state.is_basic x_i); + + let lits_of_row_ ~is_lower : V.lit list = + let l, _ = + gather_bounds_of_row_ self x_i ~is_lower + ~map:(fun _ _ bnd -> bnd.b_lit) in + l + in + + let process_bount_lit (is_lower, bnd): unit = + if is_lower then ( + if Erat.(bnd.b_val < x_i.l_implied) then ( + (* implied lower bound subsumes this lower bound *) + let reason = lits_of_row_ ~is_lower:true in + on_propagate bnd.b_lit ~reason + ); + if Erat.(bnd.b_val > x_i.u_implied) then ( + (* lower bound is higher than implied upper bound *) + match V.not_lit bnd.b_lit with + | Some not_lit -> + Log.debugf 50 + (fun k->k"(@[lra.propagate.not@ :lower-bnd-of %a@ :bnd %a :lit %a@ :u-implied %a@])" + Var_state.pp x_i Erat.pp bnd.b_val V.pp_lit bnd.b_lit Erat.pp x_i.u_implied); + let reason = lits_of_row_ ~is_lower:false in + on_propagate not_lit ~reason + | None -> () + ) + ) else ( + if Erat.(bnd.b_val > x_i.u_implied) then ( + (* implied upper bound subsumes this upper bound *) + let reason = lits_of_row_ ~is_lower:false in + on_propagate bnd.b_lit ~reason + ); + if Erat.(bnd.b_val < x_i.l_implied) then ( + (* upper bound is lower than implied lower bound *) + match V.not_lit bnd.b_lit with + | Some not_lit -> + let reason = lits_of_row_ ~is_lower:true in + on_propagate not_lit ~reason + | None -> () + ) + ) + in + + List.iter process_bount_lit x_i.all_bound_lits; + () + let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit = let open Constraint in @@ -743,7 +895,17 @@ module Make(Q : RATIONAL)(Var: VAR) if Var_state.is_n_basic vs && Erat.(vs.value < new_bnd.b_val) then ( (* line 5: need to update non-basic variable *) - update_n_basic self vs new_bnd.b_val + update_n_basic self vs new_bnd.b_val; + + (* section 3.2.5: update implied bounds on basic variables that + depend on [vs] *) + for i = 0 to Matrix.n_rows self.matrix -1 do + if Q.(Matrix.get self.matrix i vs.idx <> zero) then ( + let x_i = Matrix.get_row_var self.matrix i in + update_implied_bounds_ self x_i; + propagate_basic_implied_bounds self ~on_propagate x_i; + ); + done; ) end ) else ( @@ -839,8 +1001,6 @@ module Make(Q : RATIONAL)(Var: VAR) | None -> true | Some bnd -> Erat.(x.value > bnd.b_val) - (* TODO: certificate checker *) - (* make a certificate from the row of basic variable [x_i] which is outside one of its bound, and whose row is tight on all non-basic variables. @param is_lower is true if the lower bound is not respected @@ -849,41 +1009,17 @@ module Make(Q : RATIONAL)(Var: VAR) Log.debugf 50 (fun k->k "(@[simplex.cert-of-row[lower: %B]@ x_i=%a@])" is_lower Var_state.pp x_i); assert (Var_state.is_basic x_i); - let le = ref [] in - let bounds = ref V_map.empty in - Vec.iteri - (fun j x_j -> - if j <> x_i.idx then ( - let c = Matrix.get self.matrix x_i.basic_idx j in - if Q.(c <> zero) then ( - le := (c, x_j.var) :: !le; - match is_lower, Q.(c > zero) with - | true, true - | false, false -> - begin match x_j.u_bound with - | Some u -> - let op = if Q.(u.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in - bounds := V_map.add x_j.var (op, u) !bounds - | None -> assert false (* we could increase [x_j]?! *) - end - | true, false - | false, true -> - begin match x_j.l_bound with - | Some l -> - let op = if Q.(l.b_val.eps_factor <= zero) then Op.Geq else Op.Gt in - bounds := V_map.add x_j.var (op, l) !bounds - | None -> assert false (* we could decrease [x_j]?! *) - end - ) - )) - self.vars; + let le, bounds = + gather_bounds_of_row_ self x_i ~is_lower + ~map:(fun c v _ -> c, v.var) + in let op = if is_lower then if Q.(bnd.b_val.eps_factor <= zero) then Op.Geq else Op.Gt else if Q.(bnd.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in - let cert = Unsat_cert.unsat_basic x_i (op, bnd) !le !bounds in + let cert = Unsat_cert.unsat_basic x_i (op, bnd) le bounds in cert (* main satisfiability check. @@ -964,6 +1100,7 @@ module Make(Q : RATIONAL)(Var: VAR) vars=Vec.create(); var_tbl=V_map.empty; bound_stack=Backtrack_stack.create(); + undo_stack=Backtrack_stack.create(); stat_check=Stat.mk_int stat "simplex.check"; stat_unsat=Stat.mk_int stat "simplex.unsat"; stat_define=Stat.mk_int stat "simplex.define";