From e77d2b81caae587c5302085a4dabb33cc99dda49 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 19 Jan 2022 11:43:35 -0500 Subject: [PATCH] wip: feat(intsolver): classify constraints into sets E,L,G,M --- src/intsolver/sidekick_intsolver.ml | 235 +++++++++++++++++++++++++--- 1 file changed, 215 insertions(+), 20 deletions(-) diff --git a/src/intsolver/sidekick_intsolver.ml b/src/intsolver/sidekick_intsolver.ml index 5ba5e62d..69c10a49 100644 --- a/src/intsolver/sidekick_intsolver.ml +++ b/src/intsolver/sidekick_intsolver.ml @@ -1,6 +1,6 @@ module type ARG = sig - module Z : Sidekick_arith.INT + module Z : Sidekick_arith.INT_FULL type term type lit @@ -78,6 +78,104 @@ module Make(A : ARG) module A = A open A + module ZTbl = CCHashtbl.Make(Z) + + module Utils_ : sig + type divisor = { + prime : Z.t; + power : int; + } + val is_prime : Z.t -> bool + val prime_decomposition : Z.t -> divisor list + val primes_leq : Z.t -> Z.t Iter.t + end = struct + + type divisor = { + prime : Z.t; + power : int; + } + + let two = Z.of_int 2 + + (* table from numbers to some of their divisor (if any) *) + let _table = lazy ( + let t = ZTbl.create 256 in + ZTbl.add t two None; + t) + + let _divisors n = ZTbl.find (Lazy.force _table) n + + let _add_prime n = + ZTbl.replace (Lazy.force _table) n None + + (* add to the table the fact that [d] is a divisor of [n] *) + let _add_divisor n d = + assert (not (ZTbl.mem (Lazy.force _table) n)); + ZTbl.add (Lazy.force _table) n (Some d) + + (* primality test, modifies _table *) + let _is_prime n0 = + let n = ref two in + let bound = Z.succ (Z.sqrt n0) in + let is_prime = ref true in + while !is_prime && Z.(!n <= bound) do + if Z.(rem n0 !n = zero) + then begin + is_prime := false; + _add_divisor n0 !n; + end; + n := Z.succ !n; + done; + if !is_prime then _add_prime n0; + !is_prime + + let is_prime n = + try + begin match _divisors n with + | None -> true + | Some _ -> false + end + with Not_found -> + if Z.probab_prime n && _is_prime n then ( + _add_prime n; true + ) else false + + let rec _merge l1 l2 = match l1, l2 with + | [], _ -> l2 + | _, [] -> l1 + | p1::l1', p2::l2' -> + match Z.compare p1.prime p2.prime with + | 0 -> + {prime=p1.prime; power=p1.power+p2.power} :: _merge l1' l2' + | n when n < 0 -> + p1 :: _merge l1' l2 + | _ -> p2 :: _merge l1 l2' + + let rec _decompose n = + try + begin match _divisors n with + | None -> [{prime=n; power=1;}] + | Some q1 -> + let q2 = Z.divexact n q1 in + _merge (_decompose q1) (_decompose q2) + end + with Not_found -> + ignore (_is_prime n); + _decompose n + + let prime_decomposition n = + if is_prime n + then [{prime=n; power=1;}] + else _decompose n + + let primes_leq n0 k = + let n = ref two in + while Z.(!n <= n0) do + if is_prime !n then k !n + done + end + + module Op = struct type t = | Leq @@ -106,6 +204,7 @@ module Make(A : ARG) let iter = T_map.iter let return t : t = T_map.add t Z.one empty let neg self : t = T_map.map Z.neg self + let mem self t : bool = T_map.mem t self let mult n self = if Z.(n = zero) then empty else T_map.map (fun c -> Z.(c * n)) self @@ -164,6 +263,9 @@ module Make(A : ARG) lits: lit Bag.t; } + (* FIXME: need to simplify: compute gcd(le.coeffs), then divide by that + and round const appropriately *) + let pp out self = Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le Op.pp self.op Z.pp self.const end @@ -213,13 +315,17 @@ module Make(A : ARG) type op = | Leq - | Lt | Eq | Eq_mod of { prime: Z.t; pow: int; } (* modulo prime^pow *) + let pp_op out = function + | Leq -> Fmt.string out "<=" + | Eq -> Fmt.string out "=" + | Eq_mod {prime; pow} -> Fmt.fprintf out "%a^%d" Z.pp prime pow + type constr = { le: LE.t; const: Z.t; @@ -227,15 +333,21 @@ module Make(A : ARG) lits: lit Bag.t; } + let pp_constr out self = + Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le pp_op self.op Z.pp self.const + type state = { mutable rw: LE.t T_map.t; (* rewrite rules *) mutable vars: int T_map.t; (* variables in at least one constraint *) mutable constrs: constr list; + mutable ok: (unit, constr) Result.t; } (* main solving state. mutable, but copied for backtracking. invariant: variables in [rw] do not occur anywhere else *) + let[@inline] is_ok_ self = CCResult.is_ok self.ok + (* perform rewriting on the linear expression *) let norm_le (self:state) (le:LE.t) : LE.t = LE.flat_map @@ -243,8 +355,10 @@ module Make(A : ARG) le let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars + let[@inline] incr_v (self:state) (t:term) : unit = self.vars <- T_map.add t (1 + count_v self t) self.vars + let decr_v (self:state) (t:term) : unit = let n = count_v self t - 1 in assert (n >= 0); @@ -252,12 +366,55 @@ module Make(A : ARG) (if n=0 then T_map.remove t self.vars else T_map.add t n self.vars) - let add_constr (self:state) (c:constr) = - let c = {c with le=norm_le self c.le } in - LE.iter (fun t _ -> incr_v self t) c.le; - self.constrs <- c :: self.constrs + let simplify_constr (c:constr) : (constr, unit) Result.t = + let exception E_unsat in + try + match T_map.choose_opt c.le with + | None -> Ok c + | Some (_, z0) -> + let c_gcd = T_map.fold (fun _ z m -> Z.gcd z m) c.le z0 in + if Z.(c_gcd > one) then ( + let const = match c.op with + | Leq -> + (* round down, regardless of sign *) + let cplus = Z.abs c.const in + let cplus = Z.(cplus / c_gcd) in + if Z.sign c.const >= 0 then cplus else Z.neg cplus + | Eq | Eq_mod _ -> + if Z.equal (Z.rem c.const c_gcd) Z.zero then ( + (* compatible constant *) + Z.(divexact c.const c_gcd) + ) else ( + raise E_unsat + ) + in - let remove_constr (self:state) (c:constr) = + let c' = { + c with + le=T_map.map (fun c -> Z.(c / c_gcd)) c.le; + const; + } in + Log.debugf 50 + (fun k->k "(@[intsolver.simplify@ :from %a@ :into %a@])" + pp_constr c pp_constr c'); + Ok c' + ) else Ok c + with E_unsat -> + Log.debugf 50 (fun k->k "(@[intsolver.simplify.unsat@ %a@])" pp_constr c); + Error () + + let add_constr (self:state) (c:constr) : unit = + if is_ok_ self then ( + let c = {c with le=norm_le self c.le } in + match simplify_constr c with + | Ok c -> + LE.iter (fun t _ -> incr_v self t) c.le; + self.constrs <- c :: self.constrs + | Error () -> + self.ok <- Error c + ) + + let remove_constr (self:state) (c:constr) : unit = LE.iter (fun t _ -> decr_v self t) c.le let create (self:t) : state = @@ -265,6 +422,7 @@ module Make(A : ARG) vars=T_map.empty; rw=T_map.empty; constrs=[]; + ok=Ok(); } in BVec.iter self.defs ~f:(fun (v,le) -> @@ -273,10 +431,10 @@ module Make(A : ARG) BVec.iter self.cs ~f:(fun (c:Constr.t) -> let {Constr.le; op; const; lits} = c in - let op = match op with - | Op.Eq -> Eq - | Op.Leq -> Leq - | Op.Lt -> Lt + let op, const = match op with + | Op.Eq -> Eq, const + | Op.Leq -> Leq, const + | Op.Lt -> Leq, Z.pred const (* [x < t] is [x <= t-1] *) in let c = {le;const;lits;op} in add_constr state c @@ -289,22 +447,59 @@ module Make(A : ARG) let m = Model.empty in Sat m (* TODO: model *) - | Some (t, _) -> - self.vars <- T_map.remove t self.vars; - Log.debugf 30 - (fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])" - A.pp_term t (T_map.cardinal self.vars)); - - assert false (* TODO *) - + | Some (t, _) -> elim_var_ self t end + + and elim_var_ self t : result = + Log.debugf 30 + (fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])" + A.pp_term t (T_map.cardinal self.vars)); + + assert (not (T_map.mem t self.rw)); (* woudl have been rewritten away *) + self.vars <- T_map.remove t self.vars; + + (* gather the sets *) + let set_e = ref [] in (* eqns *) + let set_l = ref [] in (* t <= … *) + let set_g = ref [] in (* t >= … *) + let set_m = ref [] in (* t = … [n] *) + let others = ref [] in + + let classify_constr (c:constr) = + match T_map.get t c.le, c.op with + | None, _ -> + others := c :: !others; + | Some n_t, Leq -> + if Z.sign n_t > 0 then + set_l := (n_t,c) :: !set_l + else + set_g := (Z.abs n_t,c) :: !set_g + | Some n_t, Eq -> + set_e := (n_t,c) :: !set_e + | Some n_t, Eq_mod _ -> + set_m := (n_t,c) :: !set_m + in + List.iter classify_constr self.constrs; + + Log.debugf 50 + (fun k-> + let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in + k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])" + A.pp_term t pps !set_e pps !set_l pps !set_g pps !set_m); + assert false end let check (self:t) : result = Log.debugf 10 (fun k->k "(@[intsolver.check@])"); let state = Check_.create self in - Check_.solve_rec state + match state.ok with + | Ok () -> + Check_.solve_rec state + | Error c -> + Log.debugf 10 (fun k->k "(@[insolver.unsat-constraint@ %a@])" Check_.pp_constr c); + (* TODO proper certificate *) + Unsat () let _check_invariants _ = () end