wip: feat(intsolver): classify constraints into sets E,L,G,M

This commit is contained in:
Simon Cruanes 2022-01-19 11:43:35 -05:00
parent 04e9d5b93c
commit e77d2b81ca
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -1,6 +1,6 @@
module type ARG = sig
module Z : Sidekick_arith.INT
module Z : Sidekick_arith.INT_FULL
type term
type lit
@ -78,6 +78,104 @@ module Make(A : ARG)
module A = A
open A
module ZTbl = CCHashtbl.Make(Z)
module Utils_ : sig
type divisor = {
prime : Z.t;
power : int;
}
val is_prime : Z.t -> bool
val prime_decomposition : Z.t -> divisor list
val primes_leq : Z.t -> Z.t Iter.t
end = struct
type divisor = {
prime : Z.t;
power : int;
}
let two = Z.of_int 2
(* table from numbers to some of their divisor (if any) *)
let _table = lazy (
let t = ZTbl.create 256 in
ZTbl.add t two None;
t)
let _divisors n = ZTbl.find (Lazy.force _table) n
let _add_prime n =
ZTbl.replace (Lazy.force _table) n None
(* add to the table the fact that [d] is a divisor of [n] *)
let _add_divisor n d =
assert (not (ZTbl.mem (Lazy.force _table) n));
ZTbl.add (Lazy.force _table) n (Some d)
(* primality test, modifies _table *)
let _is_prime n0 =
let n = ref two in
let bound = Z.succ (Z.sqrt n0) in
let is_prime = ref true in
while !is_prime && Z.(!n <= bound) do
if Z.(rem n0 !n = zero)
then begin
is_prime := false;
_add_divisor n0 !n;
end;
n := Z.succ !n;
done;
if !is_prime then _add_prime n0;
!is_prime
let is_prime n =
try
begin match _divisors n with
| None -> true
| Some _ -> false
end
with Not_found ->
if Z.probab_prime n && _is_prime n then (
_add_prime n; true
) else false
let rec _merge l1 l2 = match l1, l2 with
| [], _ -> l2
| _, [] -> l1
| p1::l1', p2::l2' ->
match Z.compare p1.prime p2.prime with
| 0 ->
{prime=p1.prime; power=p1.power+p2.power} :: _merge l1' l2'
| n when n < 0 ->
p1 :: _merge l1' l2
| _ -> p2 :: _merge l1 l2'
let rec _decompose n =
try
begin match _divisors n with
| None -> [{prime=n; power=1;}]
| Some q1 ->
let q2 = Z.divexact n q1 in
_merge (_decompose q1) (_decompose q2)
end
with Not_found ->
ignore (_is_prime n);
_decompose n
let prime_decomposition n =
if is_prime n
then [{prime=n; power=1;}]
else _decompose n
let primes_leq n0 k =
let n = ref two in
while Z.(!n <= n0) do
if is_prime !n then k !n
done
end
module Op = struct
type t =
| Leq
@ -106,6 +204,7 @@ module Make(A : ARG)
let iter = T_map.iter
let return t : t = T_map.add t Z.one empty
let neg self : t = T_map.map Z.neg self
let mem self t : bool = T_map.mem t self
let mult n self =
if Z.(n = zero) then empty
else T_map.map (fun c -> Z.(c * n)) self
@ -164,6 +263,9 @@ module Make(A : ARG)
lits: lit Bag.t;
}
(* FIXME: need to simplify: compute gcd(le.coeffs), then divide by that
and round const appropriately *)
let pp out self =
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le Op.pp self.op Z.pp self.const
end
@ -213,13 +315,17 @@ module Make(A : ARG)
type op =
| Leq
| Lt
| Eq
| Eq_mod of {
prime: Z.t;
pow: int;
} (* modulo prime^pow *)
let pp_op out = function
| Leq -> Fmt.string out "<="
| Eq -> Fmt.string out "="
| Eq_mod {prime; pow} -> Fmt.fprintf out "%a^%d" Z.pp prime pow
type constr = {
le: LE.t;
const: Z.t;
@ -227,15 +333,21 @@ module Make(A : ARG)
lits: lit Bag.t;
}
let pp_constr out self =
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le pp_op self.op Z.pp self.const
type state = {
mutable rw: LE.t T_map.t; (* rewrite rules *)
mutable vars: int T_map.t; (* variables in at least one constraint *)
mutable constrs: constr list;
mutable ok: (unit, constr) Result.t;
}
(* main solving state. mutable, but copied for backtracking.
invariant: variables in [rw] do not occur anywhere else
*)
let[@inline] is_ok_ self = CCResult.is_ok self.ok
(* perform rewriting on the linear expression *)
let norm_le (self:state) (le:LE.t) : LE.t =
LE.flat_map
@ -243,8 +355,10 @@ module Make(A : ARG)
le
let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars
let[@inline] incr_v (self:state) (t:term) : unit =
self.vars <- T_map.add t (1 + count_v self t) self.vars
let decr_v (self:state) (t:term) : unit =
let n = count_v self t - 1 in
assert (n >= 0);
@ -252,12 +366,55 @@ module Make(A : ARG)
(if n=0 then T_map.remove t self.vars
else T_map.add t n self.vars)
let add_constr (self:state) (c:constr) =
let simplify_constr (c:constr) : (constr, unit) Result.t =
let exception E_unsat in
try
match T_map.choose_opt c.le with
| None -> Ok c
| Some (_, z0) ->
let c_gcd = T_map.fold (fun _ z m -> Z.gcd z m) c.le z0 in
if Z.(c_gcd > one) then (
let const = match c.op with
| Leq ->
(* round down, regardless of sign *)
let cplus = Z.abs c.const in
let cplus = Z.(cplus / c_gcd) in
if Z.sign c.const >= 0 then cplus else Z.neg cplus
| Eq | Eq_mod _ ->
if Z.equal (Z.rem c.const c_gcd) Z.zero then (
(* compatible constant *)
Z.(divexact c.const c_gcd)
) else (
raise E_unsat
)
in
let c' = {
c with
le=T_map.map (fun c -> Z.(c / c_gcd)) c.le;
const;
} in
Log.debugf 50
(fun k->k "(@[intsolver.simplify@ :from %a@ :into %a@])"
pp_constr c pp_constr c');
Ok c'
) else Ok c
with E_unsat ->
Log.debugf 50 (fun k->k "(@[intsolver.simplify.unsat@ %a@])" pp_constr c);
Error ()
let add_constr (self:state) (c:constr) : unit =
if is_ok_ self then (
let c = {c with le=norm_le self c.le } in
match simplify_constr c with
| Ok c ->
LE.iter (fun t _ -> incr_v self t) c.le;
self.constrs <- c :: self.constrs
| Error () ->
self.ok <- Error c
)
let remove_constr (self:state) (c:constr) =
let remove_constr (self:state) (c:constr) : unit =
LE.iter (fun t _ -> decr_v self t) c.le
let create (self:t) : state =
@ -265,6 +422,7 @@ module Make(A : ARG)
vars=T_map.empty;
rw=T_map.empty;
constrs=[];
ok=Ok();
} in
BVec.iter self.defs
~f:(fun (v,le) ->
@ -273,10 +431,10 @@ module Make(A : ARG)
BVec.iter self.cs
~f:(fun (c:Constr.t) ->
let {Constr.le; op; const; lits} = c in
let op = match op with
| Op.Eq -> Eq
| Op.Leq -> Leq
| Op.Lt -> Lt
let op, const = match op with
| Op.Eq -> Eq, const
| Op.Leq -> Leq, const
| Op.Lt -> Leq, Z.pred const (* [x < t] is [x <= t-1] *)
in
let c = {le;const;lits;op} in
add_constr state c
@ -289,22 +447,59 @@ module Make(A : ARG)
let m = Model.empty in
Sat m (* TODO: model *)
| Some (t, _) ->
self.vars <- T_map.remove t self.vars;
| Some (t, _) -> elim_var_ self t
end
and elim_var_ self t : result =
Log.debugf 30
(fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])"
A.pp_term t (T_map.cardinal self.vars));
assert false (* TODO *)
assert (not (T_map.mem t self.rw)); (* woudl have been rewritten away *)
self.vars <- T_map.remove t self.vars;
end
(* gather the sets *)
let set_e = ref [] in (* eqns *)
let set_l = ref [] in (* t <= … *)
let set_g = ref [] in (* t >= … *)
let set_m = ref [] in (* t = … [n] *)
let others = ref [] in
let classify_constr (c:constr) =
match T_map.get t c.le, c.op with
| None, _ ->
others := c :: !others;
| Some n_t, Leq ->
if Z.sign n_t > 0 then
set_l := (n_t,c) :: !set_l
else
set_g := (Z.abs n_t,c) :: !set_g
| Some n_t, Eq ->
set_e := (n_t,c) :: !set_e
| Some n_t, Eq_mod _ ->
set_m := (n_t,c) :: !set_m
in
List.iter classify_constr self.constrs;
Log.debugf 50
(fun k->
let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in
k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])"
A.pp_term t pps !set_e pps !set_l pps !set_g pps !set_m);
assert false
end
let check (self:t) : result =
Log.debugf 10 (fun k->k "(@[intsolver.check@])");
let state = Check_.create self in
match state.ok with
| Ok () ->
Check_.solve_rec state
| Error c ->
Log.debugf 10 (fun k->k "(@[insolver.unsat-constraint@ %a@])" Check_.pp_constr c);
(* TODO proper certificate *)
Unsat ()
let _check_invariants _ = ()
end