mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
wip: feat(intsolver): classify constraints into sets E,L,G,M
This commit is contained in:
parent
04e9d5b93c
commit
e77d2b81ca
1 changed files with 215 additions and 20 deletions
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
module type ARG = sig
|
||||
module Z : Sidekick_arith.INT
|
||||
module Z : Sidekick_arith.INT_FULL
|
||||
|
||||
type term
|
||||
type lit
|
||||
|
|
@ -78,6 +78,104 @@ module Make(A : ARG)
|
|||
module A = A
|
||||
open A
|
||||
|
||||
module ZTbl = CCHashtbl.Make(Z)
|
||||
|
||||
module Utils_ : sig
|
||||
type divisor = {
|
||||
prime : Z.t;
|
||||
power : int;
|
||||
}
|
||||
val is_prime : Z.t -> bool
|
||||
val prime_decomposition : Z.t -> divisor list
|
||||
val primes_leq : Z.t -> Z.t Iter.t
|
||||
end = struct
|
||||
|
||||
type divisor = {
|
||||
prime : Z.t;
|
||||
power : int;
|
||||
}
|
||||
|
||||
let two = Z.of_int 2
|
||||
|
||||
(* table from numbers to some of their divisor (if any) *)
|
||||
let _table = lazy (
|
||||
let t = ZTbl.create 256 in
|
||||
ZTbl.add t two None;
|
||||
t)
|
||||
|
||||
let _divisors n = ZTbl.find (Lazy.force _table) n
|
||||
|
||||
let _add_prime n =
|
||||
ZTbl.replace (Lazy.force _table) n None
|
||||
|
||||
(* add to the table the fact that [d] is a divisor of [n] *)
|
||||
let _add_divisor n d =
|
||||
assert (not (ZTbl.mem (Lazy.force _table) n));
|
||||
ZTbl.add (Lazy.force _table) n (Some d)
|
||||
|
||||
(* primality test, modifies _table *)
|
||||
let _is_prime n0 =
|
||||
let n = ref two in
|
||||
let bound = Z.succ (Z.sqrt n0) in
|
||||
let is_prime = ref true in
|
||||
while !is_prime && Z.(!n <= bound) do
|
||||
if Z.(rem n0 !n = zero)
|
||||
then begin
|
||||
is_prime := false;
|
||||
_add_divisor n0 !n;
|
||||
end;
|
||||
n := Z.succ !n;
|
||||
done;
|
||||
if !is_prime then _add_prime n0;
|
||||
!is_prime
|
||||
|
||||
let is_prime n =
|
||||
try
|
||||
begin match _divisors n with
|
||||
| None -> true
|
||||
| Some _ -> false
|
||||
end
|
||||
with Not_found ->
|
||||
if Z.probab_prime n && _is_prime n then (
|
||||
_add_prime n; true
|
||||
) else false
|
||||
|
||||
let rec _merge l1 l2 = match l1, l2 with
|
||||
| [], _ -> l2
|
||||
| _, [] -> l1
|
||||
| p1::l1', p2::l2' ->
|
||||
match Z.compare p1.prime p2.prime with
|
||||
| 0 ->
|
||||
{prime=p1.prime; power=p1.power+p2.power} :: _merge l1' l2'
|
||||
| n when n < 0 ->
|
||||
p1 :: _merge l1' l2
|
||||
| _ -> p2 :: _merge l1 l2'
|
||||
|
||||
let rec _decompose n =
|
||||
try
|
||||
begin match _divisors n with
|
||||
| None -> [{prime=n; power=1;}]
|
||||
| Some q1 ->
|
||||
let q2 = Z.divexact n q1 in
|
||||
_merge (_decompose q1) (_decompose q2)
|
||||
end
|
||||
with Not_found ->
|
||||
ignore (_is_prime n);
|
||||
_decompose n
|
||||
|
||||
let prime_decomposition n =
|
||||
if is_prime n
|
||||
then [{prime=n; power=1;}]
|
||||
else _decompose n
|
||||
|
||||
let primes_leq n0 k =
|
||||
let n = ref two in
|
||||
while Z.(!n <= n0) do
|
||||
if is_prime !n then k !n
|
||||
done
|
||||
end
|
||||
|
||||
|
||||
module Op = struct
|
||||
type t =
|
||||
| Leq
|
||||
|
|
@ -106,6 +204,7 @@ module Make(A : ARG)
|
|||
let iter = T_map.iter
|
||||
let return t : t = T_map.add t Z.one empty
|
||||
let neg self : t = T_map.map Z.neg self
|
||||
let mem self t : bool = T_map.mem t self
|
||||
let mult n self =
|
||||
if Z.(n = zero) then empty
|
||||
else T_map.map (fun c -> Z.(c * n)) self
|
||||
|
|
@ -164,6 +263,9 @@ module Make(A : ARG)
|
|||
lits: lit Bag.t;
|
||||
}
|
||||
|
||||
(* FIXME: need to simplify: compute gcd(le.coeffs), then divide by that
|
||||
and round const appropriately *)
|
||||
|
||||
let pp out self =
|
||||
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le Op.pp self.op Z.pp self.const
|
||||
end
|
||||
|
|
@ -213,13 +315,17 @@ module Make(A : ARG)
|
|||
|
||||
type op =
|
||||
| Leq
|
||||
| Lt
|
||||
| Eq
|
||||
| Eq_mod of {
|
||||
prime: Z.t;
|
||||
pow: int;
|
||||
} (* modulo prime^pow *)
|
||||
|
||||
let pp_op out = function
|
||||
| Leq -> Fmt.string out "<="
|
||||
| Eq -> Fmt.string out "="
|
||||
| Eq_mod {prime; pow} -> Fmt.fprintf out "%a^%d" Z.pp prime pow
|
||||
|
||||
type constr = {
|
||||
le: LE.t;
|
||||
const: Z.t;
|
||||
|
|
@ -227,15 +333,21 @@ module Make(A : ARG)
|
|||
lits: lit Bag.t;
|
||||
}
|
||||
|
||||
let pp_constr out self =
|
||||
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le pp_op self.op Z.pp self.const
|
||||
|
||||
type state = {
|
||||
mutable rw: LE.t T_map.t; (* rewrite rules *)
|
||||
mutable vars: int T_map.t; (* variables in at least one constraint *)
|
||||
mutable constrs: constr list;
|
||||
mutable ok: (unit, constr) Result.t;
|
||||
}
|
||||
(* main solving state. mutable, but copied for backtracking.
|
||||
invariant: variables in [rw] do not occur anywhere else
|
||||
*)
|
||||
|
||||
let[@inline] is_ok_ self = CCResult.is_ok self.ok
|
||||
|
||||
(* perform rewriting on the linear expression *)
|
||||
let norm_le (self:state) (le:LE.t) : LE.t =
|
||||
LE.flat_map
|
||||
|
|
@ -243,8 +355,10 @@ module Make(A : ARG)
|
|||
le
|
||||
|
||||
let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars
|
||||
|
||||
let[@inline] incr_v (self:state) (t:term) : unit =
|
||||
self.vars <- T_map.add t (1 + count_v self t) self.vars
|
||||
|
||||
let decr_v (self:state) (t:term) : unit =
|
||||
let n = count_v self t - 1 in
|
||||
assert (n >= 0);
|
||||
|
|
@ -252,12 +366,55 @@ module Make(A : ARG)
|
|||
(if n=0 then T_map.remove t self.vars
|
||||
else T_map.add t n self.vars)
|
||||
|
||||
let add_constr (self:state) (c:constr) =
|
||||
let simplify_constr (c:constr) : (constr, unit) Result.t =
|
||||
let exception E_unsat in
|
||||
try
|
||||
match T_map.choose_opt c.le with
|
||||
| None -> Ok c
|
||||
| Some (_, z0) ->
|
||||
let c_gcd = T_map.fold (fun _ z m -> Z.gcd z m) c.le z0 in
|
||||
if Z.(c_gcd > one) then (
|
||||
let const = match c.op with
|
||||
| Leq ->
|
||||
(* round down, regardless of sign *)
|
||||
let cplus = Z.abs c.const in
|
||||
let cplus = Z.(cplus / c_gcd) in
|
||||
if Z.sign c.const >= 0 then cplus else Z.neg cplus
|
||||
| Eq | Eq_mod _ ->
|
||||
if Z.equal (Z.rem c.const c_gcd) Z.zero then (
|
||||
(* compatible constant *)
|
||||
Z.(divexact c.const c_gcd)
|
||||
) else (
|
||||
raise E_unsat
|
||||
)
|
||||
in
|
||||
|
||||
let c' = {
|
||||
c with
|
||||
le=T_map.map (fun c -> Z.(c / c_gcd)) c.le;
|
||||
const;
|
||||
} in
|
||||
Log.debugf 50
|
||||
(fun k->k "(@[intsolver.simplify@ :from %a@ :into %a@])"
|
||||
pp_constr c pp_constr c');
|
||||
Ok c'
|
||||
) else Ok c
|
||||
with E_unsat ->
|
||||
Log.debugf 50 (fun k->k "(@[intsolver.simplify.unsat@ %a@])" pp_constr c);
|
||||
Error ()
|
||||
|
||||
let add_constr (self:state) (c:constr) : unit =
|
||||
if is_ok_ self then (
|
||||
let c = {c with le=norm_le self c.le } in
|
||||
match simplify_constr c with
|
||||
| Ok c ->
|
||||
LE.iter (fun t _ -> incr_v self t) c.le;
|
||||
self.constrs <- c :: self.constrs
|
||||
| Error () ->
|
||||
self.ok <- Error c
|
||||
)
|
||||
|
||||
let remove_constr (self:state) (c:constr) =
|
||||
let remove_constr (self:state) (c:constr) : unit =
|
||||
LE.iter (fun t _ -> decr_v self t) c.le
|
||||
|
||||
let create (self:t) : state =
|
||||
|
|
@ -265,6 +422,7 @@ module Make(A : ARG)
|
|||
vars=T_map.empty;
|
||||
rw=T_map.empty;
|
||||
constrs=[];
|
||||
ok=Ok();
|
||||
} in
|
||||
BVec.iter self.defs
|
||||
~f:(fun (v,le) ->
|
||||
|
|
@ -273,10 +431,10 @@ module Make(A : ARG)
|
|||
BVec.iter self.cs
|
||||
~f:(fun (c:Constr.t) ->
|
||||
let {Constr.le; op; const; lits} = c in
|
||||
let op = match op with
|
||||
| Op.Eq -> Eq
|
||||
| Op.Leq -> Leq
|
||||
| Op.Lt -> Lt
|
||||
let op, const = match op with
|
||||
| Op.Eq -> Eq, const
|
||||
| Op.Leq -> Leq, const
|
||||
| Op.Lt -> Leq, Z.pred const (* [x < t] is [x <= t-1] *)
|
||||
in
|
||||
let c = {le;const;lits;op} in
|
||||
add_constr state c
|
||||
|
|
@ -289,22 +447,59 @@ module Make(A : ARG)
|
|||
let m = Model.empty in
|
||||
Sat m (* TODO: model *)
|
||||
|
||||
| Some (t, _) ->
|
||||
self.vars <- T_map.remove t self.vars;
|
||||
| Some (t, _) -> elim_var_ self t
|
||||
end
|
||||
|
||||
|
||||
and elim_var_ self t : result =
|
||||
Log.debugf 30
|
||||
(fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])"
|
||||
A.pp_term t (T_map.cardinal self.vars));
|
||||
|
||||
assert false (* TODO *)
|
||||
assert (not (T_map.mem t self.rw)); (* woudl have been rewritten away *)
|
||||
self.vars <- T_map.remove t self.vars;
|
||||
|
||||
end
|
||||
(* gather the sets *)
|
||||
let set_e = ref [] in (* eqns *)
|
||||
let set_l = ref [] in (* t <= … *)
|
||||
let set_g = ref [] in (* t >= … *)
|
||||
let set_m = ref [] in (* t = … [n] *)
|
||||
let others = ref [] in
|
||||
|
||||
let classify_constr (c:constr) =
|
||||
match T_map.get t c.le, c.op with
|
||||
| None, _ ->
|
||||
others := c :: !others;
|
||||
| Some n_t, Leq ->
|
||||
if Z.sign n_t > 0 then
|
||||
set_l := (n_t,c) :: !set_l
|
||||
else
|
||||
set_g := (Z.abs n_t,c) :: !set_g
|
||||
| Some n_t, Eq ->
|
||||
set_e := (n_t,c) :: !set_e
|
||||
| Some n_t, Eq_mod _ ->
|
||||
set_m := (n_t,c) :: !set_m
|
||||
in
|
||||
List.iter classify_constr self.constrs;
|
||||
|
||||
Log.debugf 50
|
||||
(fun k->
|
||||
let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in
|
||||
k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])"
|
||||
A.pp_term t pps !set_e pps !set_l pps !set_g pps !set_m);
|
||||
assert false
|
||||
end
|
||||
|
||||
let check (self:t) : result =
|
||||
Log.debugf 10 (fun k->k "(@[intsolver.check@])");
|
||||
let state = Check_.create self in
|
||||
match state.ok with
|
||||
| Ok () ->
|
||||
Check_.solve_rec state
|
||||
| Error c ->
|
||||
Log.debugf 10 (fun k->k "(@[insolver.unsat-constraint@ %a@])" Check_.pp_constr c);
|
||||
(* TODO proper certificate *)
|
||||
Unsat ()
|
||||
|
||||
let _check_invariants _ = ()
|
||||
end
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue