feat(simplex2): build proper certificates

This commit is contained in:
Simon Cruanes 2021-02-15 16:18:40 -05:00
parent f0dd1b08e8
commit d6f0fa0ffc

View file

@ -17,6 +17,12 @@ module Op = struct
| Geq | Geq
| Gt | Gt
let neg_sign = function
| Leq -> Geq
| Lt -> Gt
| Geq -> Leq
| Gt -> Lt
let to_string = function let to_string = function
| Leq -> "<=" | Leq -> "<="
| Lt -> "<" | Lt -> "<"
@ -35,12 +41,13 @@ module type S = sig
type op = Op.t type op = Op.t
(** A constraint is the comparison of a variable to a constant. *) (** A constraint is the comparison of a variable to a constant. *)
type t = private { type t = {
op: op; op: op;
lhs: V.t; lhs: V.t;
rhs: num; rhs: num;
} }
val mk : V.t -> op -> num -> t
val leq : V.t -> num -> t val leq : V.t -> num -> t
val lt : V.t -> num -> t val lt : V.t -> num -> t
val geq : V.t -> num -> t val geq : V.t -> num -> t
@ -69,13 +76,11 @@ module type S = sig
This is useful to "name" a linear expression and get back a variable This is useful to "name" a linear expression and get back a variable
that can be used in a {!Constraint.t} *) that can be used in a {!Constraint.t} *)
type unsat_cert = { type unsat_cert
cert_bounds: (num * V.lit) list;
cert_defs: (V.t * (num * V.t) list) list; (* definitions used *)
}
module Unsat_cert : sig module Unsat_cert : sig
type t = unsat_cert type t = unsat_cert
val lits : t -> V.lit list (* unsat core *)
val pp : t Fmt.printer val pp : t Fmt.printer
end end
@ -135,6 +140,7 @@ module Make(Var: VAR)
Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs
(Op.to_string self.op) pp_q_dbg self.rhs (Op.to_string self.op) pp_q_dbg self.rhs
let mk lhs op rhs : t = {lhs;op;rhs}
let leq lhs rhs = {lhs;rhs;op=Op.Leq} let leq lhs rhs = {lhs;rhs;op=Op.Leq}
let lt lhs rhs = {lhs;rhs;op=Op.Lt} let lt lhs rhs = {lhs;rhs;op=Op.Lt}
let geq lhs rhs = {lhs;rhs;op=Op.Geq} let geq lhs rhs = {lhs;rhs;op=Op.Geq}
@ -369,6 +375,10 @@ module Make(Var: VAR)
| `Lower -> var.l_bound <- bnd); | `Lower -> var.l_bound <- bnd);
() ()
let pp_stats out (self:t) : unit =
Fmt.fprintf out "(@[simplex@ :n-vars %d@ :n-rows %d@])"
(Vec.size self.vars) (Matrix.n_rows self.matrix)
let pp out (self:t) : unit = let pp out (self:t) : unit =
Fmt.fprintf out "(@[simplex@ @[<1>:vars@ [@[<hov>%a@]]@]@ @[<1>:matrix@ %a@]@])" Fmt.fprintf out "(@[simplex@ @[<1>:vars@ [@[<hov>%a@]]@]@ @[<1>:matrix@ %a@]@])"
(Vec.pp Var_state.pp) self.vars (Vec.pp Var_state.pp) self.vars
@ -474,7 +484,7 @@ module Make(Var: VAR)
); );
) le; ) le;
Log.debugf 50 (fun k->k "post-define: %a" pp self); (* Log.debugf 50 (fun k->k "post-define: %a" pp self); *)
_check_invariants_internal self; _check_invariants_internal self;
() ()
@ -502,8 +512,8 @@ module Make(Var: VAR)
let update_n_basic (self:t) (x:var_state) (v:erat) : unit = let update_n_basic (self:t) (x:var_state) (v:erat) : unit =
assert (Var_state.is_n_basic x); assert (Var_state.is_n_basic x);
Log.debugf 50 Log.debugf 50
(fun k->k "(@[<hv>simplex.update-n-basic@ %a@ :new-val %a@ :in %a@])" (fun k->k "(@[<hv>simplex.update-n-basic@ %a@ :new-val %a@])"
Var_state.pp x Erat.pp v pp self); Var_state.pp x Erat.pp v);
_check_invariants_internal self; _check_invariants_internal self;
let m = self.matrix in let m = self.matrix in
@ -597,24 +607,42 @@ module Make(Var: VAR)
() ()
type unsat_cert = { type unsat_cert =
cert_bounds: (num * V.lit) list; | E_bounds of {
cert_defs: (V.t * (num * V.t) list) list; (* definitions used *) x: var_state;
} lower: bound;
upper: bound;
}
| E_unsat_basic of {
x: var_state;
le: (num * V.t) list; (* definition of the basic var *)
bounds: (Op.t * bound) V_map.t; (* bound for each variable in [le] *)
}
module Unsat_cert = struct module Unsat_cert = struct
type t = unsat_cert type t = unsat_cert
let pp out (self:t) = let lits = function
let pp_bnd out (n,lit) = | E_bounds b -> [b.lower.b_lit; b.upper.b_lit]
Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit pp_q_dbg n | E_unsat_basic b ->
in V_map.fold (fun _ (_,bnd) l -> bnd.b_lit :: l) b.bounds []
Fmt.fprintf out "(@[cert@ :bounds %a@ :defs %a@])"
Fmt.(Dump.list pp_bnd) self.cert_bounds
Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair pp_q_dbg V.pp)))) self.cert_defs
let mk ~defs ~bounds : t = let pp out (self:t) =
{ cert_defs=defs; cert_bounds=bounds } match self with
| E_bounds {x;lower;upper} ->
Fmt.fprintf out "(@[unsat-bounds@ %a@ :lower %a@ :upper %a@])"
Var_state.pp x Erat.pp lower.b_val Erat.pp upper.b_val
| E_unsat_basic {x; le; bounds} ->
let pp_bnd out (v,(op,bnd)) =
Fmt.fprintf out "(@[%a %s %a@])" Var.pp v (Op.to_string op) Erat.pp bnd.b_val
in
Fmt.fprintf out "(@[cert@ %a :bounds %a@ :defs %a@])"
Var_state.pp x
Fmt.(Dump.list pp_bnd) (V_map.to_list bounds)
Fmt.(Dump.list (Dump.pair pp_q_dbg V.pp)) le
let bounds x ~lower ~upper : t = E_bounds {x; lower; upper}
let unsat_basic x le bounds : t = E_unsat_basic {x; le; bounds}
end end
exception E_unsat of Unsat_cert.t exception E_unsat of Unsat_cert.t
@ -643,8 +671,7 @@ module Make(Var: VAR)
begin match vs.l_bound, vs.u_bound with begin match vs.l_bound, vs.u_bound with
| _, Some upper when Erat.(new_bnd.b_val > upper.b_val) -> | _, Some upper when Erat.(new_bnd.b_val > upper.b_val) ->
(* [b_val <= x <= upper], but [b_val > upper] *) (* [b_val <= x <= upper], but [b_val > upper] *)
let cert = Unsat_cert.mk ~defs:[] let cert = Unsat_cert.bounds vs ~lower:new_bnd ~upper in
~bounds:[(Q.one, upper.b_lit); (Q.one, lit)] in
raise (E_unsat cert) raise (E_unsat cert)
| Some lower, _ when Erat.(lower.b_val >= new_bnd.b_val) -> | Some lower, _ when Erat.(lower.b_val >= new_bnd.b_val) ->
() (* subsumed by existing constraint, do nothing *) () (* subsumed by existing constraint, do nothing *)
@ -663,8 +690,7 @@ module Make(Var: VAR)
begin match vs.l_bound, vs.u_bound with begin match vs.l_bound, vs.u_bound with
| Some lower, _ when Erat.(new_bnd.b_val < lower.b_val) -> | Some lower, _ when Erat.(new_bnd.b_val < lower.b_val) ->
(* [lower <= x <= b_val], but [b_val < lower] *) (* [lower <= x <= b_val], but [b_val < lower] *)
let cert = Unsat_cert.mk ~defs:[] let cert = Unsat_cert.bounds vs ~lower ~upper:new_bnd in
~bounds:[(Q.one, lower.b_lit); (Q.one, lit)] in
raise (E_unsat cert) raise (E_unsat cert)
| _, Some upper when Erat.(upper.b_val <= new_bnd.b_val) -> | _, Some upper when Erat.(upper.b_val <= new_bnd.b_val) ->
() (* subsumed *) () (* subsumed *)
@ -719,22 +745,52 @@ module Make(Var: VAR)
| None -> true | None -> true
| Some bnd -> Erat.(x.value > bnd.b_val) | Some bnd -> Erat.(x.value > bnd.b_val)
(* TODO: certificate checker *)
(* make a certificate from the row of basic variable [x_i] which is outside (* make a certificate from the row of basic variable [x_i] which is outside
one of its bound, and whose row is tight on all non-basic variables *) one of its bound, and whose row is tight on all non-basic variables.
@param is_lower is true if the lower bound is not respected
(i.e. [x_i] is too small) *)
let cert_of_row_ (self:t) (x_i:var_state) ~is_lower : unsat_cert = let cert_of_row_ (self:t) (x_i:var_state) ~is_lower : unsat_cert =
Log.debugf 50 (fun k->k "(@[simplex.cert-of-row[lower: %B]@ x_i=%a@ %a@])" Log.debugf 50 (fun k->k "(@[simplex.cert-of-row[lower: %B]@ x_i=%a@])"
is_lower Var_state.pp x_i pp self); is_lower Var_state.pp x_i);
assert (Var_state.is_basic x_i); assert (Var_state.is_basic x_i);
(* TODO: store initial definition for each matrix row *) let le = ref [] in
let defs = [] in let bounds = ref V_map.empty in
let bounds = [] in (* TODO: use all bounds in the row *) Vec.iteri
Unsat_cert.mk ~defs ~bounds (fun j x_j ->
if j <> x_i.idx then (
let c = Matrix.get self.matrix x_i.basic_idx j in
if Q.(c <> zero) then (
le := (c, x_j.var) :: !le;
match is_lower, Q.(c > zero) with
| true, true
| false, false ->
begin match x_j.u_bound with
| Some u ->
let op = if Q.(u.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in
bounds := V_map.add x_j.var (op, u) !bounds
| None -> assert false (* we could increase [x_j]?! *)
end
| true, false
| false, true ->
begin match x_j.l_bound with
| Some l ->
let op = if Q.(l.b_val.eps_factor <= zero) then Op.Geq else Op.Gt in
bounds := V_map.add x_j.var (op, l) !bounds
| None -> assert false (* we could decrease [x_j]?! *)
end
)
))
self.vars;
let cert = Unsat_cert.unsat_basic x_i !le !bounds in
cert
(* main satisfiability check. (* main satisfiability check.
page 15, figure 3.2 *) page 15, figure 3.2 *)
let check_exn (self:t) : unit = let check_exn (self:t) : unit =
let exception Stop in let exception Stop in
Log.debugf 20 (fun k->k "(@[simplex2.check@ %a@])" pp self); Log.debugf 20 (fun k->k "(@[simplex2.check@ %a@])" pp_stats self);
let m = self.matrix in let m = self.matrix in
try try