(* copyright (c) 2014-2018, Guillaume Bury, Simon Cruanes *) (* OPTIMS: * - distinguish separate systems (that do not interact), such as in { 1 <= 3x = 3y <= 2; z <= 3} ? * - Implement gomorry cuts ? *) open Containers module type VAR = Linear_expr_intf.VAR module type FRESH = Linear_expr_intf.FRESH module type VAR_GEN = Linear_expr_intf.VAR_GEN module type S = Simplex_intf.S module type S_FULL = Simplex_intf.S_FULL module Vec = CCVector module Matrix : sig type 'a t val create : unit -> 'a t val get : 'a t -> int -> int -> 'a val set : 'a t -> int -> int -> 'a -> unit val get_row : 'a t -> int -> 'a Vec.vector val copy : 'a t -> 'a t val n_row : _ t -> int val n_col : _ t -> int val push_row : 'a t -> 'a -> unit (* new row, filled with element *) val push_col : 'a t -> 'a -> unit (* new column, filled with element *) (**/**) val check_invariants : _ t -> bool (**/**) end = struct type 'a t = { mutable n_col: int; (* num of columns *) tab: 'a Vec.vector Vec.vector; } let[@inline] create() : _ = {tab=Vec.create(); n_col=0} let[@inline] get m i j = Vec.get (Vec.get m.tab i) j let[@inline] get_row m i = Vec.get m.tab i let[@inline] set (m:_ t) i j x = Vec.set (Vec.get m.tab i) j x let[@inline] copy m = {m with tab=Vec.map Vec.copy m.tab} let[@inline] n_row m = Vec.length m.tab let[@inline] n_col m = m.n_col let push_row m x = Vec.push m.tab (Vec.make (n_col m) x) let push_col m x = m.n_col <- m.n_col + 1; Vec.iter (fun row -> Vec.push row x) m.tab let check_invariants m = Vec.for_all (fun r -> Vec.length r = n_col m) m.tab end (* use non-polymorphic comparison ops *) open Int.Infix (* Simplex Implementation *) module Make_inner (Var: VAR) (VMap : CCMap.S with type key=Var.t) (Param: sig type t val copy : t -> t end) = struct module Var_map = VMap module M = Var_map (* Exceptions *) exception Unsat of Var.t exception AbsurdBounds of Var.t exception NoneSuitable type param = Param.t type var = Var.t type lit = Var.lit type basic_var = var type nbasic_var = var type erat = { base: Q.t; (* reference number *) eps_factor: Q.t; (* coefficient for epsilon, the infinitesimal *) } (** Epsilon-rationals, used for strict bounds *) module Erat = struct type t = erat let zero : t = {base=Q.zero; eps_factor=Q.zero} let[@inline] make base eps_factor : t = {base; eps_factor} let[@inline] base t = t.base let[@inline] eps_factor t = t.eps_factor let[@inline] mul k e = make Q.(k * e.base) Q.(k * e.eps_factor) let[@inline] sum e1 e2 = make Q.(e1.base + e2.base) Q.(e1.eps_factor + e2.eps_factor) let[@inline] compare e1 e2 = match Q.compare e1.base e2.base with | 0 -> Q.compare e1.eps_factor e2.eps_factor | x -> x let lt a b = compare a b < 0 let gt a b = compare a b > 0 let[@inline] min x y = if compare x y <= 0 then x else y let[@inline] max x y = if compare x y >= 0 then x else y let[@inline] evaluate (epsilon:Q.t) (e:t) : Q.t = Q.(e.base + epsilon * e.eps_factor) let pp out e = if Q.equal Q.zero (eps_factor e) then Q.pp_print out (base e) else Format.fprintf out "(@[%a + @<1>ε * %a@])" Q.pp_print (base e) Q.pp_print (eps_factor e) end let str_of_var = Format.to_string Var.pp let str_of_erat = Format.to_string Erat.pp let str_of_q = Format.to_string Q.pp_print type bound = { value : Erat.t; reason : lit option; } type t = { param: param; tab : Q.t Matrix.t; (* the matrix of coefficients *) basic : basic_var Vec.vector; (* basic variables *) nbasic : nbasic_var Vec.vector; (* non basic variables *) mutable assign : Erat.t M.t; (* assignments *) mutable bounds : (bound * bound) M.t; (* (lower, upper) bounds for variables *) mutable idx_basic : int M.t; (* basic var -> its index in [basic] *) mutable idx_nbasic : int M.t; (* non basic var -> its index in [nbasic] *) } type cert = { cert_var: var; cert_expr: (Q.t * var) list; } type res = | Solution of Q.t Var_map.t | Unsatisfiable of cert let create param : t = { param: param; tab = Matrix.create (); basic = Vec.create (); nbasic = Vec.create (); assign = M.empty; bounds = M.empty; idx_basic = M.empty; idx_nbasic = M.empty; } let copy t = { param = Param.copy t.param; tab = Matrix.copy t.tab; basic = Vec.copy t.basic; nbasic = Vec.copy t.nbasic; assign = t.assign; bounds = t.bounds; idx_nbasic = t.idx_nbasic; idx_basic = t.idx_basic; } let index_basic (t:t) (x:basic_var) : int = match M.find x t.idx_basic with | n -> n | exception Not_found -> -1 let index_nbasic (t:t) (x:nbasic_var) : int = match M.find x t.idx_nbasic with | n -> n | exception Not_found -> -1 let[@inline] mem_basic (t:t) (x:var) : bool = M.mem x t.idx_basic let[@inline] mem_nbasic (t:t) (x:var) : bool = M.mem x t.idx_nbasic (* check invariants, for test purposes *) let check_invariants (t:t) : bool = Matrix.check_invariants t.tab && Vec.for_all (fun v -> mem_basic t v) t.basic && Vec.for_all (fun v -> mem_nbasic t v) t.nbasic && Vec.for_all (fun v -> not (mem_nbasic t v)) t.basic && Vec.for_all (fun v -> not (mem_basic t v)) t.nbasic && Vec.for_all (fun v -> Var_map.mem v t.assign) t.nbasic && Vec.for_all (fun v -> not (Var_map.mem v t.assign)) t.basic && true (* find the definition of the basic variable [x], as a linear combination of non basic variables *) let find_expr_basic_opt t (x:var) : Q.t Vec.vector option = begin match index_basic t x with | -1 -> None | i -> Some (Matrix.get_row t.tab i) end let find_expr_basic t (x:basic_var) : Q.t Vec.vector = begin match find_expr_basic_opt t x with | None -> assert false | Some e -> e end (* build the expression [y = \sum_i (if x_i=y then 1 else 0)·x_i] *) let find_expr_nbasic t (x:nbasic_var) : Q.t Vec.vector = Vec.map (fun y -> if Var.compare x y = 0 then Q.one else Q.zero) t.nbasic (* find expression of [x] *) let find_expr_total (t:t) (x:var) : Q.t Vec.vector = match find_expr_basic_opt t x with | Some e -> e | None -> assert (mem_nbasic t x); find_expr_nbasic t x (* compute value of basic variable. It can be computed by using [x]'s definition in terms of nbasic variables, which have values *) let value_basic (t:t) (x:basic_var) : Erat.t = assert (mem_basic t x); let res = ref Erat.zero in let expr = find_expr_basic t x in for i = 0 to Vec.length expr - 1 do let val_nbasic_i = try M.find (Vec.get t.nbasic i) t.assign with Not_found -> assert false in res := Erat.sum !res (Erat.mul (Vec.get expr i) val_nbasic_i) done; !res (* extract a value for [x] *) let[@inline] value (t:t) (x:var) : Erat.t = try M.find x t.assign (* nbasic variables are assigned *) with Not_found -> value_basic t x (* trivial bounds *) let empty_bounds : bound * bound = { value = Erat.make Q.minus_inf Q.zero; reason = None; }, { value = Erat.make Q.inf Q.zero; reason = None; } (* find bounds of [x] *) let[@inline] get_bounds (t:t) (x:var) : bound * bound = try M.find x t.bounds with Not_found -> empty_bounds let[@inline] get_bounds_values (t:t) (x:var) : Erat.t * Erat.t = let l, u = get_bounds t x in l.value, u.value (* is [value x] within the bounds for [x]? *) let is_within_bounds (t:t) (x:var) : bool * Erat.t = let v = value t x in let low, upp = get_bounds_values t x in if Erat.compare v low < 0 then false, low else if Erat.compare v upp > 0 then false, upp else true, v (* add nbasic variables *) let add_vars (t:t) (l:var list) : unit = (* add new variable to idx and array for nbasic, removing duplicates and variables already present *) let idx_nbasic, _, l = List.fold_left (fun ((idx_nbasic, offset, l) as acc) x -> if mem_basic t x then acc else if M.mem x idx_nbasic then acc else ( (* allocate new index for [x] *) M.add x offset idx_nbasic, offset+1, x::l )) (t.idx_nbasic, Vec.length t.nbasic, []) l in (* add new columns to the matrix *) let old_dim = Matrix.n_col t.tab in List.iter (fun _ -> Matrix.push_col t.tab Q.zero) l; assert (old_dim + List.length l = Matrix.n_col t.tab); Vec.append_list t.nbasic (List.rev l); (* assign these variables *) t.assign <- List.fold_left (fun acc y -> M.add y Erat.zero acc) t.assign l; t.idx_nbasic <- idx_nbasic; () (* define basic variable [x] by [eq] in [t] *) let add_eq (t:t) (x, eq : basic_var * _ list) : unit = if mem_basic t x || mem_nbasic t x then ( invalid_arg (Format.sprintf "Variable `%a` already defined." Var.pp x); ); add_vars t (List.map snd eq); (* add [x] as a basic var *) t.idx_basic <- M.add x (Vec.length t.basic) t.idx_basic; Vec.push t.basic x; (* add new row for defining [x] *) assert (Matrix.n_col t.tab > 0); Matrix.push_row t.tab Q.zero; let row_i = Matrix.n_row t.tab - 1 in assert (row_i >= 0); (* now put into the row the coefficients corresponding to [eq], expanding basic variables to their definition *) List.iter (fun (c, x) -> let expr = find_expr_total t x in assert (Vec.length expr = Matrix.n_col t.tab); Vec.iteri (fun j c' -> if not (Q.equal Q.zero c') then ( Matrix.set t.tab row_i j Q.(Matrix.get t.tab row_i j + c * c') )) expr) eq; () (* add bounds to [x] in [t] *) let add_bound_aux (t:t) (x:var) (low:Erat.t) (low_reason:lit option) (upp:Erat.t) (upp_reason:lit option) : unit = add_vars t [x]; let l, u = get_bounds t x in let l' = if Erat.lt low l.value then l else { value = low; reason = low_reason; } in let u' = if Erat.gt upp u.value then u else { value = upp; reason = upp_reason; } in t.bounds <- M.add x (l', u') t.bounds let add_bounds (t:t) ?strict_lower:(slow=false) ?strict_upper:(supp=false) ?lower_reason ?upper_reason (x, l, u) : unit = let e1 = if slow then Q.one else Q.zero in let e2 = if supp then Q.neg Q.one else Q.zero in add_bound_aux t x (Erat.make l e1) lower_reason (Erat.make u e2) upper_reason; if mem_nbasic t x then ( let b, v = is_within_bounds t x in if not b then ( t.assign <- M.add x v t.assign; ) ) let add_lower_bound t ?strict ~reason x l = add_bounds t ?strict_lower:strict ~lower_reason:reason (x,l,Q.inf) let add_upper_bound t ?strict ~reason x u = add_bounds t ?strict_upper:strict ~upper_reason:reason (x,Q.minus_inf,u) (* full assignment *) let full_assign (t:t) : (var * Erat.t) Iter.t = Iter.append (Vec.to_iter t.nbasic) (Vec.to_iter t.basic) |> Iter.map (fun x -> x, value t x) let[@inline] min x y = if Q.compare x y < 0 then x else y (* Find an epsilon that is small enough for finding a solution, yet it must be positive. {!Erat.t} values are used to turn strict bounds ([X > 0]) into non-strict bounds ([X >= 0 + ε]), because the simplex algorithm only deals with non-strict bounds. When a solution is found, we need to turn {!Erat.t} into {!Q.t} by finding a rational value that is small enough that it will fit into all the intervals of [t]. This rational will be the actual value of [ε]. *) let solve_epsilon (t:t) : Q.t = let emax = M.fold (fun x ({ value = {base=low;eps_factor=e_low}; _}, { value = {base=upp;eps_factor=e_upp}; _}) emax -> let {base=v; eps_factor=e_v} = value t x in (* lower bound *) let emax = if Q.compare low Q.minus_inf > 0 && Q.compare e_v e_low < 0 then min emax Q.((low - v) / (e_v - e_low)) else emax in (* upper bound *) if Q.compare upp Q.inf < 0 && Q.compare e_v e_upp > 0 then min emax Q.((upp - v) / (e_v - e_upp)) else emax) t.bounds Q.inf in if Q.compare emax Q.one >= 0 then Q.one else emax let get_full_assign_seq (t:t) : _ Iter.t = let e = solve_epsilon t in let f = Erat.evaluate e in full_assign t |> Iter.map (fun (x,v) -> x, f v) let get_full_assign t : Q.t Var_map.t = Var_map.of_iter (get_full_assign_seq t) (* Find nbasic variable suitable for pivoting with [x]. A nbasic variable [y] is suitable if it "goes into the right direction" (its coefficient in the definition of [x] is of the adequate sign) and if it hasn't reached its bound in this direction. precondition: [x] is a basic variable whose value in current assignment is outside its bounds We return the smallest (w.r.t Var.compare) suitable variable. This is important for termination. *) let find_suitable_nbasic_for_pivot (t:t) (x:basic_var) : nbasic_var * Q.t = assert (mem_basic t x); let _, v = is_within_bounds t x in let b = Erat.compare (value t x) v < 0 in (* is nbasic var [y], with coeff [a] in definition of [x], suitable? *) let test (y:nbasic_var) (a:Q.t) : bool = assert (mem_nbasic t y); let v = value t y in let low, upp = get_bounds_values t y in if b then ( (Erat.lt v upp && Q.compare a Q.zero > 0) || (Erat.gt v low && Q.compare a Q.zero < 0) ) else ( (Erat.gt v low && Q.compare a Q.zero > 0) || (Erat.lt v upp && Q.compare a Q.zero < 0) ) in let nbasic_vars = t.nbasic in let expr = find_expr_basic t x in (* find best suitable variable *) let rec aux i = if i = Vec.length nbasic_vars then ( assert (i = Vec.length expr); None ) else ( let y = Vec.get nbasic_vars i in let a = Vec.get expr i in if test y a then ( (* see if other variables are better suited *) begin match aux (i+1) with | None -> Some (y,a) | Some (z, _) as res_tail -> if Var.compare y z <= 0 then Some (y,a) else res_tail end ) else ( aux (i+1) ) ) in begin match aux 0 with | Some res -> res | None -> raise NoneSuitable end (* pivot to exchange [x] and [y] *) let pivot (t:t) (x:basic_var) (y:nbasic_var) (a:Q.t) : unit = (* swap values ([x] becomes assigned) *) let val_x = value t x in t.assign <- t.assign |> M.remove y |> M.add x val_x; (* Matrixrix Pivot operation *) let kx = index_basic t x in let ky = index_nbasic t y in for j = 0 to Vec.length t.nbasic - 1 do if Var.compare y (Vec.get t.nbasic j) = 0 then ( Matrix.set t.tab kx j Q.(one / a) ) else ( Matrix.set t.tab kx j Q.(neg (Matrix.get t.tab kx j) / a) ) done; for i = 0 to Vec.length t.basic - 1 do if i <> kx then ( let c = Matrix.get t.tab i ky in Matrix.set t.tab i ky Q.zero; for j = 0 to Vec.length t.nbasic - 1 do Matrix.set t.tab i j Q.(Matrix.get t.tab i j + c * Matrix.get t.tab kx j) done ) done; (* Switch x and y in basic and nbasic vars *) Vec.set t.basic kx y; Vec.set t.nbasic ky x; t.idx_basic <- t.idx_basic |> M.remove x |> M.add y kx; t.idx_nbasic <- t.idx_nbasic |> M.remove y |> M.add x ky; () (* find minimum element of [arr] (wrt [cmp]) that satisfies predicate [f] *) let find_min_filter ~cmp (f:'a -> bool) (arr:('a,_) Vec.t) : 'a option = (* find the first element that satisfies [f] *) let rec aux_find_first i = if i = Vec.length arr then None else ( let x = Vec.get arr i in if f x then aux_compare_with x (i+1) else aux_find_first (i+1) ) (* find if any element of [l] satisfies [f] and is smaller than [x] *) and aux_compare_with x i = if i = Vec.length arr then Some x else ( let y = Vec.get arr i in let best = if f y && cmp y x < 0 then y else x in aux_compare_with best (i+1) ) in aux_find_first 0 (* check bounds *) let check_bounds (t:t) : unit = M.iter (fun x (l, u) -> if Erat.gt l.value u.value then raise (AbsurdBounds x)) t.bounds (* actual solving algorithm *) let solve_aux (t:t) : unit = check_bounds t; (* select the smallest basic variable that is not satisfied in the current assignment. *) let rec aux_select_basic_var () = match find_min_filter ~cmp:Var.compare (fun x -> not (fst (is_within_bounds t x))) t.basic with | Some x -> aux_pivot_on_basic x | None -> () (* remove the basic variable *) and aux_pivot_on_basic x = let _b, v = is_within_bounds t x in assert (not _b); match find_suitable_nbasic_for_pivot t x with | y, a -> (* exchange [x] and [y] by pivoting *) pivot t x y a; (* assign [x], now a nbasic variable, to the faulty bound [v] *) t.assign <- M.add x v t.assign; (* next iteration *) aux_select_basic_var () | exception NoneSuitable -> raise (Unsat x) in aux_select_basic_var (); () (* main method for the user to call *) let solve (t:t) : res = try solve_aux t; Solution (get_full_assign t) with | Unsat x -> let cert_expr = List.combine (Vec.to_list (find_expr_basic t x)) (Vec.to_list t.nbasic) in Unsatisfiable { cert_var=x; cert_expr; } (* FIXME *) | AbsurdBounds x -> Unsatisfiable { cert_var=x; cert_expr=[]; } (* add [c·x] to [m] *) let add_expr_ (x:var) (c:Q.t) (m:Q.t M.t) = let c' = M.get_or ~default:Q.zero x m in let c' = Q.(c + c') in if Q.equal Q.zero c' then M.remove x m else M.add x c' m (* dereference basic variables from [c·x], and add the result to [m] *) let rec deref_var_ t x c m = match find_expr_basic_opt t x with | None -> add_expr_ x c m | Some expr_x -> let m = ref m in Vec.iteri (fun i c_i -> let y_i = Vec.get t.nbasic i in m := deref_var_ t y_i Q.(c * c_i) !m) expr_x; !m (* maybe invert bounds, if [c < 0] *) let scale_bounds c (l,u) : bound * bound = match Q.compare c Q.zero with | 0 -> let b = { value = Erat.zero; reason = None; } in b, b | n when n<0 -> { u with value = Erat.mul c u.value; }, { l with value = Erat.mul c l.value; } | _ -> { l with value = Erat.mul c l.value; }, { u with value = Erat.mul c u.value; } let add_to_unsat_core acc = function | None -> acc | Some reason -> reason :: acc let check_cert (t:t) (c:cert) = let x = c.cert_var in let { value = low_x; reason = low_x_reason; }, { value = up_x; reason = upp_x_reason; } = get_bounds t x in begin match c.cert_expr with | [] -> if Erat.compare low_x up_x > 0 then `Ok (add_to_unsat_core (add_to_unsat_core [] low_x_reason) upp_x_reason) else `Bad_bounds (str_of_erat low_x, str_of_erat up_x) | expr -> let e0 = deref_var_ t x (Q.neg Q.one) M.empty in (* compute bounds for the expression [c.cert_expr], and also compute [c.cert_expr - x] to check if it's 0] *) let low, low_unsat_core, up, up_unsat_core, expr_minus_x = List.fold_left (fun (l, luc, u, uuc, expr_minus_x) (c, y) -> let ly, uy = scale_bounds c (get_bounds t y) in assert (Erat.compare ly.value uy.value <= 0); let expr_minus_x = deref_var_ t y c expr_minus_x in let luc = add_to_unsat_core luc ly.reason in let uuc = add_to_unsat_core uuc uy.reason in Erat.sum l ly.value, luc, Erat.sum u uy.value, uuc, expr_minus_x) (Erat.zero, [], Erat.zero, [], e0) expr in (* check that the expanded expression is [x], and that one of the bounds on [x] is incompatible with bounds of [c.cert_expr] *) if M.is_empty expr_minus_x then ( if Erat.compare low_x up > 0 then `Ok (add_to_unsat_core up_unsat_core low_x_reason) else if Erat.compare up_x low < 0 then `Ok (add_to_unsat_core low_unsat_core upp_x_reason) else `Bad_bounds (str_of_erat low, str_of_erat up) ) else `Diff_not_0 expr_minus_x end (* printer *) let matrix_pp_width = ref 8 let fmt_head = format_of_string "|%*s|| " let fmt_cell = format_of_string "%*s| " let pp_cert out (c:cert) = match c.cert_expr with | [] -> Format.fprintf out "(@[inconsistent-bounds %a@])" Var.pp c.cert_var | _ -> let pp_pair = Format.(hvbox ~i:2 @@ pair ~sep:(return "@ * ") Q.pp_print Var.pp) in Format.fprintf out "(@[cert@ :var %a@ :linexp %a@])" Var.pp c.cert_var Format.(within "[" "]" @@ hvbox @@ list ~sep:(return "@ + ") pp_pair) c.cert_expr let pp_mat out t = let open Format in fprintf out "@["; (* header *) fprintf out fmt_head !matrix_pp_width ""; Vec.iter (fun x -> fprintf out fmt_cell !matrix_pp_width (str_of_var x)) t.nbasic; fprintf out "@,"; (* rows *) for i=0 to Matrix.n_row t.tab-1 do if i>0 then fprintf out "@,"; let v = Vec.get t.basic i in fprintf out fmt_head !matrix_pp_width (str_of_var v); let row = Matrix.get_row t.tab i in Vec.iter (fun q -> fprintf out fmt_cell !matrix_pp_width (str_of_q q)) row; done; fprintf out "@]" let pp_assign = let open Format in let pp_pair = within "(" ")" @@ hvbox @@ pair ~sep:(return "@ := ") Var.pp Erat.pp in map Var_map.to_iter @@ within "(" ")" @@ hvbox @@ iter pp_pair let pp_bounds = let open Format in let pp_pairs out (x,(l,u)) = fprintf out "(@[%a =< %a =< %a@])" Erat.pp l.value Var.pp x Erat.pp u.value in map Var_map.to_iter @@ within "(" ")" @@ hvbox @@ iter pp_pairs let pp_full_state out (t:t) : unit = (* print main matrix *) Format.fprintf out "(@[simplex@ :n-row %d :n-col %d@ :mat %a@ :assign %a@ :bounds %a@])" (Matrix.n_row t.tab) (Matrix.n_col t.tab) pp_mat t pp_assign t.assign pp_bounds t.bounds end module Make(Var:VAR) = Make_inner(Var)(CCMap.Make(Var))(struct type t = unit let copy ()=() end) module Make_full_for_expr(V : VAR_GEN) (L : Linear_expr.S with type Var.t = V.t and type C.t = Q.t and type Var.lit = V.lit) : S_FULL with type var = V.t and type lit = V.lit and module L = L and module Var_map = L.Var_map and type L.var = V.t and type L.Comb.t = L.Comb.t and type param = V.Fresh.t = struct include Make_inner(V)(L.Var_map)(V.Fresh) module L = L type op = Predicate.t = Leq | Geq | Lt | Gt | Eq | Neq type constr = L.Constr.t (* add a constraint *) let add_constr (t:t) (c:constr) (reason:lit) : unit = let (x:var) = V.Fresh.fresh t.param in let e, op, q = L.Constr.split c in add_eq t (x, L.Comb.to_list e); begin match op with | Leq -> add_upper_bound t ~strict:false ~reason x q | Geq -> add_lower_bound t ~strict:false ~reason x q | Lt -> add_upper_bound t ~strict:true ~reason x q | Gt -> add_lower_bound t ~strict:true ~reason x q | Eq -> add_bounds t (x,q,q) ~strict_lower:false ~strict_upper:false ~lower_reason:reason ~upper_reason:reason | Neq -> assert false end end module Make_full(V : VAR_GEN) : S_FULL with type var = V.t and type lit = V.lit and type L.var = V.t and type param = V.Fresh.t = Make_full_for_expr(V)(Linear_expr.Make(struct include Q let pp = pp_print end)(V))