module type ARG = sig (** terms *) module T : sig type t val equal : t -> t -> bool val hash : t -> int val compare : t -> t -> int val pp : t Fmt.printer end type tag val pp_tag : tag Fmt.printer end module Pred : sig type t = Lt | Leq | Geq | Gt | Neq | Eq val neg : t -> t val pp : t Fmt.printer val to_string : t -> string end = struct type t = Lt | Leq | Geq | Gt | Neq | Eq let to_string = function | Lt -> "<" | Leq -> "<=" | Eq -> "=" | Neq -> "!=" | Gt -> ">" | Geq -> ">=" let neg = function | Leq -> Gt | Lt -> Geq | Eq -> Neq | Neq -> Eq | Geq -> Lt | Gt -> Leq let pp out p = Fmt.string out (to_string p) end module type S = sig module A : ARG type t type term = A.T.t module LE : sig type t val const : Q.t -> t val zero : t val var : term -> t val neg : t -> t val find_exn : term -> t -> Q.t val find : term -> t -> Q.t option (* val map : (term -> term) -> t -> t *) module Infix : sig val (+) : t -> t -> t val (-) : t -> t -> t val ( * ) : Q.t -> t -> t end include module type of Infix val pp : t Fmt.printer end (** {3 Arithmetic constraint} *) module Constr : sig type t val pp : t Fmt.printer val mk : ?tag:A.tag -> Pred.t -> LE.t -> LE.t -> t val is_absurd : t -> bool end val create : unit -> t val assert_c : t -> Constr.t -> unit type res = | Sat | Unsat of A.tag list val solve : t -> res end module Make(A : ARG) : S with module A = A = struct module A = A module T = A.T module T_set = CCSet.Make(A.T) module T_map = CCMap.Make(A.T) type term = A.T.t module LE = struct module M = T_map type t = { le: Q.t M.t; const: Q.t; } let const x : t = {const=x; le=M.empty} let zero = const Q.zero let var x : t = {const=Q.zero; le=M.singleton x Q.one} let[@inline] find_exn v le = M.find v le.le let[@inline] find v le = M.get v le.le let[@inline] remove v le : t = {le with le=M.remove v le.le} let neg a : t = {const=Q.neg a.const; le=M.map Q.neg a.le; } let (+) a b : t = {const = Q.(a.const + b.const); le=M.merge_safe a.le b.le ~f:(fun _ -> function | `Left x | `Right x -> Some x | `Both (x,y) -> let z = Q.(x + y) in if Q.sign z = 0 then None else Some z) } let (-) a b : t = {const = Q.(a.const - b.const); le=M.merge_safe a.le b.le ~f:(fun _ -> function | `Left x -> Some x | `Right x -> Some (Q.neg x) | `Both (x,y) -> let z = Q.(x - y) in if Q.sign z = 0 then None else Some z) } let ( * ) x a : t = if Q.sign x = 0 then const Q.zero else ( {const=Q.( a.const * x ); le=M.map (fun y -> Q.(x * y)) a.le } ) let max_var self : T.t option = M.keys self.le |> Iter.max ~lt:(fun a b -> T.compare a b < 0) (* ensure coeff of [v] is 1 *) let normalize_wrt (v:T.t) le : t = let q = find_exn v le in Q.inv q * le module Infix = struct let (+) = (+) let (-) = (-) let ( * ) = ( * ) end let vars self = T_map.keys self.le let pp out (self:t) : unit = let pp_pair out (e,q) = if Q.equal Q.one q then T.pp out e else Fmt.fprintf out "%a * %a" Q.pp_print q T.pp e in let pp_sum out le = (Util.pp_iter ~sep:" + " pp_pair) out (M.to_iter le) in if Q.sign self.const = 0 then ( Fmt.fprintf out "(@[%a@])" pp_sum self.le ) else ( Fmt.fprintf out "(@[%a@ + %a@])" Q.pp_print self.const pp_sum self.le ) end (** {2 Constraints} *) module Constr = struct type t = { pred: Pred.t; le: LE.t; tag: A.tag list; } let pp out (c:t) : unit = Fmt.fprintf out "(@[constr@ :le %a@ :pred %s 0@])" LE.pp c.le (Pred.to_string c.pred) let mk_ ~tag pred le : t = {pred; tag; le; } let mk ?tag pred l1 l2 : t = mk_ ~tag:(CCOpt.to_list tag) pred LE.(l1 - l2) let is_absurd (self:t) : bool = T_map.is_empty self.le.le && let c = self.le.const in begin match self.pred with | Leq -> Q.compare c Q.zero > 0 | Lt -> Q.compare c Q.zero >= 0 | Geq -> Q.compare c Q.zero < 0 | Gt -> Q.compare c Q.zero <= 0 | Eq -> Q.compare c Q.zero <> 0 | Neq -> Q.compare c Q.zero = 0 end let is_trivial (self:t) : bool = T_map.is_empty self.le.le && not (is_absurd self) (* nornalize and return maximum variable *) let normalize (self:t) : t = match self.pred with | Geq -> mk_ ~tag:self.tag Lt (LE.neg self.le) | Gt -> mk_ ~tag:self.tag Leq (LE.neg self.le) | _ -> self let find_max (self:t) : T.t option * bool = match LE.max_var self.le with | None -> None, true | Some t -> Some t, Q.sign (T_map.find t self.le.le) > 0 end (** constraints for a variable (where the variable is maximal) *) type c_for_var = { occ_pos: Constr.t list; occ_eq: Constr.t list; occ_neg: Constr.t list; } type system = { empties: Constr.t list; (* no variables, check first *) idx: c_for_var T_map.t; (* map [t] to [c1,c2] where [c1] are normalized constraints whose maximum term is [t], with positive sign for [c1] and negative for [c2] respectively. *) } type t = { mutable cs: Constr.t list; mutable sys: system; } let empty_sys : system = {empties=[]; idx=T_map.empty} let empty_c_for_v : c_for_var = { occ_pos=[]; occ_neg=[]; occ_eq=[] } let create () : t = { cs=[]; sys=empty_sys; } let add_sys (sys:system) (c:Constr.t) : system = assert (match c.pred with Eq|Leq|Lt -> true | _ -> false); if Constr.is_trivial c then ( Log.debugf 10 (fun k->k"(@[FM.drop-trivial@ %a@])" Constr.pp c); sys ) else ( match Constr.find_max c with | None, _ -> {sys with empties=c :: sys.empties} | Some v, occ_pos -> Log.debugf 30 (fun k->k "(@[FM.add-sys %a@ :max_var %a@ :occurs-pos %B@])" Constr.pp c T.pp v occ_pos); let cs = T_map.get_or ~default:empty_c_for_v v sys.idx in let cs = if c.pred = Eq then {cs with occ_eq = c :: cs.occ_eq} else if occ_pos then {cs with occ_pos = c :: cs.occ_pos} else {cs with occ_neg = c :: cs.occ_neg } in let idx = T_map.add v cs sys.idx in {sys with idx} ) let assert_c (self:t) c0 : unit = Log.debugf 10 (fun k->k "(@[FM.add-constr@ %a@ :tags %a@])" Constr.pp c0 (Fmt.Dump.list A.pp_tag) c0.tag); let c = Constr.normalize c0 in if c.pred <> c0.pred then ( Log.debugf 30 (fun k->k "(@[FM.normalized %a@])" Constr.pp c); ); assert (match c.pred with Eq | Leq | Lt -> true | _ -> false); self.cs <- c :: self.cs; self.sys <- add_sys self.sys c; () let pp_system out (self:system) : unit = let pp_idxkv out (t,{occ_eq; occ_pos; occ_neg}) = Fmt.fprintf out "(@[for-var %a@ :occ-eq %a@ :occ-pos %a@ :occ-neg %a@])" T.pp t (Fmt.Dump.list Constr.pp) occ_eq (Fmt.Dump.list Constr.pp) occ_pos (Fmt.Dump.list Constr.pp) occ_neg in Fmt.fprintf out "(@[:empties %a@ :idx (@[%a@])@])" (Fmt.Dump.list Constr.pp) self.empties (Util.pp_iter pp_idxkv) (T_map.to_iter self.idx) (* TODO: be able to provide a model for SAT *) type res = | Sat | Unsat of A.tag list (* replace [x] with [by] inside [le] *) let subst_le (x:T.t) (le:LE.t) ~by:(le1:LE.t) : LE.t = let q = LE.find_exn x le in let le = LE.remove x le in LE.( le + q * le1 ) let subst_constr x c ~by : Constr.t = let c = {c with Constr.le=subst_le x ~by c.Constr.le} in Constr.normalize c let rec solve_ (self:system) : res = Log.debugf 50 (fun k->k "(@[FM.solve-rec@ :sys %a@])" pp_system self); begin match List.find Constr.is_absurd self.empties with | c -> Log.debugf 10 (fun k->k"(@[FM.unsat@ :by-absurd %a@])" Constr.pp c); Unsat c.tag | exception Not_found -> (* need to process biggest variable first *) match T_map.max_binding_opt self.idx with | None -> Sat | Some (v, {occ_eq=c0 :: ceq'; occ_pos; occ_neg}) -> (* remove [v] from [idx] *) let sys = {self with idx=T_map.remove v self.idx} in (* substitute using [c0] in the other constraints containing [v] *) assert (c0.pred = Eq); let c0 = LE.normalize_wrt v c0.le in (* build [v = rhs] *) let rhs = LE.neg @@ LE.remove v c0 in Log.debugf 50 (fun k->k "(@[FM.subst-from-eq@ :v %a@ :rhs %a@])" T.pp v LE.pp rhs); (* perform substitution *) let new_sys = [Iter.of_list ceq'; Iter.of_list occ_pos; Iter.of_list occ_neg] |> Iter.of_list |> Iter.flatten |> Iter.map (subst_constr v ~by:rhs) |> Iter.fold add_sys sys in solve_ new_sys | Some (v, {occ_eq=[]; occ_pos=l1; occ_neg=l2}) -> Log.debugf 10 (fun k->k "(@[@{FM.pivot@}@ :v %a@ :lpos %a@ :lneg %a@])" T.pp v (Fmt.Dump.list Constr.pp) l1 (Fmt.Dump.list Constr.pp) l2); (* remove [v] *) let sys = {self with idx=T_map.remove v self.idx} in let new_sys = Iter.product (Iter.of_list l1) (Iter.of_list l2) |> Iter.map (fun (c1,c2) -> let q1 = LE.find_exn v c1.Constr.le in let le = LE.( c1.Constr.le + (Q.inv q1 * c2.Constr.le) ) in let pred = match c1.Constr.pred, c2.Constr.pred with | Eq, Eq -> Pred.Eq | Lt, _ | _, Lt -> Pred.Lt | Leq, _ | _, Leq -> Pred.Leq | _ -> assert false in let c = Constr.mk_ ~tag:(c1.tag @ c2.tag) pred le in Log.debugf 50 (fun k->k "(@[FM.resolve@ %a@ %a@ :yields@ %a@])" Constr.pp c1 Constr.pp c2 Constr.pp c); c) |> Iter.fold add_sys sys in solve_ new_sys end let solve (self:t) : res = Log.debugf 5 (fun k->k"(@[@{FM.solve@}@ %a@])" (Util.pp_list Constr.pp) self.cs); solve_ self.sys end