mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
wip: feat(intsolver): classify constraints into sets E,L,G,M
This commit is contained in:
parent
04e9d5b93c
commit
e77d2b81ca
1 changed files with 215 additions and 20 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
|
|
||||||
module type ARG = sig
|
module type ARG = sig
|
||||||
module Z : Sidekick_arith.INT
|
module Z : Sidekick_arith.INT_FULL
|
||||||
|
|
||||||
type term
|
type term
|
||||||
type lit
|
type lit
|
||||||
|
|
@ -78,6 +78,104 @@ module Make(A : ARG)
|
||||||
module A = A
|
module A = A
|
||||||
open A
|
open A
|
||||||
|
|
||||||
|
module ZTbl = CCHashtbl.Make(Z)
|
||||||
|
|
||||||
|
module Utils_ : sig
|
||||||
|
type divisor = {
|
||||||
|
prime : Z.t;
|
||||||
|
power : int;
|
||||||
|
}
|
||||||
|
val is_prime : Z.t -> bool
|
||||||
|
val prime_decomposition : Z.t -> divisor list
|
||||||
|
val primes_leq : Z.t -> Z.t Iter.t
|
||||||
|
end = struct
|
||||||
|
|
||||||
|
type divisor = {
|
||||||
|
prime : Z.t;
|
||||||
|
power : int;
|
||||||
|
}
|
||||||
|
|
||||||
|
let two = Z.of_int 2
|
||||||
|
|
||||||
|
(* table from numbers to some of their divisor (if any) *)
|
||||||
|
let _table = lazy (
|
||||||
|
let t = ZTbl.create 256 in
|
||||||
|
ZTbl.add t two None;
|
||||||
|
t)
|
||||||
|
|
||||||
|
let _divisors n = ZTbl.find (Lazy.force _table) n
|
||||||
|
|
||||||
|
let _add_prime n =
|
||||||
|
ZTbl.replace (Lazy.force _table) n None
|
||||||
|
|
||||||
|
(* add to the table the fact that [d] is a divisor of [n] *)
|
||||||
|
let _add_divisor n d =
|
||||||
|
assert (not (ZTbl.mem (Lazy.force _table) n));
|
||||||
|
ZTbl.add (Lazy.force _table) n (Some d)
|
||||||
|
|
||||||
|
(* primality test, modifies _table *)
|
||||||
|
let _is_prime n0 =
|
||||||
|
let n = ref two in
|
||||||
|
let bound = Z.succ (Z.sqrt n0) in
|
||||||
|
let is_prime = ref true in
|
||||||
|
while !is_prime && Z.(!n <= bound) do
|
||||||
|
if Z.(rem n0 !n = zero)
|
||||||
|
then begin
|
||||||
|
is_prime := false;
|
||||||
|
_add_divisor n0 !n;
|
||||||
|
end;
|
||||||
|
n := Z.succ !n;
|
||||||
|
done;
|
||||||
|
if !is_prime then _add_prime n0;
|
||||||
|
!is_prime
|
||||||
|
|
||||||
|
let is_prime n =
|
||||||
|
try
|
||||||
|
begin match _divisors n with
|
||||||
|
| None -> true
|
||||||
|
| Some _ -> false
|
||||||
|
end
|
||||||
|
with Not_found ->
|
||||||
|
if Z.probab_prime n && _is_prime n then (
|
||||||
|
_add_prime n; true
|
||||||
|
) else false
|
||||||
|
|
||||||
|
let rec _merge l1 l2 = match l1, l2 with
|
||||||
|
| [], _ -> l2
|
||||||
|
| _, [] -> l1
|
||||||
|
| p1::l1', p2::l2' ->
|
||||||
|
match Z.compare p1.prime p2.prime with
|
||||||
|
| 0 ->
|
||||||
|
{prime=p1.prime; power=p1.power+p2.power} :: _merge l1' l2'
|
||||||
|
| n when n < 0 ->
|
||||||
|
p1 :: _merge l1' l2
|
||||||
|
| _ -> p2 :: _merge l1 l2'
|
||||||
|
|
||||||
|
let rec _decompose n =
|
||||||
|
try
|
||||||
|
begin match _divisors n with
|
||||||
|
| None -> [{prime=n; power=1;}]
|
||||||
|
| Some q1 ->
|
||||||
|
let q2 = Z.divexact n q1 in
|
||||||
|
_merge (_decompose q1) (_decompose q2)
|
||||||
|
end
|
||||||
|
with Not_found ->
|
||||||
|
ignore (_is_prime n);
|
||||||
|
_decompose n
|
||||||
|
|
||||||
|
let prime_decomposition n =
|
||||||
|
if is_prime n
|
||||||
|
then [{prime=n; power=1;}]
|
||||||
|
else _decompose n
|
||||||
|
|
||||||
|
let primes_leq n0 k =
|
||||||
|
let n = ref two in
|
||||||
|
while Z.(!n <= n0) do
|
||||||
|
if is_prime !n then k !n
|
||||||
|
done
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
module Op = struct
|
module Op = struct
|
||||||
type t =
|
type t =
|
||||||
| Leq
|
| Leq
|
||||||
|
|
@ -106,6 +204,7 @@ module Make(A : ARG)
|
||||||
let iter = T_map.iter
|
let iter = T_map.iter
|
||||||
let return t : t = T_map.add t Z.one empty
|
let return t : t = T_map.add t Z.one empty
|
||||||
let neg self : t = T_map.map Z.neg self
|
let neg self : t = T_map.map Z.neg self
|
||||||
|
let mem self t : bool = T_map.mem t self
|
||||||
let mult n self =
|
let mult n self =
|
||||||
if Z.(n = zero) then empty
|
if Z.(n = zero) then empty
|
||||||
else T_map.map (fun c -> Z.(c * n)) self
|
else T_map.map (fun c -> Z.(c * n)) self
|
||||||
|
|
@ -164,6 +263,9 @@ module Make(A : ARG)
|
||||||
lits: lit Bag.t;
|
lits: lit Bag.t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
(* FIXME: need to simplify: compute gcd(le.coeffs), then divide by that
|
||||||
|
and round const appropriately *)
|
||||||
|
|
||||||
let pp out self =
|
let pp out self =
|
||||||
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le Op.pp self.op Z.pp self.const
|
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le Op.pp self.op Z.pp self.const
|
||||||
end
|
end
|
||||||
|
|
@ -213,13 +315,17 @@ module Make(A : ARG)
|
||||||
|
|
||||||
type op =
|
type op =
|
||||||
| Leq
|
| Leq
|
||||||
| Lt
|
|
||||||
| Eq
|
| Eq
|
||||||
| Eq_mod of {
|
| Eq_mod of {
|
||||||
prime: Z.t;
|
prime: Z.t;
|
||||||
pow: int;
|
pow: int;
|
||||||
} (* modulo prime^pow *)
|
} (* modulo prime^pow *)
|
||||||
|
|
||||||
|
let pp_op out = function
|
||||||
|
| Leq -> Fmt.string out "<="
|
||||||
|
| Eq -> Fmt.string out "="
|
||||||
|
| Eq_mod {prime; pow} -> Fmt.fprintf out "%a^%d" Z.pp prime pow
|
||||||
|
|
||||||
type constr = {
|
type constr = {
|
||||||
le: LE.t;
|
le: LE.t;
|
||||||
const: Z.t;
|
const: Z.t;
|
||||||
|
|
@ -227,15 +333,21 @@ module Make(A : ARG)
|
||||||
lits: lit Bag.t;
|
lits: lit Bag.t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let pp_constr out self =
|
||||||
|
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le pp_op self.op Z.pp self.const
|
||||||
|
|
||||||
type state = {
|
type state = {
|
||||||
mutable rw: LE.t T_map.t; (* rewrite rules *)
|
mutable rw: LE.t T_map.t; (* rewrite rules *)
|
||||||
mutable vars: int T_map.t; (* variables in at least one constraint *)
|
mutable vars: int T_map.t; (* variables in at least one constraint *)
|
||||||
mutable constrs: constr list;
|
mutable constrs: constr list;
|
||||||
|
mutable ok: (unit, constr) Result.t;
|
||||||
}
|
}
|
||||||
(* main solving state. mutable, but copied for backtracking.
|
(* main solving state. mutable, but copied for backtracking.
|
||||||
invariant: variables in [rw] do not occur anywhere else
|
invariant: variables in [rw] do not occur anywhere else
|
||||||
*)
|
*)
|
||||||
|
|
||||||
|
let[@inline] is_ok_ self = CCResult.is_ok self.ok
|
||||||
|
|
||||||
(* perform rewriting on the linear expression *)
|
(* perform rewriting on the linear expression *)
|
||||||
let norm_le (self:state) (le:LE.t) : LE.t =
|
let norm_le (self:state) (le:LE.t) : LE.t =
|
||||||
LE.flat_map
|
LE.flat_map
|
||||||
|
|
@ -243,8 +355,10 @@ module Make(A : ARG)
|
||||||
le
|
le
|
||||||
|
|
||||||
let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars
|
let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars
|
||||||
|
|
||||||
let[@inline] incr_v (self:state) (t:term) : unit =
|
let[@inline] incr_v (self:state) (t:term) : unit =
|
||||||
self.vars <- T_map.add t (1 + count_v self t) self.vars
|
self.vars <- T_map.add t (1 + count_v self t) self.vars
|
||||||
|
|
||||||
let decr_v (self:state) (t:term) : unit =
|
let decr_v (self:state) (t:term) : unit =
|
||||||
let n = count_v self t - 1 in
|
let n = count_v self t - 1 in
|
||||||
assert (n >= 0);
|
assert (n >= 0);
|
||||||
|
|
@ -252,12 +366,55 @@ module Make(A : ARG)
|
||||||
(if n=0 then T_map.remove t self.vars
|
(if n=0 then T_map.remove t self.vars
|
||||||
else T_map.add t n self.vars)
|
else T_map.add t n self.vars)
|
||||||
|
|
||||||
let add_constr (self:state) (c:constr) =
|
let simplify_constr (c:constr) : (constr, unit) Result.t =
|
||||||
let c = {c with le=norm_le self c.le } in
|
let exception E_unsat in
|
||||||
LE.iter (fun t _ -> incr_v self t) c.le;
|
try
|
||||||
self.constrs <- c :: self.constrs
|
match T_map.choose_opt c.le with
|
||||||
|
| None -> Ok c
|
||||||
|
| Some (_, z0) ->
|
||||||
|
let c_gcd = T_map.fold (fun _ z m -> Z.gcd z m) c.le z0 in
|
||||||
|
if Z.(c_gcd > one) then (
|
||||||
|
let const = match c.op with
|
||||||
|
| Leq ->
|
||||||
|
(* round down, regardless of sign *)
|
||||||
|
let cplus = Z.abs c.const in
|
||||||
|
let cplus = Z.(cplus / c_gcd) in
|
||||||
|
if Z.sign c.const >= 0 then cplus else Z.neg cplus
|
||||||
|
| Eq | Eq_mod _ ->
|
||||||
|
if Z.equal (Z.rem c.const c_gcd) Z.zero then (
|
||||||
|
(* compatible constant *)
|
||||||
|
Z.(divexact c.const c_gcd)
|
||||||
|
) else (
|
||||||
|
raise E_unsat
|
||||||
|
)
|
||||||
|
in
|
||||||
|
|
||||||
let remove_constr (self:state) (c:constr) =
|
let c' = {
|
||||||
|
c with
|
||||||
|
le=T_map.map (fun c -> Z.(c / c_gcd)) c.le;
|
||||||
|
const;
|
||||||
|
} in
|
||||||
|
Log.debugf 50
|
||||||
|
(fun k->k "(@[intsolver.simplify@ :from %a@ :into %a@])"
|
||||||
|
pp_constr c pp_constr c');
|
||||||
|
Ok c'
|
||||||
|
) else Ok c
|
||||||
|
with E_unsat ->
|
||||||
|
Log.debugf 50 (fun k->k "(@[intsolver.simplify.unsat@ %a@])" pp_constr c);
|
||||||
|
Error ()
|
||||||
|
|
||||||
|
let add_constr (self:state) (c:constr) : unit =
|
||||||
|
if is_ok_ self then (
|
||||||
|
let c = {c with le=norm_le self c.le } in
|
||||||
|
match simplify_constr c with
|
||||||
|
| Ok c ->
|
||||||
|
LE.iter (fun t _ -> incr_v self t) c.le;
|
||||||
|
self.constrs <- c :: self.constrs
|
||||||
|
| Error () ->
|
||||||
|
self.ok <- Error c
|
||||||
|
)
|
||||||
|
|
||||||
|
let remove_constr (self:state) (c:constr) : unit =
|
||||||
LE.iter (fun t _ -> decr_v self t) c.le
|
LE.iter (fun t _ -> decr_v self t) c.le
|
||||||
|
|
||||||
let create (self:t) : state =
|
let create (self:t) : state =
|
||||||
|
|
@ -265,6 +422,7 @@ module Make(A : ARG)
|
||||||
vars=T_map.empty;
|
vars=T_map.empty;
|
||||||
rw=T_map.empty;
|
rw=T_map.empty;
|
||||||
constrs=[];
|
constrs=[];
|
||||||
|
ok=Ok();
|
||||||
} in
|
} in
|
||||||
BVec.iter self.defs
|
BVec.iter self.defs
|
||||||
~f:(fun (v,le) ->
|
~f:(fun (v,le) ->
|
||||||
|
|
@ -273,10 +431,10 @@ module Make(A : ARG)
|
||||||
BVec.iter self.cs
|
BVec.iter self.cs
|
||||||
~f:(fun (c:Constr.t) ->
|
~f:(fun (c:Constr.t) ->
|
||||||
let {Constr.le; op; const; lits} = c in
|
let {Constr.le; op; const; lits} = c in
|
||||||
let op = match op with
|
let op, const = match op with
|
||||||
| Op.Eq -> Eq
|
| Op.Eq -> Eq, const
|
||||||
| Op.Leq -> Leq
|
| Op.Leq -> Leq, const
|
||||||
| Op.Lt -> Lt
|
| Op.Lt -> Leq, Z.pred const (* [x < t] is [x <= t-1] *)
|
||||||
in
|
in
|
||||||
let c = {le;const;lits;op} in
|
let c = {le;const;lits;op} in
|
||||||
add_constr state c
|
add_constr state c
|
||||||
|
|
@ -289,22 +447,59 @@ module Make(A : ARG)
|
||||||
let m = Model.empty in
|
let m = Model.empty in
|
||||||
Sat m (* TODO: model *)
|
Sat m (* TODO: model *)
|
||||||
|
|
||||||
| Some (t, _) ->
|
| Some (t, _) -> elim_var_ self t
|
||||||
self.vars <- T_map.remove t self.vars;
|
|
||||||
Log.debugf 30
|
|
||||||
(fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])"
|
|
||||||
A.pp_term t (T_map.cardinal self.vars));
|
|
||||||
|
|
||||||
assert false (* TODO *)
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
and elim_var_ self t : result =
|
||||||
|
Log.debugf 30
|
||||||
|
(fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])"
|
||||||
|
A.pp_term t (T_map.cardinal self.vars));
|
||||||
|
|
||||||
|
assert (not (T_map.mem t self.rw)); (* woudl have been rewritten away *)
|
||||||
|
self.vars <- T_map.remove t self.vars;
|
||||||
|
|
||||||
|
(* gather the sets *)
|
||||||
|
let set_e = ref [] in (* eqns *)
|
||||||
|
let set_l = ref [] in (* t <= … *)
|
||||||
|
let set_g = ref [] in (* t >= … *)
|
||||||
|
let set_m = ref [] in (* t = … [n] *)
|
||||||
|
let others = ref [] in
|
||||||
|
|
||||||
|
let classify_constr (c:constr) =
|
||||||
|
match T_map.get t c.le, c.op with
|
||||||
|
| None, _ ->
|
||||||
|
others := c :: !others;
|
||||||
|
| Some n_t, Leq ->
|
||||||
|
if Z.sign n_t > 0 then
|
||||||
|
set_l := (n_t,c) :: !set_l
|
||||||
|
else
|
||||||
|
set_g := (Z.abs n_t,c) :: !set_g
|
||||||
|
| Some n_t, Eq ->
|
||||||
|
set_e := (n_t,c) :: !set_e
|
||||||
|
| Some n_t, Eq_mod _ ->
|
||||||
|
set_m := (n_t,c) :: !set_m
|
||||||
|
in
|
||||||
|
List.iter classify_constr self.constrs;
|
||||||
|
|
||||||
|
Log.debugf 50
|
||||||
|
(fun k->
|
||||||
|
let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in
|
||||||
|
k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])"
|
||||||
|
A.pp_term t pps !set_e pps !set_l pps !set_g pps !set_m);
|
||||||
|
assert false
|
||||||
end
|
end
|
||||||
|
|
||||||
let check (self:t) : result =
|
let check (self:t) : result =
|
||||||
Log.debugf 10 (fun k->k "(@[intsolver.check@])");
|
Log.debugf 10 (fun k->k "(@[intsolver.check@])");
|
||||||
let state = Check_.create self in
|
let state = Check_.create self in
|
||||||
Check_.solve_rec state
|
match state.ok with
|
||||||
|
| Ok () ->
|
||||||
|
Check_.solve_rec state
|
||||||
|
| Error c ->
|
||||||
|
Log.debugf 10 (fun k->k "(@[insolver.unsat-constraint@ %a@])" Check_.pp_constr c);
|
||||||
|
(* TODO proper certificate *)
|
||||||
|
Unsat ()
|
||||||
|
|
||||||
let _check_invariants _ = ()
|
let _check_invariants _ = ()
|
||||||
end
|
end
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue