From f226c6b82035b97dec0ef0ad271351b5b1ff77a7 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 12 Feb 2021 19:42:16 -0500 Subject: [PATCH] fix(lra): many fixes in simplex; some fixme/todo --- src/arith/lra/simplex2.ml | 184 +++++++++++++++++++++++++++++--------- 1 file changed, 144 insertions(+), 40 deletions(-) diff --git a/src/arith/lra/simplex2.ml b/src/arith/lra/simplex2.ml index b35df7b1..626bc5d7 100644 --- a/src/arith/lra/simplex2.ml +++ b/src/arith/lra/simplex2.ml @@ -5,6 +5,8 @@ de Moura and Dutertre. *) +open CCMonomorphic + module type VAR = Linear_expr_intf.VAR (** {2 Basic operator} *) @@ -79,6 +81,9 @@ module type S = sig exception E_unsat of Unsat_cert.t + val add_var : t -> V.t -> unit + (** Make sure the variable exists in the simplex. *) + val add_constraint : t -> Constraint.t -> V.lit -> unit (** Add a constraint to the simplex. @raise Unsat if it's immediately obvious that this is not satisfiable. *) @@ -93,6 +98,11 @@ module type S = sig val check : t -> result (** Call {!check_exn} and return a model or a proof of unsat. *) + + (**/**) + val _check_invariants : t -> unit + (* check internal invariants *) + (**/**) end (* TODO(optim): page 14, paragraph 2: we could detect which variables occur in no @@ -108,6 +118,9 @@ module Make(Var: VAR) type num = Q.t (** Numbers *) + let pp_q_float n out q = Fmt.fprintf out "%*.1f" n (Q.to_float q) + let pp_q_dbg = pp_q_float 1 + module Constraint = struct type op = Op.t @@ -120,7 +133,7 @@ module Make(Var: VAR) let pp out (self:t) = Fmt.fprintf out "(@[%a %s@ %a@])" V.pp self.lhs - (Op.to_string self.op) Q.pp_print self.rhs + (Op.to_string self.op) pp_q_dbg self.rhs let leq lhs rhs = {lhs;rhs;op=Op.Leq} let lt lhs rhs = {lhs;rhs;op=Op.Lt} @@ -132,7 +145,7 @@ module Make(Var: VAR) type t = num V_map.t let pp out (self:t) : unit = let pp_pair out (v,n) = - Fmt.fprintf out "(@[%a := %a@])" V.pp v Q.pp_print n in + Fmt.fprintf out "(@[%a := %a@])" V.pp v pp_q_dbg n in Fmt.fprintf out "{@[%a@]}" (Fmt.iter pp_pair) (V_map.to_iter self) let to_string = Fmt.to_string pp end @@ -177,10 +190,10 @@ module Make(Var: VAR) let pp out e : unit = if Q.equal Q.zero (eps_factor e) - then Q.pp_print out (base e) + then pp_q_dbg out (base e) else Fmt.fprintf out "(@[%a + @<1>ε * %a@])" - Q.pp_print (base e) Q.pp_print (eps_factor e) + pp_q_dbg (base e) pp_q_dbg (eps_factor e) end type var_idx = int @@ -230,33 +243,36 @@ module Make(Var: VAR) cols: num Vec.t; } type t = { + mutable n_cols: int; rows: row Vec.t } - let create() : t = {rows=Vec.create()} + let create() : t = + {n_cols=0; rows=Vec.create()} let[@inline] n_rows self = Vec.size self.rows - let n_cols self = - if n_rows self=0 then 0 - else Vec.size (Vec.get self.rows 0).cols + let[@inline] n_cols self = self.n_cols let pp out self = - Fmt.fprintf out "{@["; + Fmt.fprintf out "{@[matrix[%dx%d]@," (n_rows self) (n_cols self); Vec.iteri (fun i row -> - Fmt.fprintf out "@[%-5d: %a@]@," i - (Fmt.iter ~sep:(Fmt.return "@ ") Q.pp_print) (Vec.to_seq row.cols)) + Fmt.fprintf out "{@[r%-3d: %a@]}@," i + (Fmt.iter ~sep:(Fmt.return "@ ") (pp_q_float 6)) (Vec.to_seq row.cols)) self.rows; Fmt.fprintf out "@]}" let to_string = Fmt.to_string pp let add_column self = + self.n_cols <- 1 + self.n_cols; Vec.iter (fun r -> Vec.push r.cols Q.zero) self.rows let add_row_and_column self : int = let n = n_rows self in let j = n_cols self in add_column self; - let row = {var_idx=j; cols=Vec.make (j+1) Q.zero} in + let cols = Vec.make (j+1) Q.zero in + for _k=0 to j do Vec.push cols Q.zero done; + let row = {var_idx=j; cols} in Vec.push self.rows row; n @@ -302,11 +318,25 @@ module Make(Var: VAR) let[@inline] is_basic (self:t) : bool = self.basic_idx >= 0 let[@inline] is_n_basic (self:t) : bool = self.basic_idx < 0 + + let in_bounds_ self = + (match self.l_bound with None -> true | Some b -> Erat.(self.value >= b.b_val)) && + (match self.u_bound with None -> true | Some b -> Erat.(self.value <= b.b_val)) + let pp out self = - Fmt.fprintf out "(@[var %a@ :basic %B@ :value %a@ :lbound %a@ :ubound %a@])" - Var.pp self.var (is_basic self) Erat.pp self.value - Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.l_bound - Fmt.(Dump.option (map (fun b->b.b_val) Erat.pp)) self.u_bound + let bnd_status = if in_bounds_ self then "" else "(out of bounds)" in + let pp_bnd what out = function + | None -> () + | Some b -> Fmt.fprintf out "@ @[%s %a@]" what Erat.pp b.b_val + and pp_basic_idx out () = + if self.basic_idx < 0 then () else Fmt.int out self.basic_idx + in + Fmt.fprintf out + "(@[var[%s%a]%s %a@ :value %a%a%a@])" + (if is_basic self then "B" else "N") pp_basic_idx () + bnd_status + Var.pp self.var Erat.pp self.value + (pp_bnd ":lower") self.l_bound (pp_bnd ":upper") self.u_bound end type t = { @@ -328,7 +358,7 @@ module Make(Var: VAR) () let pp out (self:t) : unit = - Fmt.fprintf out "(@[simplex@ :vars %a@ :matrix %a@])" + Fmt.fprintf out "(@[simplex@ @[<1>:vars@ [@[%a@]]@]@ @[<1>:matrix@ %a@]@])" (Vec.pp Var_state.pp) self.vars Matrix.pp self.matrix @@ -340,15 +370,17 @@ module Make(Var: VAR) let define (self:t) (x:V.t) (le:_ list) : unit = assert (not (has_var_ self x)); Log.debugf 5 (fun k->k "(@[simplex.define@ %a@ :le %a@])" - Var.pp x Fmt.(Dump.(list @@ pair Q.pp_print Var.pp)) le); + Var.pp x Fmt.(Dump.(list @@ pair pp_q_dbg Var.pp)) le); + let idx = Vec.size self.vars in let n = Matrix.add_row_and_column self.matrix in let vs = { var=x; value=Erat.zero; - idx=Vec.size self.vars; + idx; basic_idx=n; l_bound=None; u_bound=None; } in + assert (Var_state.is_basic vs); Vec.push self.vars vs; self.var_tbl <- V_map.add x vs self.var_tbl; (* set coefficients in the matrix's new row: [-x + le = 0] *) @@ -356,7 +388,14 @@ module Make(Var: VAR) List.iter (fun (coeff,v2) -> let vs2 = find_var_ self v2 in + + (* FIXME: if [vs2] is basic, instead recurse with vs2's row. + copy coefficients of [vs2]'s row but multiplied with [coeff] . + + See t4_short *) + Matrix.add self.matrix n vs2.idx coeff; + vs.value <- Erat.(vs.value + coeff * vs2.value); (* update value of [v] *) ) le; () @@ -383,6 +422,10 @@ module Make(Var: VAR) *) let update_n_basic (self:t) (x:var_state) (v:erat) : unit = assert (Var_state.is_n_basic x); + Log.debugf 50 + (fun k->k "(@[simplex.update-n-basic@ %a@ :new-val %a@])" + Var_state.pp x Erat.pp v); + let m = self.matrix in let i = x.idx in @@ -422,7 +465,7 @@ module Make(Var: VAR) (* now pivot the variables so that [x_j]'s coeff is -1 *) let new_coeff = Q.(minus_one / a_ij) in for k=0 to Matrix.n_cols m-1 do - Matrix.mult m x_i.idx k new_coeff; + Matrix.mult m x_i.basic_idx k new_coeff; done; x_j.basic_idx <- x_i.basic_idx; x_i.basic_idx <- -1; @@ -443,11 +486,11 @@ module Make(Var: VAR) let pp out (self:t) = let pp_bnd out (n,lit) = - Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit Q.pp_print n + Fmt.fprintf out "(@[%a@ coeff %a@])" V.pp_lit lit pp_q_dbg n in Fmt.fprintf out "(@[cert@ :bounds %a@ :defs %a@])" Fmt.(Dump.list pp_bnd) self.cert_bounds - Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair Q.pp_print V.pp)))) self.cert_defs + Fmt.(Dump.list (Dump.pair V.pp (Dump.list (Dump.pair pp_q_dbg V.pp)))) self.cert_defs let mk ~defs ~bounds : t = { cert_defs=defs; cert_bounds=bounds } @@ -455,8 +498,13 @@ module Make(Var: VAR) exception E_unsat of Unsat_cert.t + let add_var self (v:V.t) : unit = + ignore (find_or_create_n_basic_var_ self v : var_state); + () + let add_constraint (self:t) (c:Constraint.t) (lit:V.lit) : unit = let open Constraint in + Log.debugf 5 (fun k->k "(@[simplex2.add-constraint@ %a@])" Constraint.pp c); let vs = find_or_create_n_basic_var_ self c.lhs in let is_lower_bnd, new_bnd_val = match c.op with @@ -561,12 +609,12 @@ module Make(Var: VAR) page 15, figure 3.2 *) let check_exn (self:t) : unit = let exception Stop in - Log.debugf 5 (fun k->k "(@[simplex2.check@ %a@])" Matrix.pp self.matrix); + Log.debugf 20 (fun k->k "(@[simplex2.check@ %a@])" pp self); let m = self.matrix in try while true do - Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" Matrix.pp self.matrix); + Log.debugf 50 (fun k->k "(@[simplex2.check.iter@ %a@])" pp self); (* basic variable that doesn't respect its bound *) let x_i, is_lower, bnd = match find_basic_var_ self with @@ -615,7 +663,9 @@ module Make(Var: VAR) pivot_and_update self x_i x_j bnd.b_val ) done; - with Stop -> () + with Stop -> + Log.debugf 50 (fun k->k "(@[simplex2.check.done@])"); + () let create () : t = let self = { @@ -630,6 +680,10 @@ module Make(Var: VAR) | Sat of Subst.t | Unsat of unsat_cert + let default_eps = + let denom = 1 lsl 10 in + Q.(one / of_int denom) + (* Find an epsilon that is small enough for finding a solution, yet it must be positive. @@ -641,29 +695,65 @@ module Make(Var: VAR) all the intervals of [t]. This rational will be the actual value of [ε]. *) let solve_epsilon (self:t) : Q.t = - let emax = + let eps = Iter.fold - (fun emax x -> - let {base=v; eps_factor=e_v} = x.value in - (* lower bound *) - let emax = match x.l_bound with - | Some {b_val={base=low;eps_factor=e_low};_} when Q.(e_v < e_low) -> - Q.min emax Q.((low - v) / (e_v - e_low)) - | _ -> emax + (fun eps x -> + assert Q.(eps >= zero); + assert (Var_state.in_bounds_ x); + + let x_val = x.value in + Log.debugf 50 (fun k->k "v.base=%a, v.eps=%a, emax=%a" + pp_q_dbg x_val.base pp_q_dbg x_val.eps_factor + pp_q_dbg eps); + + (* is lower bound *) + let eps = match x.l_bound with + | Some {b_val;_} + when Q.(Erat.evaluate eps b_val > Erat.evaluate eps x_val) -> + assert (Erat.(x.value >= b_val)); + assert (Q.(b_val.eps_factor > x.value.eps_factor)); + (* current epsilon is too big. we need to make it smaller + than [x.value - b_val]. + - [b_val.base + eps * b_val.factor + <= x.base + eps * x.factor] + - [eps * (b_val.factor - x.factor) <= x.base - b_val.base] + - [eps <= (x.base - b_val.base) / (b_val.factor - x.factor)] + *) + let new_eps = + Q.((x_val.base - b_val.base) / + (b_val.eps_factor - x_val.eps_factor)) + in + Q.min eps new_eps + | _ -> eps in (* upper bound *) - let emax = match x.u_bound with - | Some { b_val={base=upp;eps_factor=e_upp}; _} when Q.(e_v > e_upp) -> - min emax Q.((upp - v) / (e_v - e_upp)) - | _ -> emax + let eps = match x.u_bound with + | Some { b_val; _} + when Q.(Erat.evaluate eps b_val < Erat.evaluate eps x_val) -> + assert (Erat.(x.value <= b_val)); + (* current epsilon is too big. we need to make it smaller + than [b_val - x.value]. + - [x.base + eps * x.factor + <= b_val.base + eps * b_val.factor] + - [eps * (x.factor - b_val.factor) <= b_val.base - x.base] + - [eps <= (b_val.base - x.base) / (x.factor - b_val.factor)] + *) + let new_eps = + Q.((b_val.base - x_val.base) / + (x_val.eps_factor - b_val.eps_factor)) + in + Log.debugf 5 (fun k->k "new max=%.5f" @@ Q.to_float new_eps); + Q.min eps new_eps + | _ -> eps in - emax) - Q.inf (Vec.to_seq self.vars) + eps) + default_eps (Vec.to_seq self.vars) in - if Q.compare emax Q.one >= 0 then Q.one else emax + if Q.(eps >= one) then Q.one else eps let model_ self = let eps = solve_epsilon self in + Log.debugf 50 (fun k->k "(@[simplex2.model@ :epsilon-val %a@])" pp_q_dbg eps); let subst = Vec.to_seq self.vars |> Iter.fold @@ -673,6 +763,8 @@ module Make(Var: VAR) V_map.add x.var v subst) V_map.empty in + Log.debugf 5 + (fun k->k "(@[simplex2.model@ %a@])" Subst.pp subst); subst let check (self:t) : result = @@ -735,4 +827,16 @@ module Make(Var: VAR) ) else `Diff_not_0 expr_minus_x end *) + + let _check_invariants self : unit = + Vec.iteri (fun i v -> assert(v.idx = i)) self.vars; + let n = Vec.size self.vars in + assert (Matrix.n_rows self.matrix = 0 || Matrix.n_cols self.matrix = n); + for i = 0 to Matrix.n_rows self.matrix-1 do + let v = Vec.get self.vars (Matrix.get_row_var_idx self.matrix i) in + assert (Var_state.is_basic v); + assert (v.basic_idx = i); + done; + () (* TODO: more *) + end