(** {1 Fast Simplex for CDCL(T)} We follow the paper "Integrating Simplex with DPLL(T )" from de Moura and Dutertre. *) open CCMonomorphic module type RATIONAL = Sidekick_arith.RATIONAL module type VAR = Linear_expr_intf.VAR (** {2 Basic operator} *) module Op = struct type t = | Leq | Lt | Geq | Gt let neg_sign = function | Leq -> Geq | Lt -> Gt | Geq -> Leq | Gt -> Lt let not_ = function | Leq -> Gt | Lt -> Geq | Geq -> Lt | Gt -> Leq let to_string = function | Leq -> "<=" | Lt -> "<" | Geq -> ">=" | Gt -> ">" let pp out op = Fmt.string out (to_string op) end module type S = sig module V : VAR module V_map : CCMap.S with type key = V.t module Q : RATIONAL type num = Q.t (** Numbers *) module Constraint : sig type op = Op.t (** A constraint is the comparison of a variable to a constant. *) type t = { op: op; lhs: V.t; rhs: num; } val mk : V.t -> op -> num -> t val leq : V.t -> num -> t val lt : V.t -> num -> t val geq : V.t -> num -> t val gt : V.t -> num -> t val pp : t Fmt.printer end module Subst : sig type t = num V_map.t val eval : t -> V.t -> Q.t val pp : t Fmt.printer val to_string : t -> string end type t val create : ?stat:Stat.t -> unit -> t (** Create a new simplex. *) val push_level : t -> unit val pop_levels : t -> int -> unit val define : t -> V.t -> (num * V.t) list -> unit (** Define a basic variable in terms of other variables. This is useful to "name" a linear expression and get back a variable that can be used in a {!Constraint.t} *) type unsat_cert module Unsat_cert : sig type t = unsat_cert val lits : t -> V.lit list (* unsat core *) val pp : t Fmt.printer end exception E_unsat of Unsat_cert.t type ev_on_propagate = V.lit -> reason:V.lit list -> unit val add_var : t -> V.t -> unit (** Make sure the variable exists in the simplex. *) val add_constraint : on_propagate:ev_on_propagate -> t -> Constraint.t -> V.lit -> unit (** Add a constraint to the simplex. @raise Unsat if it's immediately obvious that this is not satisfiable. *) val declare_bound : t -> Constraint.t -> V.lit -> unit (** Declare that this constraint exists, so we can possibly propagate it. Unlike {!add_constraint} this does {b NOT} assert that the constraint is true *) val check_exn : on_propagate:(V.lit -> reason:V.lit list -> unit) -> t -> unit (** Check the whole simplex for satisfiability. @param on_propagate is called with arguments [lit, reason] whenever [reason => lit] is found to be true by the simplex. @raise Unsat if the constraints are not satisfiable. *) type result = | Sat of Subst.t | Unsat of Unsat_cert.t val check : on_propagate:(V.lit -> reason:V.lit list -> unit) -> t -> result (** Call {!check_exn} and return a model or a proof of unsat. *) (**/**) val _check_invariants : t -> unit (* check internal invariants *) val _check_cert : unsat_cert -> unit (**/**) end (* TODO(optim): page 14, paragraph 2: we could detect which variables occur in no atom after preprocessing; then these variables can be "inlined" (removed by Gaussian elimination) as a preprocessing proof_rule, and this removes one column and maybe one row if it was basic. *) module Make(Q : RATIONAL)(Var: VAR) : S with module V = Var and module Q = Q = struct module V = Var module V_map = CCMap.Make(Var) module Q = Q type num = Q.t let pp_q_dbg = Q.pp_approx 1 module Constraint = struct type op = Op.t (** A constraint is the comparison of a variable to a constant. *) type t = { op: op; lhs: V.t; rhs: num; } let pp out (self:t) = Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs (Op.to_string self.op) pp_q_dbg self.rhs let mk lhs op rhs : t = {lhs;op;rhs} let leq lhs rhs = {lhs;rhs;op=Op.Leq} let lt lhs rhs = {lhs;rhs;op=Op.Lt} let geq lhs rhs = {lhs;rhs;op=Op.Geq} let gt lhs rhs = {lhs;rhs;op=Op.Gt} end module Subst = struct type t = num V_map.t let eval self t = try V_map.find t self with Not_found -> Q.zero let pp out (self:t) : unit = let pp_pair out (v,n) = Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in Fmt.fprintf out "{@[%a@]}" (Fmt.iter pp_pair) (V_map.to_iter self) let to_string = Fmt.to_string pp end (** An extended rational, used to introduce ε so we can use strict comparisons in an algorithm designed to handle >= only. In a way, an extended rational is a tuple `(base, factor)` ordered lexicographically. *) type erat = { base: num; (** reference number *) eps_factor: num; (** coefficient for epsilon, the infinitesimal *) } (** {2 Epsilon-rationals, used for strict bounds} *) module Erat = struct type t = erat let zero : t = {base=Q.zero; eps_factor=Q.zero} let[@inline] make base eps_factor : t = if Q.is_real base then {base; eps_factor} else {base; eps_factor=Q.zero} let[@inline] make_q x = make x Q.zero let[@inline] base t = t.base let[@inline] eps_factor t = t.eps_factor let[@inline] ( * ) k e = make Q.(k * e.base) Q.(k * e.eps_factor) let[@inline] ( / ) e k = make Q.(e.base / k) Q.(e.eps_factor / k) let[@inline] (+) e1 e2 = make Q.(e1.base + e2.base) Q.(e1.eps_factor + e2.eps_factor) let[@inline] (-) e1 e2 = make Q.(e1.base - e2.base) Q.(e1.eps_factor - e2.eps_factor) let[@inline] compare e1 e2 = match Q.compare e1.base e2.base with | 0 -> Q.compare e1.eps_factor e2.eps_factor | x -> x let[@inline] (<) a b = compare a b < 0 let[@inline] (<=) a b = compare a b <= 0 let[@inline] (>) a b = compare a b > 0 let[@inline] (>=) a b = compare a b >= 0 let[@inline] (=) a b = compare a b = 0 let plus_inf = make Q.infinity Q.zero let minus_inf = make Q.minus_infinity Q.zero let[@inline] min x y = if x <= y then x else y let[@inline] max x y = if x >= y then x else y let[@inline] evaluate (epsilon:Q.t) (e:t) : Q.t = Q.(e.base + epsilon * e.eps_factor) let pp out e : unit = if Q.equal Q.zero (eps_factor e) then pp_q_dbg out (base e) else Fmt.fprintf out "(@[%a + @<1>ε * %a@])" pp_q_dbg (base e) pp_q_dbg (eps_factor e) end type bound = { b_val: erat; b_lit: Var.lit; } type is_lower = bool type var_state = { var: V.t; mutable value: erat; idx: int; (* index in {!t.vars} *) mutable basic_idx: int; (* index of the row in the matrix, if any. -1 otherwise *) mutable l_bound: bound option; mutable u_bound: bound option; mutable l_implied: Erat.t; (* implied lower bound for a basic var *) mutable u_implied: Erat.t; mutable all_bound_lits : (is_lower * bound) list; (* all known literals on this var *) } (** {2 Matrix} The matrix [A] from the paper, with m rows and n columns. - m is the number of basic variables (defined in terms of non-basic variables) - n is the total number of variables, basic and non-basic. The invariant that the simplex maintains at all times is that [Ax = 0], where [x] is the vector of values of all variables (basic and non-basic). *) module Matrix : sig type t val create : unit -> t val pp : t Fmt.printer val to_string : t -> string val n_rows : t -> int val n_cols : t -> int val add_column : t -> unit (** Add a non-basic variable (only adds a column) *) val add_row_and_column : t -> f:(row_idx:int -> col_idx:int -> var_state) -> var_state (** Add a basic variable. *) val get_row_var : t -> int -> var_state (** The basic variable for row [i] *) val set_row_var : t -> int -> var_state -> unit (** Set basic variable for row [i] *) val get : t -> int -> int -> num val set : t -> int -> int -> num -> unit val add : t -> int -> int -> num -> unit val mult : t -> int -> int -> num -> unit end = struct type row = { mutable vs: var_state; cols: num Vec.t; } type t = { mutable n_cols: int; rows: row Vec.t } let create() : t = {n_cols=0; rows=Vec.create()} let[@inline] n_rows self = Vec.size self.rows let[@inline] n_cols self = self.n_cols let pp out self = Fmt.fprintf out "@[{matrix[%dx%d]@," (n_rows self) (n_cols self); (* header *) let ppi out i = Fmt.string out @@ CCString.pad ~side:`Left 6 @@ Printf.sprintf "v%d" i in Fmt.fprintf out "{@[%9s: %a@]}" "vars" (Fmt.iter ~sep:(Fmt.return "@ ") ppi) CCInt.(0 -- (n_cols self-1)); Vec.iteri (fun i row -> let hd = CCString.pad ~side:`Left 6 @@ Printf.sprintf "r%d (v%d)" i row.vs.idx in Fmt.fprintf out "@,{@[%9s: %a@]}" hd (Fmt.iter ~sep:(Fmt.return "@ ") (Q.pp_approx 6)) (Vec.to_seq row.cols)) self.rows; Fmt.fprintf out "@;<0 -1>}@]" let to_string = Fmt.to_string pp let add_column self = self.n_cols <- 1 + self.n_cols; Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows let add_row_and_column self ~f : var_state = let n = n_rows self in let j = n_cols self in add_column self; let cols = Vec.make (j+1) Q.zero in for _k=0 to j do Vec.push cols Q.zero done; let vs = f ~row_idx:n ~col_idx:j in let row = {vs; cols} in Vec.push self.rows row; vs let[@inline] get_row_var self i = (Vec.get self.rows i).vs let[@inline] set_row_var self i v = (Vec.get self.rows i).vs <- v let[@inline] get self i j : num = Vec.get (Vec.get self.rows i).cols j let[@inline] set self i j n : unit = let r = Vec.get self.rows i in Vec.set r.cols j n (* add [n] to [m_ij] *) let add self i j n : unit = let r = Vec.get self.rows i in Vec.set r.cols j Q.(Vec.get r.cols j + n) (* multiply [m_ij] by [c] *) let mult self i j c : unit = let r = Vec.get self.rows i in let n_j = Vec.get r.cols j in if Q.(n_j <> zero) then ( Vec.set r.cols j Q.(n_j * c) ) end module Var_state = struct type t = var_state let (==) : t -> t -> bool = Containers.Stdlib.(==) let (!=) : t -> t -> bool = Containers.Stdlib.(!=) let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0 let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0 let in_bounds_ self = (match self.l_bound with None -> true | Some b -> Erat.(self.value >= b.b_val)) && (match self.u_bound with None -> true | Some b -> Erat.(self.value <= b.b_val)) let pp_name out self = Fmt.fprintf out "v%d" self.idx let pp out self = let bnd_status = if in_bounds_ self then "" else "(oob)" in let pp_bnd what out = function | None -> () | Some b -> Fmt.fprintf out "@ @[%s %a@]" what Erat.pp b.b_val and pp_basic_idx out () = if self.basic_idx < 0 then () else Fmt.int out self.basic_idx in Fmt.fprintf out "(@[v%d[%s%a]%s@ :value %a%a%a@])" self.idx (if is_basic self then "B" else "N") pp_basic_idx () bnd_status Erat.pp self.value (pp_bnd ":lower") self.l_bound (pp_bnd ":upper") self.u_bound end type t = { matrix: Matrix.t; vars: var_state Vec.t; (* index -> var with this index *) mutable var_tbl: var_state V_map.t; (* var -> its state *) bound_stack: (var_state * [`Upper|`Lower] * bound option) Backtrack_stack.t; undo_stack: (unit -> unit) Backtrack_stack.t; stat_check: int Stat.counter; stat_unsat: int Stat.counter; stat_define: int Stat.counter; stat_propagate: int Stat.counter; } let push_level self : unit = Log.debug 10 "(simplex2.push-level)"; Backtrack_stack.push_level self.bound_stack; Backtrack_stack.push_level self.undo_stack let pop_levels self n : unit = Log.debugf 10 (fun k->k "(simplex2.pop-levels %d)" n); Backtrack_stack.pop_levels self.bound_stack n ~f:(fun (var, kind, bnd) -> match kind with | `Upper -> var.u_bound <- bnd | `Lower -> var.l_bound <- bnd); Backtrack_stack.pop_levels self.undo_stack n ~f:(fun f -> f()); () let pp_stats out (self:t) : unit = Fmt.fprintf out "(@[simplex@ :n-vars %d@ :n-rows %d@])" (Vec.size self.vars) (Matrix.n_rows self.matrix) let pp out (self:t) : unit = Fmt.fprintf out "(@[simplex@ @[<1>:vars@ [@[%a@]]@]@ @[<1>:matrix@ %a@]@])" (Vec.pp Var_state.pp) self.vars Matrix.pp self.matrix (* for debug purposes *) let _check_invariants self : unit = Vec.iteri (fun i v -> assert(v.idx = i)) self.vars; let n = Vec.size self.vars in assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n); for i = 0 to Matrix.n_rows self.matrix-1 do let v = Matrix.get_row_var self.matrix i in assert (Var_state.is_basic v); assert (v.basic_idx = i); assert Q.(Matrix.get self.matrix v.basic_idx v.idx = minus_one); (* basic vars are only defined in terms of non-basic vars *) Vec.iteri (fun j v_j -> if Var_state.(v != v_j) && Q.(Matrix.get self.matrix i j <> zero) then ( assert (Var_state.is_n_basic v_j) )) self.vars; (* sum of each row must be 0 *) let sum = Vec.fold (fun sum v -> Erat.(sum + Matrix.get self.matrix i v.idx * v.value)) Erat.zero self.vars in Log.debugf 50 (fun k->k "row %d: sum %a" i Erat.pp sum); assert Erat.(sum = zero); done; () (* for internal checking *) let[@inline] _check_invariants_internal self = if false (* FUDGE *) then _check_invariants self let[@inline] has_var_ (self:t) x : bool = V_map.mem x self.var_tbl let[@inline] find_var_ (self:t) x : var_state = try V_map.find x self.var_tbl with Not_found -> Error.errorf "variable is not in the simplex" Var.pp x let[@inline] get_bnd_or v (b:bound option) = match b with None -> v | Some b -> b.b_val (* update implied bounds for basic variable [x_i] by looking at each non-basic variable's bounds *) let update_implied_bounds_ (self:t) (x_i:var_state) : unit = assert (Var_state.is_basic x_i); let m = self.matrix in let i_low = ref Erat.zero in let i_up = ref Erat.zero in for j=0 to Matrix.n_cols m-1 do if j <> x_i.idx then ( let a_ij: Q.t = Matrix.get m x_i.basic_idx j in let x_j = Vec.get self.vars j in let low_j = get_bnd_or Erat.minus_inf x_j.l_bound in let up_j = get_bnd_or Erat.plus_inf x_j.u_bound in if Q.(a_ij = zero) then() else if Q.(a_ij > zero) then ( i_low := Erat.(!i_low + a_ij * low_j); i_up := Erat.(!i_up + a_ij * up_j); ) else ( (* [a_ij < 0] and [x_j < up], means [-a_ij > 0] and [-x_j > -up] means [x_i = rest + a_ij x_j > rest + (-a_ij) * (-up)] *) i_low := Erat.(!i_low + a_ij * up_j); i_up := Erat.(!i_up + a_ij * low_j); ) ) done; let old_i_low = x_i.l_implied in let old_i_up = x_i.u_implied in Backtrack_stack.push self.undo_stack (fun () -> x_i.l_implied <- old_i_low; x_i.u_implied <- old_i_up); x_i.l_implied <- !i_low; x_i.u_implied <- !i_up; Log.debugf 50 (fun k->k"(@[lra.implied-bounds@ :var %a@ :lower %a@ :upper %a@])" Var_state.pp x_i Erat.pp !i_low Erat.pp !i_up); () (* define [x] as a basic variable *) let define (self:t) (x:V.t) (le:_ list) : unit = assert (not (has_var_ self x)); Stat.incr self.stat_define; (* Log.debugf 50 (fun k->k "define-in: %a" pp self); *) let le = List.map (fun (q,v) -> q, find_var_ self v) le in (* initial value for the new variable *) let value = List.fold_left (fun sum (c,v) -> Erat.(sum + c * v.value)) Erat.zero le in let vs = Matrix.add_row_and_column self.matrix ~f:(fun ~row_idx ~col_idx -> { var=x; value; idx=col_idx; basic_idx=row_idx; l_bound=None; u_bound=None; l_implied=Erat.minus_inf; u_implied=Erat.plus_inf; all_bound_lits=[]; }) in Log.debugf 5 (fun k->k "(@[simplex.define@ @[v%d :var %a@]@ :linexpr %a@])" vs.idx Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var_state.pp_name)) le); assert (Var_state.is_basic vs); assert Var_state.(Matrix.get_row_var self.matrix vs.basic_idx == vs); Vec.push self.vars vs; self.var_tbl <- V_map.add x vs self.var_tbl; (* set coefficients in the matrix's new row: [-x + le = 0] *) Matrix.set self.matrix vs.basic_idx vs.idx Q.minus_one; List.iter (fun (coeff,v2) -> assert Containers.Stdlib.(v2 != vs); if Var_state.is_basic v2 then ( (* [v2] is also basic, so instead of depending on it, we depend on its definition in terms of non-basic variables. *) for k=0 to Matrix.n_cols self.matrix - 1 do if k <> v2.idx then ( let v2_jk = Matrix.get self.matrix v2.basic_idx k in if Q.(v2_jk <> zero) then ( let v_k = Vec.get self.vars k in assert (Var_state.is_n_basic v_k); (* [v2 := v2_jk * v_k + …], so [v := … + coeff * v2_jk * v_k] *) Matrix.add self.matrix vs.basic_idx k Q.(coeff * v2_jk); ); ); done; ) else ( (* directly add coefficient with non-basic var [v2] *) Matrix.add self.matrix vs.basic_idx v2.idx coeff; ); ) le; update_implied_bounds_ self vs; (* Log.debugf 50 (fun k->k "post-define: %a" pp self); *) _check_invariants_internal self; () (* find the state for [x], or add [x] as a non-basic variable *) let find_or_create_n_basic_var_ (self:t) (x:V.t) : var_state = try V_map.find x self.var_tbl with Not_found -> Matrix.add_column self.matrix; let vs = { idx=Vec.size self.vars; basic_idx= -1; var=x; l_bound=None; u_bound=None; l_implied=Erat.minus_inf; u_implied=Erat.plus_inf; value=Erat.zero; all_bound_lits=[]; } in assert (Var_state.is_n_basic vs); self.var_tbl <- V_map.add x vs self.var_tbl; Vec.push self.vars vs; vs (* update the simplex so that non-basic [x] is now assigned value [n]. See page 14, figure 3.1. *) let update_n_basic (self:t) (x:var_state) (v:erat) : unit = assert (Var_state.is_n_basic x); Log.debugf 50 (fun k->k "(@[simplex.update-n-basic@ %a@ :new-val %a@])" Var_state.pp x Erat.pp v); _check_invariants_internal self; let m = self.matrix in let i = x.idx in let diff = Erat.(v - x.value) in for j=0 to Matrix.n_rows m - 1 do let a_ji = Matrix.get m j i in let x_j = Matrix.get_row_var m j in assert (Var_state.is_basic x_j); (* value of [x_j] by [a_ji * diff] *) let new_val = Erat.(x_j.value + a_ji * diff) in (* Log.debugf 50 (fun k->k "new-val %a@ := %a" Var_state.pp x_j Erat.pp new_val); *) x_j.value <- new_val; done; x.value <- v; _check_invariants_internal self; () (* pivot [x_i] (basic) and [x_j] (non-basic), setting value of [x_i] to [v] at the same time. See page 14, figure 3.1 *) let pivot_and_update (self:t) (x_i:var_state) (x_j:var_state) (v:erat) : unit = assert (Var_state.is_basic x_i); assert (Var_state.is_n_basic x_j); _check_invariants_internal self; let m = self.matrix in let a_ij = Matrix.get m x_i.basic_idx x_j.idx in assert (Q.sign a_ij <> 0); let theta = Erat.((v - x_i.value) / a_ij) in x_i.value <- v; x_j.value <- Erat.(x_j.value + theta); for k=0 to Matrix.n_rows m-1 do if k <> x_i.basic_idx then ( let x_k = Matrix.get_row_var m k in let a_kj = Matrix.get m x_k.basic_idx x_j.idx in x_k.value <- Erat.(x_k.value + a_kj * theta); ) done; begin (* now pivot the variables so that [x_j]'s coeff is -1 and so that other basic variables only depend on non-basic variables. *) let new_coeff = Q.(minus_one / a_ij) in for k=0 to Matrix.n_cols m-1 do Matrix.mult m x_i.basic_idx k new_coeff; (* update row of [x_i] *) done; assert Q.(Matrix.get m x_i.basic_idx x_j.idx = minus_one); (* make [x_i] non basic, and [x_j] basic *) x_j.basic_idx <- x_i.basic_idx; x_i.basic_idx <- -1; x_i.l_implied <- Erat.minus_inf; x_i.u_implied <- Erat.plus_inf; Matrix.set_row_var m x_j.basic_idx x_j; (* adjust other rows so they don't depend on [x_j] *) for k=0 to Matrix.n_rows m-1 do if k <> x_j.basic_idx then ( let x_k = Matrix.get_row_var m k in assert (Var_state.is_basic x_k); let c_kj = Matrix.get m k x_j.idx in if Q.(c_kj <> zero) then ( (* [m[k,j] != 0] indicates that basic variable [x_k] depends on [x_j], which is about to become basic. To avoid basic-basic dependency we replace [x_j] by its (new) definition *) for l=0 to Matrix.n_cols m-1 do if l<>x_j.idx then ( let c_jl = Matrix.get m x_j.basic_idx l in (* so: [x_k := c_kj * x_j + …], we want to eliminate [x_j], and [x_j := … + c_jl * x_l + …]. therefore [x_j := … + c_jl * c_kl * x_l] *) Matrix.add m k l Q.(c_kj * c_jl); ) done; Matrix.set m k x_j.idx Q.zero; (* [x_k] doesn't use [x_j] anymore *) ) ) done; update_implied_bounds_ self x_j; end; assert (Var_state.is_basic x_j); assert (Var_state.is_n_basic x_i); (* Log.debugf 50 (fun k->k "post pivot: %a" pp self); *) _check_invariants_internal self; () type unsat_cert = | E_bounds of { x: var_state; lower: bound; upper: bound; } | E_unsat_basic of { x: var_state; x_bound: (Op.t * bound); le: (num * V.t) list; (* definition of the basic var *) bounds: (Op.t * bound) V_map.t; (* bound for each variable in [le] *) } module Unsat_cert = struct type t = unsat_cert let lits = function | E_bounds b -> [b.lower.b_lit; b.upper.b_lit] | E_unsat_basic {x_bound=(_,x_bnd); bounds; x=_; le=_;} -> V_map.fold (fun _ (_,bnd) l -> bnd.b_lit :: l) bounds [x_bnd.b_lit] let pp out (self:t) = match self with | E_bounds {x;lower;upper} -> Fmt.fprintf out "(@[cert.unsat-bounds@ %a@ :lower %a@ :upper %a@])" Var_state.pp x Erat.pp lower.b_val Erat.pp upper.b_val | E_unsat_basic {x; x_bound; le; bounds} -> let pp_bnd out (v,(op,bnd)) = Fmt.fprintf out "(@[%a %s %a@])" Var.pp v (Op.to_string op) Erat.pp bnd.b_val in Fmt.fprintf out "(@[cert.unsat-basic@ %a@ @[:bound %a@] @[:le %a@]@ @[:le-bounds@ %a@]@])" Var_state.pp x pp_bnd (x.var,x_bound) Fmt.(Dump.list pp_bnd) (V_map.to_list bounds) Fmt.(Dump.list (Dump.pair pp_q_dbg V.pp)) le let bounds x ~lower ~upper : t = E_bounds {x; lower; upper} let unsat_basic x x_bound le bounds : t = E_unsat_basic {x; x_bound; le; bounds} end exception E_unsat of Unsat_cert.t type ev_on_propagate = V.lit -> reason:V.lit list -> unit let add_var self (v:V.t) : unit = ignore (find_or_create_n_basic_var_ self v : var_state); () (* gather all relevant lower (resp. upper) bounds for the definition of the basic variable [x_i], and collect each relevant variable with [map] into a list. @param get_lower if true, means we select lower bounds of variables with positive coeffs, and upper bounds of variables with negative coeffs. *) let gather_bounds_of_row_ (self:t) (x_i:var_state) ~map ~get_lower : _ list * _ = assert (Var_state.is_basic x_i); let map_res = ref [] in let bounds = ref V_map.empty in Vec.iteri (fun j x_j -> if j <> x_i.idx then ( let c = Matrix.get self.matrix x_i.basic_idx j in if Q.(c <> zero) then ( match get_lower, Q.(c > zero) with | true, true | false, false -> begin match x_j.l_bound with | Some l -> map_res := (map c x_j l) :: !map_res; let op = if Q.(l.b_val.eps_factor <= zero) then Op.Geq else Op.Gt in bounds := V_map.add x_j.var (op, l) !bounds | None -> assert false (* we could decrease [x_j]?! *) end | true, false | false, true -> begin match x_j.u_bound with | Some u -> map_res := (map c x_j u) :: !map_res; let op = if Q.(u.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in bounds := V_map.add x_j.var (op, u) !bounds | None -> assert false (* we could increase [x_j]?! *) end ) )) self.vars; !map_res, !bounds (* do propagations for basic var [x_i] based on its implied bounds *) let propagate_basic_implied_bounds (self:t) ~on_propagate (x_i:var_state) : unit = assert (Var_state.is_basic x_i); let lits_of_row_ ~get_lower : V.lit list = let l, _ = gather_bounds_of_row_ self x_i ~get_lower ~map:(fun _ _ bnd -> bnd.b_lit) in l in let process_bount_lit (is_lower, bnd): unit = if is_lower then ( if Erat.(bnd.b_val < x_i.l_implied) then ( (* implied lower bound subsumes this lower bound *) let reason = lits_of_row_ ~get_lower:true in on_propagate bnd.b_lit ~reason ); if Erat.(bnd.b_val > x_i.u_implied) then ( (* lower bound is higher than implied upper bound *) match V.not_lit bnd.b_lit with | Some not_lit -> Log.debugf 50 (fun k->k"(@[lra.propagate.not@ :lower-bnd-of %a@ :bnd %a :lit %a@ :u-implied %a@])" Var_state.pp x_i Erat.pp bnd.b_val V.pp_lit bnd.b_lit Erat.pp x_i.u_implied); let reason = lits_of_row_ ~get_lower:false in on_propagate not_lit ~reason | None -> () ) ) else ( if Erat.(bnd.b_val > x_i.u_implied) then ( (* implied upper bound subsumes this upper bound *) let reason = lits_of_row_ ~get_lower:false in on_propagate bnd.b_lit ~reason ); if Erat.(bnd.b_val < x_i.l_implied) then ( (* upper bound is lower than implied lower bound *) match V.not_lit bnd.b_lit with | Some not_lit -> let reason = lits_of_row_ ~get_lower:true in on_propagate not_lit ~reason | None -> () ) ) in List.iter process_bount_lit x_i.all_bound_lits; () let update_implied_bounds_for_nbasic_ (self:t) (vs:var_state) ~on_propagate : unit = (* section 3.2.5: update implied bounds on basic variables that depend on [vs] *) for i = 0 to Matrix.n_rows self.matrix -1 do if Q.(Matrix.get self.matrix i vs.idx <> zero) then ( let x_i = Matrix.get_row_var self.matrix i in update_implied_bounds_ self x_i; propagate_basic_implied_bounds self ~on_propagate x_i; ); done let add_constraint ~on_propagate (self:t) (c:Constraint.t) (lit:V.lit) : unit = let open Constraint in let vs = find_or_create_n_basic_var_ self c.lhs in Log.debugf 5 (fun k->k "(@[simplex2.add-constraint@ :var %a@ :c %a@])" Var_state.pp_name vs Constraint.pp c); let is_lower_bnd, new_bnd_val = match c.op with | Leq -> false, Erat.make_q c.rhs | Lt -> false, Erat.make c.rhs Q.minus_one | Geq -> true, Erat.make_q c.rhs | Gt -> true, Erat.make c.rhs Q.one in let new_bnd = {b_val=new_bnd_val; b_lit=lit} in if is_lower_bnd then ( begin match vs.l_bound, vs.u_bound with | _, Some upper when Erat.(new_bnd.b_val > upper.b_val) -> (* [b_val <= x <= upper], but [b_val > upper] *) let cert = Unsat_cert.bounds vs ~lower:new_bnd ~upper in Stat.incr self.stat_unsat; raise (E_unsat cert) | Some lower, _ when Erat.(lower.b_val >= new_bnd.b_val) -> () (* subsumed by existing constraint, do nothing *) | _ -> (* save current bound *) Backtrack_stack.push self.bound_stack (vs, `Lower, vs.l_bound); vs.l_bound <- Some new_bnd; (* propagate *) List.iter (fun (is_lower', bnd) -> if is_lower' && Erat.(bnd.b_val < new_bnd.b_val) then ( (* subsumed *) Stat.incr self.stat_propagate; on_propagate (bnd.b_lit) ~reason:[new_bnd.b_lit]; ) else if not is_lower' && Erat.(bnd.b_val < new_bnd.b_val) then ( (* incompatible upper bound *) match V.not_lit bnd.b_lit with | Some l' -> Stat.incr self.stat_propagate; on_propagate l' ~reason:[new_bnd.b_lit]; | None -> () ) ) vs.all_bound_lits; if Var_state.is_n_basic vs && Erat.(vs.value < new_bnd.b_val) then ( (* line 5: need to update non-basic variable *) update_n_basic self vs new_bnd.b_val; update_implied_bounds_for_nbasic_ self vs ~on_propagate; ) end ) else ( begin match vs.l_bound, vs.u_bound with | Some lower, _ when Erat.(new_bnd.b_val < lower.b_val) -> (* [lower <= x <= b_val], but [b_val < lower] *) let cert = Unsat_cert.bounds vs ~lower ~upper:new_bnd in Stat.incr self.stat_unsat; raise (E_unsat cert) | _, Some upper when Erat.(upper.b_val <= new_bnd.b_val) -> () (* subsumed *) | _ -> Backtrack_stack.push self.bound_stack (vs, `Upper, vs.u_bound); vs.u_bound <- Some new_bnd; (* propagate *) List.iter (fun (is_lower', bnd) -> if not is_lower' && Erat.(bnd.b_val > new_bnd.b_val) then ( (* subsumed *) Stat.incr self.stat_propagate; on_propagate (bnd.b_lit) ~reason:[new_bnd.b_lit]; ) else if is_lower' && Erat.(bnd.b_val > new_bnd.b_val) then ( (* incompatible lower bound *) match V.not_lit bnd.b_lit with | Some l' -> Stat.incr self.stat_propagate; on_propagate l' ~reason:[new_bnd.b_lit]; | None -> () ) ) vs.all_bound_lits; if Var_state.is_n_basic vs && Erat.(vs.value > new_bnd.b_val) then ( (* line 5: need to update non-basic variable *) update_n_basic self vs new_bnd.b_val; update_implied_bounds_for_nbasic_ self vs ~on_propagate; ) end ) let declare_bound (self:t) (c:Constraint.t) (lit:V.lit) : unit = let vs = find_or_create_n_basic_var_ self c.lhs in Log.debugf 10 (fun k->k "(@[simplex2.declare-bound@ %a@ :lit %a@])" Constraint.pp c V.pp_lit lit); let is_lower_bnd, new_bnd_val = match c.op with | Leq -> false, Erat.make_q c.rhs | Lt -> false, Erat.make c.rhs Q.minus_one | Geq -> true, Erat.make_q c.rhs | Gt -> true, Erat.make c.rhs Q.one in let new_bnd = is_lower_bnd, {b_val=new_bnd_val; b_lit=lit} in vs.all_bound_lits <- new_bnd :: vs.all_bound_lits (* try to find basic variable that doesn't respect one of its bounds *) let find_basic_var_ (self:t) : (var_state * [`Lower | `Upper] * bound) option = let n = Matrix.n_rows self.matrix in let rec aux i = if i >= n then None else ( let x_i = Matrix.get_row_var self.matrix i in let v_i = x_i.value in match x_i.l_bound, x_i.u_bound with | Some l, _ when Erat.(l.b_val > v_i) -> Some (x_i, `Lower, l) | _, Some u when Erat.(u.b_val < v_i) -> Some (x_i, `Upper, u) | _ -> (aux[@tailcall]) (i+1) ) in aux 0 let find_n_basic_var_ (self:t) ~f : var_state option = let rec aux j = if j >= Vec.size self.vars then None else ( let x_j = Vec.get self.vars j in if Var_state.is_n_basic x_j && f x_j then Some x_j else aux (j+1) ) in aux 0 (* true if [x.value < x.u_bound] *) let has_upper_slack (x:var_state) : bool = match x.u_bound with | None -> true | Some bnd -> Erat.(x.value < bnd.b_val) (* true if [x.value > x.l_bound] *) let has_lower_slack (x:var_state) : bool = match x.l_bound with | None -> true | Some bnd -> Erat.(x.value > bnd.b_val) (* make a certificate from the row of basic variable [x_i] which is outside one of its bound, and whose row is tight on all non-basic variables. @param is_lower is true if the lower bound is not respected (i.e. [x_i] is too small) *) let cert_of_row_ (self:t) (x_i:var_state) (bnd:bound) ~is_lower : unsat_cert = Log.debugf 50 (fun k->k "(@[simplex.cert-of-row[lower: %B]@ x_i=%a@])" is_lower Var_state.pp x_i); assert (Var_state.is_basic x_i); let le, bounds = (* we get explanations for an (implied) upper bound if [bnd] is a lower bound, and conversely *) gather_bounds_of_row_ self x_i ~get_lower:(not is_lower) ~map:(fun c v _ -> c, v.var) in let op = if is_lower then if Q.(bnd.b_val.eps_factor <= zero) then Op.Geq else Op.Gt else if Q.(bnd.b_val.eps_factor >= zero) then Op.Leq else Op.Lt in let cert = Unsat_cert.unsat_basic x_i (op, bnd) le bounds in cert (* main satisfiability check. page 15, figure 3.2 *) let check_exn ~on_propagate:_ (self:t) : unit = let exception Stop in Log.debugf 20 (fun k->k "(@[simplex2.check@ %a@])" pp_stats self); Stat.incr self.stat_check; let m = self.matrix in try while true do _check_invariants_internal self; (* Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" pp self); *) (* basic variable that doesn't respect its bound *) let x_i, is_lower, bnd = match find_basic_var_ self with | Some (x, `Lower, bnd) -> x, true, bnd | Some (x, `Upper, bnd) -> x, false, bnd | None -> raise_notrace Stop (* line 4: we are done *) in if is_lower then ( (* line 5 *) let x_j = match find_n_basic_var_ self ~f:(fun x_j -> let a_ij = Matrix.get m x_i.basic_idx x_j.idx in (Q.sign a_ij > 0 && has_upper_slack x_j) || (Q.sign a_ij < 0 && has_lower_slack x_j)) with | Some x -> x | None -> let cert = cert_of_row_ self x_i bnd ~is_lower:true in Stat.incr self.stat_unsat; raise (E_unsat cert) in assert (Var_state.is_n_basic x_j); (* line 9 *) Log.debugf 50 (fun k->k "(@[simplex2.pivot@ :basic %a@ :n-basic %a@])" Var_state.pp x_i Var_state.pp x_j); pivot_and_update self x_i x_j bnd.b_val ) else ( (* line 10 *) let x_j = match find_n_basic_var_ self ~f:(fun x_j -> let a_ij = Matrix.get m x_i.basic_idx x_j.idx in (Q.sign a_ij < 0 && has_upper_slack x_j) || (Q.sign a_ij > 0 && has_lower_slack x_j)) with | Some x -> x | None -> let cert = cert_of_row_ self x_i bnd ~is_lower:false in Stat.incr self.stat_unsat; raise (E_unsat cert) in assert (Var_state.is_n_basic x_j); (* line 14 *) Log.debugf 50 (fun k->k "(@[simplex2.pivot@ :basic %a@ :n-basic %a@])" Var_state.pp x_i Var_state.pp x_j); pivot_and_update self x_i x_j bnd.b_val ) done; with Stop -> Log.debugf 50 (fun k->k "(@[simplex2.check.done@])"); () let create ?(stat=Stat.global) () : t = let self = { matrix=Matrix.create(); vars=Vec.create(); var_tbl=V_map.empty; bound_stack=Backtrack_stack.create(); undo_stack=Backtrack_stack.create(); stat_check=Stat.mk_int stat "simplex.check"; stat_unsat=Stat.mk_int stat "simplex.conflicts"; stat_define=Stat.mk_int stat "simplex.define"; stat_propagate=Stat.mk_int stat "simplex.propagate"; } in self type result = | Sat of Subst.t | Unsat of unsat_cert let default_eps = let denom = 1 lsl 10 in Q.(one / of_int denom) (* Find an epsilon that is small enough for finding a solution, yet it must be positive. {!Erat.t} values are used to turn strict bounds ([X > 0]) into non-strict bounds ([X >= 0 + ε]), because the simplex algorithm only deals with non-strict bounds. When a solution is found, we need to turn {!Erat.t} into {!Q.t} by finding a rational value that is small enough that it will fit into all the intervals of [t]. This rational will be the actual value of [ε]. *) let solve_epsilon (self:t) : Q.t = let eps = Iter.fold (fun eps x -> assert Q.(eps >= zero); assert (Var_state.in_bounds_ x); let x_val = x.value in (*Log.debugf 50 (fun k->k "(@[solve-eps v.base=%a, v.eps=%a, emax=%a@])" pp_q_dbg x_val.base pp_q_dbg x_val.eps_factor pp_q_dbg eps);*) (* is lower bound *) let eps = match x.l_bound with | Some {b_val;_} when Q.(Erat.evaluate eps b_val > Erat.evaluate eps x_val) -> assert (Erat.(x.value >= b_val)); assert (Q.(b_val.eps_factor > x.value.eps_factor)); (* current epsilon is too big. we need to make it smaller than [x.value - b_val]. - [b_val.base + eps * b_val.factor <= x.base + eps * x.factor] - [eps * (b_val.factor - x.factor) <= x.base - b_val.base] - [eps <= (x.base - b_val.base) / (b_val.factor - x.factor)] *) let new_eps = Q.((x_val.base - b_val.base) / (b_val.eps_factor - x_val.eps_factor)) in Q.min eps new_eps | _ -> eps in (* upper bound *) let eps = match x.u_bound with | Some { b_val; _} when Q.(Erat.evaluate eps b_val < Erat.evaluate eps x_val) -> assert (Erat.(x.value <= b_val)); (* current epsilon is too big. we need to make it smaller than [b_val - x.value]. - [x.base + eps * x.factor <= b_val.base + eps * b_val.factor] - [eps * (x.factor - b_val.factor) <= b_val.base - x.base] - [eps <= (b_val.base - x.base) / (x.factor - b_val.factor)] *) let new_eps = Q.((b_val.base - x_val.base) / (x_val.eps_factor - b_val.eps_factor)) in Q.min eps new_eps | _ -> eps in eps) default_eps (Vec.to_seq self.vars) in if Q.(eps >= one) then Q.one else eps let model_ self = let eps = solve_epsilon self in Log.debugf 50 (fun k->k "(@[simplex2.model@ :epsilon-val %a@])" pp_q_dbg eps); let subst = Vec.to_seq self.vars |> Iter.fold (fun subst x -> let {base;eps_factor} = x.value in let v = Q.(base + eps * eps_factor) in V_map.add x.var v subst) V_map.empty in Log.debugf 10 (fun k->k "(@[simplex2.model@ %a@])" Subst.pp subst); subst let check ~on_propagate (self:t) : result = try check_exn ~on_propagate self; let m = model_ self in Sat m with E_unsat c -> Unsat c let _check_cert (cert:unsat_cert) : unit = match cert with | E_bounds {x=_; lower; upper} -> (* unsat means [lower > upper] *) if not Erat.(lower.b_val > upper.b_val) then ( Error.errorf "invalid simplex cert: %a" Unsat_cert.pp cert ); | E_unsat_basic { x=_; x_bound; le; bounds } -> (* sum of [bounds], weighted by coefficient from [le] *) let is_lower, x_b = match x_bound with | (Leq | Lt), b -> false, b.b_val | (Geq|Gt), b -> true, b.b_val in let sum = List.fold_left (fun sum (c, x) -> match V_map.find x bounds with | exception Not_found -> Error.errorf "invalid simplex cert:@ %a@ var %a has no bound" Unsat_cert.pp cert Var.pp x | Op.(Geq | Gt), _ when CCBool.equal is_lower Q.(c > zero) -> Error.errorf "invalid simplex cert:@ %a@ variable %a has coeff of the wrong sign %a" Unsat_cert.pp cert Var.pp x pp_q_dbg c | Op.(Lt | Leq), _ when CCBool.equal is_lower Q.(c < zero) -> Error.errorf "invalid simplex cert:@ %a@ variable %a has coeff of the wrong sign %a" Unsat_cert.pp cert Var.pp x pp_q_dbg c | _, b -> Erat.(sum + c * b.b_val)) Erat.zero le in (* unsat if lower bound [x_b] is > [sum], which is an upper bound *) let ok = if is_lower then Erat.(x_b > sum) else Erat.(x_b < sum) in if not ok then ( Error.errorf "invalid simplex cert:@ %a@ sum of linexpr is %a" Unsat_cert.pp cert Erat.pp sum ) end