wip: feat(lra): propagate literals based on implied bounds for basic vars

This commit is contained in:
Simon Cruanes 2021-03-21 01:14:33 -04:00
parent be1c1573b1
commit 34b1aa1799
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -216,6 +216,8 @@ module Make(Q : RATIONAL)(Var: VAR)
let[@inline] (>=) a b = compare a b >= 0 let[@inline] (>=) a b = compare a b >= 0
let[@inline] (=) a b = compare a b = 0 let[@inline] (=) a b = compare a b = 0
let plus_inf = make Q.inf Q.zero
let minus_inf = make Q.minus_inf Q.zero
let[@inline] min x y = if x <= y then x else y let[@inline] min x y = if x <= y then x else y
let[@inline] max x y = if x >= y then x else y let[@inline] max x y = if x >= y then x else y
@ -242,6 +244,10 @@ module Make(Q : RATIONAL)(Var: VAR)
mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *) mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *)
mutable l_bound: bound option; mutable l_bound: bound option;
mutable u_bound: bound option; mutable u_bound: bound option;
mutable l_implied: Erat.t; (* implied lower bound for a basic var *)
mutable u_implied: Erat.t;
mutable all_bound_lits : (is_lower * bound) list; (* all known literals on this var *) mutable all_bound_lits : (is_lower * bound) list; (* all known literals on this var *)
} }
@ -390,6 +396,7 @@ module Make(Q : RATIONAL)(Var: VAR)
vars: var_state Vec.t; (* index -> var with this index *) vars: var_state Vec.t; (* index -> var with this index *)
mutable var_tbl: var_state V_map.t; (* var -> its state *) mutable var_tbl: var_state V_map.t; (* var -> its state *)
bound_stack: (var_state * [`Upper|`Lower] * bound option) Backtrack_stack.t; bound_stack: (var_state * [`Upper|`Lower] * bound option) Backtrack_stack.t;
undo_stack: (unit -> unit) Backtrack_stack.t;
stat_check: int Stat.counter; stat_check: int Stat.counter;
stat_unsat: int Stat.counter; stat_unsat: int Stat.counter;
stat_define: int Stat.counter; stat_define: int Stat.counter;
@ -398,7 +405,8 @@ module Make(Q : RATIONAL)(Var: VAR)
let push_level self : unit = let push_level self : unit =
Log.debug 10 "(simplex2.push-level)"; Log.debug 10 "(simplex2.push-level)";
Backtrack_stack.push_level self.bound_stack Backtrack_stack.push_level self.bound_stack;
Backtrack_stack.push_level self.undo_stack
let pop_levels self n : unit = let pop_levels self n : unit =
Log.debugf 10 (fun k->k "(simplex2.pop-levels %d)" n); Log.debugf 10 (fun k->k "(simplex2.pop-levels %d)" n);
@ -407,6 +415,8 @@ module Make(Q : RATIONAL)(Var: VAR)
match kind with match kind with
| `Upper -> var.u_bound <- bnd | `Upper -> var.u_bound <- bnd
| `Lower -> var.l_bound <- bnd); | `Lower -> var.l_bound <- bnd);
Backtrack_stack.pop_levels self.undo_stack n
~f:(fun f -> f());
() ()
let pp_stats out (self:t) : unit = let pp_stats out (self:t) : unit =
@ -459,6 +469,53 @@ module Make(Q : RATIONAL)(Var: VAR)
try V_map.find x self.var_tbl try V_map.find x self.var_tbl
with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x
let[@inline] get_bnd_or v (b:bound option) =
match b with None -> v | Some b -> b.b_val
(* update implied bounds for basic variable [x_i] by looking at each
non-basic variable's bounds *)
let update_implied_bounds_ (self:t) (x_i:var_state) : unit =
assert (Var_state.is_basic x_i);
let m = self.matrix in
let i_low = ref Erat.zero in
let i_up = ref Erat.zero in
for j=0 to Matrix.n_cols m-1 do
if j <> x_i.idx then (
let a_ij: Q.t = Matrix.get m x_i.basic_idx j in
let x_j = Vec.get self.vars j in
let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in
let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in
if Q.(a_ij = zero) then()
else if Q.(a_ij > zero) then (
i_low := Erat.(!i_low + a_ij * low_j);
i_up := Erat.(!i_up + a_ij * up_j);
) else (
(* [a_ij < 0] and [x_j < up],
means [-a_ij > 0] and [-x_j > -up]
means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *)
i_low := Erat.(!i_low + a_ij * up_j);
i_up := Erat.(!i_up + a_ij * low_j);
)
)
done;
let old_i_low = x_i.l_implied in
let old_i_up = x_i.u_implied in
Backtrack_stack.push self.undo_stack
(fun () ->
x_i.l_implied <- old_i_low;
x_i.u_implied <- old_i_up);
x_i.l_implied <- !i_low;
x_i.u_implied <- !i_up;
Log.debugf 50
(fun k->k"(@[lra.implied-bounds@ :var %a@ :lower %a@ :upper %a@])"
Var_state.pp x_i Erat.pp !i_low Erat.pp !i_up);
()
(* define [x] as a basic variable *) (* define [x] as a basic variable *)
let define (self:t) (x:V.t) (le:_ list) : unit = let define (self:t) (x:V.t) (le:_ list) : unit =
assert (not (has_var_ self x)); assert (not (has_var_ self x));
@ -481,6 +538,8 @@ module Make(Q : RATIONAL)(Var: VAR)
basic_idx=row_idx; basic_idx=row_idx;
l_bound=None; l_bound=None;
u_bound=None; u_bound=None;
l_implied=Erat.minus_inf;
u_implied=Erat.plus_inf;
all_bound_lits=[]; all_bound_lits=[];
}) })
in in
@ -520,6 +579,8 @@ module Make(Q : RATIONAL)(Var: VAR)
); );
) le; ) le;
update_implied_bounds_ self vs;
(* Log.debugf 50 (fun k->k "post-define: %a" pp self); *) (* Log.debugf 50 (fun k->k "post-define: %a" pp self); *)
_check_invariants_internal self; _check_invariants_internal self;
() ()
@ -535,6 +596,8 @@ module Make(Q : RATIONAL)(Var: VAR)
var=x; var=x;
l_bound=None; l_bound=None;
u_bound=None; u_bound=None;
l_implied=Erat.minus_inf;
u_implied=Erat.plus_inf;
value=Erat.zero; value=Erat.zero;
all_bound_lits=[]; all_bound_lits=[];
} in } in
@ -606,6 +669,8 @@ module Make(Q : RATIONAL)(Var: VAR)
(* make [x_i] non basic, and [x_j] basic *) (* make [x_i] non basic, and [x_j] basic *)
x_j.basic_idx <- x_i.basic_idx; x_j.basic_idx <- x_i.basic_idx;
x_i.basic_idx <- -1; x_i.basic_idx <- -1;
x_i.l_implied <- Erat.minus_inf;
x_i.u_implied <- Erat.plus_inf;
Matrix.set_row_var m x_j.basic_idx x_j; Matrix.set_row_var m x_j.basic_idx x_j;
(* adjust other rows so they don't depend on [x_j] *) (* adjust other rows so they don't depend on [x_j] *)
@ -635,6 +700,8 @@ module Make(Q : RATIONAL)(Var: VAR)
) )
) )
done; done;
update_implied_bounds_ self x_j;
end; end;
assert (Var_state.is_basic x_j); assert (Var_state.is_basic x_j);
@ -693,6 +760,91 @@ module Make(Q : RATIONAL)(Var: VAR)
ignore (find_or_create_n_basic_var_ self v : var_state); ignore (find_or_create_n_basic_var_ self v : var_state);
() ()
(* gather all relevant lower (resp. upper) bounds for the definition
of the basic variable [x_i], and collect each relevant variable
with [map] into a list. *)
let gather_bounds_of_row_ (self:t) (x_i:var_state) ~map ~is_lower : _ list * _ =
assert (Var_state.is_basic x_i);
let map_res = ref [] in
let bounds = ref V_map.empty in
Vec.iteri
(fun j x_j ->
if j <> x_i.idx then (
let c = Matrix.get self.matrix x_i.basic_idx j in
if Q.(c <> zero) then (
match is_lower, Q.(c > zero) with
| true, true
| false, false ->
begin match x_j.u_bound with
| Some u ->
map_res := (map c x_j u) :: !map_res;
let op = if Q.(u.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in
bounds := V_map.add x_j.var (op, u) !bounds
| None -> assert false (* we could increase [x_j]?! *)
end
| true, false
| false, true ->
begin match x_j.l_bound with
| Some l ->
map_res := (map c x_j l) :: !map_res;
let op = if Q.(l.b_val.eps_factor <= zero) then Op.Geq else Op.Gt in
bounds := V_map.add x_j.var (op, l) !bounds
| None -> assert false (* we could decrease [x_j]?! *)
end
)
))
self.vars;
!map_res, !bounds
(* do propagations for basic var [x_i] based on its implied bounds *)
let propagate_basic_implied_bounds (self:t) ~on_propagate (x_i:var_state) : unit =
assert (Var_state.is_basic x_i);
let lits_of_row_ ~is_lower : V.lit list =
let l, _ =
gather_bounds_of_row_ self x_i ~is_lower
~map:(fun _ _ bnd -> bnd.b_lit) in
l
in
let process_bount_lit (is_lower, bnd): unit =
if is_lower then (
if Erat.(bnd.b_val < x_i.l_implied) then (
(* implied lower bound subsumes this lower bound *)
let reason = lits_of_row_ ~is_lower:true in
on_propagate bnd.b_lit ~reason
);
if Erat.(bnd.b_val > x_i.u_implied) then (
(* lower bound is higher than implied upper bound *)
match V.not_lit bnd.b_lit with
| Some not_lit ->
Log.debugf 50
(fun k->k"(@[lra.propagate.not@ :lower-bnd-of %a@ :bnd %a :lit %a@ :u-implied %a@])"
Var_state.pp x_i Erat.pp bnd.b_val V.pp_lit bnd.b_lit Erat.pp x_i.u_implied);
let reason = lits_of_row_ ~is_lower:false in
on_propagate not_lit ~reason
| None -> ()
)
) else (
if Erat.(bnd.b_val > x_i.u_implied) then (
(* implied upper bound subsumes this upper bound *)
let reason = lits_of_row_ ~is_lower:false in
on_propagate bnd.b_lit ~reason
);
if Erat.(bnd.b_val < x_i.l_implied) then (
(* upper bound is lower than implied lower bound *)
match V.not_lit bnd.b_lit with
| Some not_lit ->
let reason = lits_of_row_ ~is_lower:true in
on_propagate not_lit ~reason
| None -> ()
)
)
in
List.iter process_bount_lit x_i.all_bound_lits;
()
let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit = let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit =
let open Constraint in let open Constraint in
@ -743,7 +895,17 @@ module Make(Q : RATIONAL)(Var: VAR)
if Var_state.is_n_basic vs && if Var_state.is_n_basic vs &&
Erat.(vs.value < new_bnd.b_val) then ( Erat.(vs.value < new_bnd.b_val) then (
(* line 5: need to update non-basic variable *) (* line 5: need to update non-basic variable *)
update_n_basic self vs new_bnd.b_val update_n_basic self vs new_bnd.b_val;
(* section 3.2.5: update implied bounds on basic variables that
depend on [vs] *)
for i = 0 to Matrix.n_rows self.matrix -1 do
if Q.(Matrix.get self.matrix i vs.idx <> zero) then (
let x_i = Matrix.get_row_var self.matrix i in
update_implied_bounds_ self x_i;
propagate_basic_implied_bounds self ~on_propagate x_i;
);
done;
) )
end end
) else ( ) else (
@ -839,8 +1001,6 @@ module Make(Q : RATIONAL)(Var: VAR)
| None -> true | None -> true
| Some bnd -> Erat.(x.value > bnd.b_val) | Some bnd -> Erat.(x.value > bnd.b_val)
(* TODO: certificate checker *)
(* make a certificate from the row of basic variable [x_i] which is outside (* make a certificate from the row of basic variable [x_i] which is outside
one of its bound, and whose row is tight on all non-basic variables. one of its bound, and whose row is tight on all non-basic variables.
@param is_lower is true if the lower bound is not respected @param is_lower is true if the lower bound is not respected
@ -849,41 +1009,17 @@ module Make(Q : RATIONAL)(Var: VAR)
Log.debugf 50 (fun k->k "(@[simplex.cert-of-row[lower: %B]@ x_i=%a@])" Log.debugf 50 (fun k->k "(@[simplex.cert-of-row[lower: %B]@ x_i=%a@])"
is_lower Var_state.pp x_i); is_lower Var_state.pp x_i);
assert (Var_state.is_basic x_i); assert (Var_state.is_basic x_i);
let le = ref [] in let le, bounds =
let bounds = ref V_map.empty in gather_bounds_of_row_ self x_i ~is_lower
Vec.iteri ~map:(fun c v _ -> c, v.var)
(fun j x_j -> in
if j <> x_i.idx then (
let c = Matrix.get self.matrix x_i.basic_idx j in
if Q.(c <> zero) then (
le := (c, x_j.var) :: !le;
match is_lower, Q.(c > zero) with
| true, true
| false, false ->
begin match x_j.u_bound with
| Some u ->
let op = if Q.(u.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in
bounds := V_map.add x_j.var (op, u) !bounds
| None -> assert false (* we could increase [x_j]?! *)
end
| true, false
| false, true ->
begin match x_j.l_bound with
| Some l ->
let op = if Q.(l.b_val.eps_factor <= zero) then Op.Geq else Op.Gt in
bounds := V_map.add x_j.var (op, l) !bounds
| None -> assert false (* we could decrease [x_j]?! *)
end
)
))
self.vars;
let op = let op =
if is_lower then if Q.(bnd.b_val.eps_factor <= zero) then Op.Geq else Op.Gt if is_lower then if Q.(bnd.b_val.eps_factor <= zero) then Op.Geq else Op.Gt
else if Q.(bnd.b_val.eps_factor >= zero) then Op.Leq else Op.Lt else if Q.(bnd.b_val.eps_factor >= zero) then Op.Leq else Op.Lt
in in
let cert = Unsat_cert.unsat_basic x_i (op, bnd) !le !bounds in let cert = Unsat_cert.unsat_basic x_i (op, bnd) le bounds in
cert cert
(* main satisfiability check. (* main satisfiability check.
@ -964,6 +1100,7 @@ module Make(Q : RATIONAL)(Var: VAR)
vars=Vec.create(); vars=Vec.create();
var_tbl=V_map.empty; var_tbl=V_map.empty;
bound_stack=Backtrack_stack.create(); bound_stack=Backtrack_stack.create();
undo_stack=Backtrack_stack.create();
stat_check=Stat.mk_int stat "simplex.check"; stat_check=Stat.mk_int stat "simplex.check";
stat_unsat=Stat.mk_int stat "simplex.unsat"; stat_unsat=Stat.mk_int stat "simplex.unsat";
stat_define=Stat.mk_int stat "simplex.define"; stat_define=Stat.mk_int stat "simplex.define";