mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-07 11:45:41 -05:00
633 lines
17 KiB
OCaml
633 lines
17 KiB
OCaml
|
||
module type ARG = sig
|
||
module Z : Sidekick_arith.INT_FULL
|
||
|
||
type term
|
||
type lit
|
||
|
||
val pp_term : term Fmt.printer
|
||
val pp_lit : lit Fmt.printer
|
||
|
||
module T_map : CCMap.S with type key = term
|
||
end
|
||
|
||
module type S = sig
|
||
module A : ARG
|
||
|
||
module Op : sig
|
||
type t =
|
||
| Leq
|
||
| Lt
|
||
| Eq
|
||
val pp : t Fmt.printer
|
||
end
|
||
|
||
type t
|
||
|
||
val create : unit -> t
|
||
|
||
val push_level : t -> unit
|
||
|
||
val pop_levels : t -> int -> unit
|
||
|
||
val assert_ :
|
||
t ->
|
||
(A.Z.t * A.term) list -> Op.t -> A.Z.t ->
|
||
lit:A.lit ->
|
||
unit
|
||
|
||
val define :
|
||
t ->
|
||
A.term ->
|
||
(A.Z.t * A.term) list ->
|
||
unit
|
||
|
||
module Cert : sig
|
||
type t
|
||
val pp : t Fmt.printer
|
||
|
||
val lits : t -> A.lit Iter.t
|
||
end
|
||
|
||
module Model : sig
|
||
type t
|
||
val pp : t Fmt.printer
|
||
|
||
val eval : t -> A.term -> A.Z.t option
|
||
end
|
||
|
||
type result =
|
||
| Sat of Model.t
|
||
| Unsat of Cert.t
|
||
|
||
val pp_result : result Fmt.printer
|
||
|
||
val check : t -> result
|
||
|
||
(**/**)
|
||
val _check_invariants : t -> unit
|
||
(**/**)
|
||
end
|
||
|
||
|
||
|
||
module Make(A : ARG)
|
||
: S with module A = A
|
||
= struct
|
||
module BVec = Backtrack_stack
|
||
module A = A
|
||
open A
|
||
|
||
module ZTbl = CCHashtbl.Make(Z)
|
||
|
||
module Utils_ : sig
|
||
type divisor = {
|
||
prime : Z.t;
|
||
power : int;
|
||
}
|
||
val is_prime : Z.t -> bool
|
||
val prime_decomposition : Z.t -> divisor list
|
||
val primes_leq : Z.t -> Z.t Iter.t
|
||
end = struct
|
||
|
||
type divisor = {
|
||
prime : Z.t;
|
||
power : int;
|
||
}
|
||
|
||
let two = Z.of_int 2
|
||
|
||
(* table from numbers to some of their divisor (if any) *)
|
||
let _table = lazy (
|
||
let t = ZTbl.create 256 in
|
||
ZTbl.add t two None;
|
||
t)
|
||
|
||
let _divisors n = ZTbl.find (Lazy.force _table) n
|
||
|
||
let _add_prime n =
|
||
ZTbl.replace (Lazy.force _table) n None
|
||
|
||
(* add to the table the fact that [d] is a divisor of [n] *)
|
||
let _add_divisor n d =
|
||
assert (not (ZTbl.mem (Lazy.force _table) n));
|
||
ZTbl.add (Lazy.force _table) n (Some d)
|
||
|
||
(* primality test, modifies _table *)
|
||
let _is_prime n0 =
|
||
let n = ref two in
|
||
let bound = Z.succ (Z.sqrt n0) in
|
||
let is_prime = ref true in
|
||
while !is_prime && Z.(!n <= bound) do
|
||
if Z.(rem n0 !n = zero)
|
||
then begin
|
||
is_prime := false;
|
||
_add_divisor n0 !n;
|
||
end;
|
||
n := Z.succ !n;
|
||
done;
|
||
if !is_prime then _add_prime n0;
|
||
!is_prime
|
||
|
||
let is_prime n =
|
||
try
|
||
begin match _divisors n with
|
||
| None -> true
|
||
| Some _ -> false
|
||
end
|
||
with Not_found ->
|
||
if Z.probab_prime n && _is_prime n then (
|
||
_add_prime n; true
|
||
) else false
|
||
|
||
let rec _merge l1 l2 = match l1, l2 with
|
||
| [], _ -> l2
|
||
| _, [] -> l1
|
||
| p1::l1', p2::l2' ->
|
||
match Z.compare p1.prime p2.prime with
|
||
| 0 ->
|
||
{prime=p1.prime; power=p1.power+p2.power} :: _merge l1' l2'
|
||
| n when n < 0 ->
|
||
p1 :: _merge l1' l2
|
||
| _ -> p2 :: _merge l1 l2'
|
||
|
||
let rec _decompose n =
|
||
try
|
||
begin match _divisors n with
|
||
| None -> [{prime=n; power=1;}]
|
||
| Some q1 ->
|
||
let q2 = Z.divexact n q1 in
|
||
_merge (_decompose q1) (_decompose q2)
|
||
end
|
||
with Not_found ->
|
||
ignore (_is_prime n);
|
||
_decompose n
|
||
|
||
let prime_decomposition n =
|
||
if is_prime n
|
||
then [{prime=n; power=1;}]
|
||
else _decompose n
|
||
|
||
let primes_leq n0 k =
|
||
let n = ref two in
|
||
while Z.(!n <= n0) do
|
||
if is_prime !n then k !n
|
||
done
|
||
end [@@warning "-60"]
|
||
|
||
|
||
module Op = struct
|
||
type t =
|
||
| Leq
|
||
| Lt
|
||
| Eq
|
||
|
||
let pp out = function
|
||
| Leq -> Fmt.string out "<="
|
||
| Lt -> Fmt.string out "<"
|
||
| Eq -> Fmt.string out "="
|
||
end
|
||
|
||
module Linexp = struct
|
||
type t = Z.t T_map.t
|
||
let is_empty = T_map.is_empty
|
||
let empty : t = T_map.empty
|
||
|
||
let pp out (self:t) : unit =
|
||
let pp_pair out (t,z) =
|
||
if Z.(z = one) then A.pp_term out t
|
||
else Fmt.fprintf out "%a · %a" Z.pp z A.pp_term t in
|
||
if is_empty self then Fmt.string out "0"
|
||
else Fmt.fprintf out "(@[%a@])"
|
||
Fmt.(iter ~sep:(return "@ + ") pp_pair) (T_map.to_iter self)
|
||
|
||
let iter = T_map.iter
|
||
let return t : t = T_map.add t Z.one empty
|
||
let neg self : t = T_map.map Z.neg self
|
||
let mem self t : bool = T_map.mem t self
|
||
let remove self t = T_map.remove t self
|
||
let find_exn self t = T_map.find t self
|
||
let mult n self =
|
||
if Z.(n = zero) then empty
|
||
else T_map.map (fun c -> Z.(c * n)) self
|
||
|
||
let add (self:t) (c:Z.t) (t:term) : t =
|
||
let n = Z.(c + T_map.get_or ~default:Z.zero t self) in
|
||
if Z.(n = zero)
|
||
then T_map.remove t self
|
||
else T_map.add t n self
|
||
|
||
let merge (self:t) (other:t) : t =
|
||
T_map.fold
|
||
(fun t c m -> add m c t)
|
||
other self
|
||
|
||
let of_list l : t =
|
||
List.fold_left (fun self (c,t) -> add self c t) empty l
|
||
|
||
(* map each term to a linexp *)
|
||
let flat_map f (self:t) : t =
|
||
T_map.fold
|
||
(fun t c m ->
|
||
let t_le = mult c (f t) in
|
||
merge m t_le
|
||
)
|
||
self empty
|
||
|
||
let (+) = merge
|
||
let ( * ) = mult
|
||
let ( ~- ) = neg
|
||
let (-) a b = a + ~- b
|
||
end
|
||
|
||
module Cert = struct
|
||
type t = unit
|
||
let pp = Fmt.unit
|
||
|
||
let lits _ = Iter.empty (* TODO *)
|
||
end
|
||
|
||
module Model = struct
|
||
type t = {
|
||
m: Z.t T_map.t;
|
||
} [@@unboxed]
|
||
|
||
let pp out self =
|
||
let pp_pair out (t,z) = Fmt.fprintf out "(@[%a := %a@])" A.pp_term t Z.pp z in
|
||
Fmt.fprintf out "(@[model@ %a@])"
|
||
Fmt.(iter ~sep:(return "@ ") pp_pair) (T_map.to_iter self.m)
|
||
|
||
let empty : t = {m=T_map.empty}
|
||
|
||
let eval (self:t) t : Z.t option = T_map.get t self.m
|
||
end
|
||
|
||
module Constr = struct
|
||
type t = {
|
||
le: Linexp.t;
|
||
const: Z.t;
|
||
op: Op.t;
|
||
lits: lit Bag.t;
|
||
}
|
||
|
||
(* FIXME: need to simplify: compute gcd(le.coeffs), then divide by that
|
||
and round const appropriately *)
|
||
|
||
let pp out self =
|
||
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le Op.pp self.op Z.pp self.const
|
||
end
|
||
|
||
type t = {
|
||
defs: (term * Linexp.t) BVec.t;
|
||
cs: Constr.t BVec.t;
|
||
}
|
||
|
||
let create() : t =
|
||
{ defs=BVec.create();
|
||
cs=BVec.create(); }
|
||
|
||
let push_level self =
|
||
BVec.push_level self.defs;
|
||
BVec.push_level self.cs;
|
||
()
|
||
|
||
let pop_levels self n =
|
||
BVec.pop_levels self.defs n ~f:(fun _ -> ());
|
||
BVec.pop_levels self.cs n ~f:(fun _ -> ());
|
||
()
|
||
|
||
type result =
|
||
| Sat of Model.t
|
||
| Unsat of Cert.t
|
||
|
||
let pp_result out = function
|
||
| Sat m -> Fmt.fprintf out "(@[SAT@ %a@])" Model.pp m
|
||
| Unsat cert -> Fmt.fprintf out "(@[UNSAT@ %a@])" Cert.pp cert
|
||
|
||
let assert_ (self:t) l op c ~lit : unit =
|
||
let le = Linexp.of_list l in
|
||
let c = {Constr.le; const=c; op; lits=Bag.return lit} in
|
||
Log.debugf 15 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c);
|
||
BVec.push self.cs c
|
||
|
||
(* TODO: check before hand that [t] occurs nowhere else *)
|
||
let define (self:t) t l : unit =
|
||
let le = Linexp.of_list l in
|
||
BVec.push self.defs (t,le)
|
||
|
||
(* #### checking #### *)
|
||
|
||
module Check_ = struct
|
||
module LE = Linexp
|
||
|
||
type op =
|
||
| Leq
|
||
| Eq
|
||
| Eq_mod of {
|
||
prime: Z.t;
|
||
pow: int;
|
||
} (* modulo prime^pow *)
|
||
|
||
let pp_op out = function
|
||
| Leq -> Fmt.string out "<="
|
||
| Eq -> Fmt.string out "="
|
||
| Eq_mod {prime; pow} -> Fmt.fprintf out "%a^%d" Z.pp prime pow
|
||
|
||
type constr = {
|
||
le: LE.t;
|
||
const: Z.t;
|
||
op: op;
|
||
lits: lit Bag.t;
|
||
}
|
||
|
||
let pp_constr out self =
|
||
Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le pp_op self.op Z.pp self.const
|
||
|
||
type state = {
|
||
mutable rw: LE.t T_map.t; (* rewrite rules *)
|
||
mutable vars: int T_map.t; (* variables in at least one constraint *)
|
||
mutable constrs: constr list;
|
||
mutable ok: (unit, constr) Result.t;
|
||
}
|
||
(* main solving state. mutable, but copied for backtracking.
|
||
invariant: variables in [rw] do not occur anywhere else
|
||
*)
|
||
|
||
let[@inline] is_ok_ self = CCResult.is_ok self.ok
|
||
|
||
(* perform rewriting on the linear expression *)
|
||
let rec norm_le (self:state) (le:LE.t) : LE.t =
|
||
LE.flat_map
|
||
(fun t ->
|
||
begin match T_map.find t self.rw with
|
||
| le -> norm_le self le
|
||
| exception Not_found -> LE.return t
|
||
end)
|
||
le
|
||
|
||
let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars
|
||
|
||
let[@inline] incr_v (self:state) (t:term) : unit =
|
||
self.vars <- T_map.add t (1 + count_v self t) self.vars
|
||
|
||
(* GCD of the coefficients of this linear expression *)
|
||
let gcd_coeffs (le:LE.t) : Z.t =
|
||
match T_map.choose_opt le with
|
||
| None -> Z.one
|
||
| Some (_, z0) -> T_map.fold (fun _ z m -> Z.gcd z m) le z0
|
||
|
||
let decr_v (self:state) (t:term) : unit =
|
||
let n = count_v self t - 1 in
|
||
assert (n >= 0);
|
||
self.vars <-
|
||
(if n=0 then T_map.remove t self.vars
|
||
else T_map.add t n self.vars)
|
||
|
||
let simplify_constr (c:constr) : (constr, unit) Result.t =
|
||
let exception E_unsat in
|
||
try
|
||
match T_map.choose_opt c.le with
|
||
| None -> Ok c
|
||
| Some (_, z0) ->
|
||
let c_gcd = T_map.fold (fun _ z m -> Z.gcd z m) c.le z0 in
|
||
if Z.(c_gcd > one) then (
|
||
let const = match c.op with
|
||
| Leq ->
|
||
(* round down, regardless of sign *)
|
||
Z.ediv c.const c_gcd
|
||
| Eq | Eq_mod _ ->
|
||
if Z.equal (Z.rem c.const c_gcd) Z.zero then (
|
||
(* compatible constant *)
|
||
Z.(divexact c.const c_gcd)
|
||
) else (
|
||
raise E_unsat
|
||
)
|
||
in
|
||
|
||
let c' = {
|
||
c with
|
||
le=T_map.map (fun c -> Z.(c / c_gcd)) c.le;
|
||
const;
|
||
} in
|
||
Log.debugf 50
|
||
(fun k->k "(@[intsolver.simplify@ :from %a@ :into %a@])"
|
||
pp_constr c pp_constr c');
|
||
Ok c'
|
||
) else Ok c
|
||
with E_unsat ->
|
||
Log.debugf 50 (fun k->k "(@[intsolver.simplify.unsat@ %a@])" pp_constr c);
|
||
Error ()
|
||
|
||
let add_constr (self:state) (c:constr) : unit =
|
||
if is_ok_ self then (
|
||
let c = {c with le=norm_le self c.le } in
|
||
match simplify_constr c with
|
||
| Ok c ->
|
||
Log.debugf 50 (fun k->k "(@[intsolver.add-constr@ %a@])" pp_constr c);
|
||
LE.iter (fun t _ -> incr_v self t) c.le;
|
||
self.constrs <- c :: self.constrs
|
||
| Error () ->
|
||
self.ok <- Error c
|
||
)
|
||
|
||
let remove_constr (self:state) (c:constr) : unit =
|
||
LE.iter (fun t _ -> decr_v self t) c.le
|
||
|
||
let create (self:t) : state =
|
||
let state = {
|
||
vars=T_map.empty;
|
||
rw=T_map.empty;
|
||
constrs=[];
|
||
ok=Ok();
|
||
} in
|
||
BVec.iter self.defs
|
||
~f:(fun (v,le) ->
|
||
assert (not (T_map.mem v state.rw));
|
||
(* normalize as much as we can now *)
|
||
let le = norm_le state le in
|
||
Log.debugf 50 (fun k->k "(@[intsolver.add-rw %a@ := %a@])" pp_term v LE.pp le);
|
||
state.rw <- T_map.add v le state.rw);
|
||
BVec.iter self.cs
|
||
~f:(fun (c:Constr.t) ->
|
||
let {Constr.le; op; const; lits} = c in
|
||
let op, const = match op with
|
||
| Op.Eq -> Eq, const
|
||
| Op.Leq -> Leq, const
|
||
| Op.Lt -> Leq, Z.pred const (* [x < t] is [x <= t-1] *)
|
||
in
|
||
let c = {le;const;lits;op} in
|
||
add_constr state c
|
||
);
|
||
state
|
||
|
||
let rec solve_rec (self:state) : result =
|
||
begin match T_map.choose_opt self.vars with
|
||
| None ->
|
||
let m = Model.empty in
|
||
Sat m (* TODO: model *)
|
||
|
||
| Some (t, _) -> elim_var_ self t
|
||
end
|
||
|
||
|
||
and elim_var_ self (x:term) : result =
|
||
Log.debugf 20
|
||
(fun k->k "(@[@{<Yellow>intsolver.elim-var@}@ %a@ :remaining %d@])"
|
||
A.pp_term x (T_map.cardinal self.vars));
|
||
|
||
assert (not (T_map.mem x self.rw)); (* would have been rewritten away *)
|
||
self.vars <- T_map.remove x self.vars;
|
||
|
||
(* gather the sets *)
|
||
let set_e = ref [] in (* eqns *)
|
||
let set_l = ref [] in (* t <= … *)
|
||
let set_g = ref [] in (* t >= … *)
|
||
let set_m = ref [] in (* t = … [n] *)
|
||
let others = ref [] in
|
||
|
||
let classify_constr (c:constr) =
|
||
match T_map.get x c.le, c.op with
|
||
| None, _ ->
|
||
others := c :: !others;
|
||
| Some n_t, Leq ->
|
||
if Z.sign n_t > 0 then
|
||
set_l := (n_t,c) :: !set_l
|
||
else
|
||
set_g := (n_t,c) :: !set_g
|
||
| Some n_t, Eq ->
|
||
set_e := (n_t,c) :: !set_e
|
||
| Some n_t, Eq_mod _ ->
|
||
set_m := (n_t,c) :: !set_m
|
||
in
|
||
|
||
List.iter classify_constr self.constrs;
|
||
self.constrs <- !others; (* remove all constraints involving [t] *)
|
||
|
||
Log.debugf 50
|
||
(fun k->
|
||
let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in
|
||
k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])"
|
||
A.pp_term x pps !set_e pps !set_l pps !set_g pps !set_m);
|
||
|
||
(* now apply the algorithm *)
|
||
if !set_e <> [] then (
|
||
(* case (a): eliminate via an equality. *)
|
||
|
||
(* pick an equality with a small coeff, if possible *)
|
||
let coeff1, c1 =
|
||
Iter.of_list !set_e
|
||
|> Iter.min_exn ~lt:(fun (n1,_)(n2,_) -> Z.(abs n1 < abs n2))
|
||
in
|
||
|
||
let le1 = LE.(neg @@ remove c1.le x) in
|
||
|
||
Log.debugf 30
|
||
(fun k->k "(@[intsolver.case_a.eqn@ :coeff %a@ :c %a@])"
|
||
Z.pp coeff1 pp_constr c1);
|
||
|
||
let elim_in_constr (coeff2, c2) =
|
||
let le2 = LE.(neg @@ remove c2.le x) in
|
||
|
||
let gcd12 = Z.gcd coeff1 coeff2 in
|
||
(* coeff1 × p1 = coeff2 × p2 = lcm = coeff1 × coeff2 / gcd,
|
||
because coeff1 × coeff2 = lcm × gcd *)
|
||
let lcm12 = Z.(abs coeff1 * abs coeff2 / gcd12) in
|
||
let p1 = Z.(lcm12 / coeff1) in
|
||
let p2 = Z.(lcm12 / coeff2) in
|
||
Log.debugf 50
|
||
(fun k->k "(@[intsolver.elim-in-constr@ %a@ :gcd %a :lcm %a@ :p1 %a :p2 %a@])"
|
||
pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2);
|
||
|
||
let c' =
|
||
let lits = Bag.append c1.lits c2.lits in
|
||
if Z.sign coeff1 <> Z.sign coeff2 then (
|
||
let le' = LE.(p1 * le1 + p2 * le2) in
|
||
let const' = Z.(p1 * c1.const + p2 * c2.const) in
|
||
{op=c2.op; le=le'; const=const'; lits}
|
||
) else (
|
||
let le' = LE.(p1 * le1 - p2 * le2) in
|
||
let const' = Z.(p1 * c1.const - p2 * c2.const) in
|
||
let le', const' =
|
||
if Z.sign coeff1 < 0 then LE.neg le', Z.neg const'
|
||
else le', const'
|
||
in
|
||
{op=c2.op; le=le'; const=const'; lits}
|
||
)
|
||
in
|
||
add_constr self c'
|
||
|
||
(* also add a divisibility constraint if needed *)
|
||
(* TODO:
|
||
if Z.(p1 > one) then (
|
||
let c' = {le=le2; op=Eq_mod p1; const=c2.const} in
|
||
add_constr self c'
|
||
)
|
||
*)
|
||
in
|
||
List.iter elim_in_constr !set_l;
|
||
List.iter elim_in_constr !set_g;
|
||
List.iter elim_in_constr !set_m;
|
||
|
||
(* FIXME: handle the congruence *)
|
||
) else if !set_l = [] || !set_g = [] then (
|
||
(* case (b): no bound on at least one side *)
|
||
assert (!set_e=[]);
|
||
|
||
() (* FIXME: handle the congruence *)
|
||
) else (
|
||
(* case (c): combine inequalities pairwise *)
|
||
|
||
let elim_pair (coeff1, c1) (coeff2, c2) : unit =
|
||
assert (Z.sign coeff1 > 0 && Z.sign coeff2 < 0);
|
||
|
||
let le1 = LE.remove c1.le x in
|
||
let le2 = LE.remove c2.le x in
|
||
|
||
let gcd12 = Z.gcd coeff1 coeff2 in
|
||
let lcm12 = Z.(coeff1 * abs coeff2 / gcd12) in
|
||
|
||
let p1 = Z.(lcm12 / coeff1) in
|
||
let p2 = Z.(lcm12 / Z.abs coeff2) in
|
||
|
||
Log.debugf 50
|
||
(fun k->k "(@[intsolver.case-b.elim-pair@ L=%a@ G=%a@ \
|
||
:gcd %a :lcm %a@ :p1 %a :p2 %a@])"
|
||
pp_constr c1 pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2);
|
||
|
||
let new_ineq =
|
||
let le = LE.(p2 * le1 - p1 * le2) in
|
||
let const = Z.(p2 * c1.const - p1 * c2.const) in
|
||
let lits = Bag.append c1.lits c2.lits in
|
||
{op=Leq; le; const; lits}
|
||
in
|
||
|
||
add_constr self new_ineq;
|
||
(* TODO: handle modulo constraints *)
|
||
|
||
in
|
||
|
||
List.iter (fun x1 -> List.iter (elim_pair x1) !set_g) !set_l;
|
||
);
|
||
|
||
(* now recurse *)
|
||
solve_rec self
|
||
end
|
||
|
||
let check (self:t) : result =
|
||
|
||
Log.debugf 10 (fun k->k "(@[@{<Yellow>intsolver.check@}@])");
|
||
let state = Check_.create self in
|
||
Log.debugf 10
|
||
(fun k->k "(@[intsolver.check.stat@ :n-vars %d@ :n-constr %d@])"
|
||
(T_map.cardinal state.vars) (List.length state.constrs));
|
||
|
||
match state.ok with
|
||
| Ok () ->
|
||
Check_.solve_rec state
|
||
| Error c ->
|
||
Log.debugf 10 (fun k->k "(@[insolver.unsat-constraint@ %a@])" Check_.pp_constr c);
|
||
(* TODO proper certificate *)
|
||
Unsat ()
|
||
|
||
let _check_invariants _ = ()
|
||
end
|