mirror of
https://github.com/c-cube/sidekick.git
synced 2025-12-06 03:05:31 -05:00
intsolver: partial implementation
This commit is contained in:
parent
be7451b070
commit
10c8006597
2 changed files with 152 additions and 40 deletions
|
|
@ -205,6 +205,8 @@ module Make(A : ARG)
|
|||
let return t : t = T_map.add t Z.one empty
|
||||
let neg self : t = T_map.map Z.neg self
|
||||
let mem self t : bool = T_map.mem t self
|
||||
let remove self t = T_map.remove t self
|
||||
let find_exn self t = T_map.find t self
|
||||
let mult n self =
|
||||
if Z.(n = zero) then empty
|
||||
else T_map.map (fun c -> Z.(c * n)) self
|
||||
|
|
@ -230,7 +232,12 @@ module Make(A : ARG)
|
|||
let t_le = mult c (f t) in
|
||||
merge m t_le
|
||||
)
|
||||
empty self
|
||||
self empty
|
||||
|
||||
let (+) = merge
|
||||
let ( * ) = mult
|
||||
let ( ~- ) = neg
|
||||
let (-) a b = a + ~- b
|
||||
end
|
||||
|
||||
module Cert = struct
|
||||
|
|
@ -300,7 +307,7 @@ module Make(A : ARG)
|
|||
let assert_ (self:t) l op c ~lit : unit =
|
||||
let le = Linexp.of_list l in
|
||||
let c = {Constr.le; const=c; op; lits=Bag.return lit} in
|
||||
Log.debugf 10 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c);
|
||||
Log.debugf 15 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c);
|
||||
BVec.push self.cs c
|
||||
|
||||
(* TODO: check before hand that [t] occurs nowhere else *)
|
||||
|
|
@ -349,9 +356,13 @@ module Make(A : ARG)
|
|||
let[@inline] is_ok_ self = CCResult.is_ok self.ok
|
||||
|
||||
(* perform rewriting on the linear expression *)
|
||||
let norm_le (self:state) (le:LE.t) : LE.t =
|
||||
let rec norm_le (self:state) (le:LE.t) : LE.t =
|
||||
LE.flat_map
|
||||
(fun t -> try T_map.find t self.rw with Not_found -> LE.return t)
|
||||
(fun t ->
|
||||
begin match T_map.find t self.rw with
|
||||
| le -> norm_le self le
|
||||
| exception Not_found -> LE.return t
|
||||
end)
|
||||
le
|
||||
|
||||
let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars
|
||||
|
|
@ -359,6 +370,12 @@ module Make(A : ARG)
|
|||
let[@inline] incr_v (self:state) (t:term) : unit =
|
||||
self.vars <- T_map.add t (1 + count_v self t) self.vars
|
||||
|
||||
(* GCD of the coefficients of this linear expression *)
|
||||
let gcd_coeffs (le:LE.t) : Z.t =
|
||||
match T_map.choose_opt le with
|
||||
| None -> Z.one
|
||||
| Some (_, z0) -> T_map.fold (fun _ z m -> Z.gcd z m) le z0
|
||||
|
||||
let decr_v (self:state) (t:term) : unit =
|
||||
let n = count_v self t - 1 in
|
||||
assert (n >= 0);
|
||||
|
|
@ -377,9 +394,7 @@ module Make(A : ARG)
|
|||
let const = match c.op with
|
||||
| Leq ->
|
||||
(* round down, regardless of sign *)
|
||||
let cplus = Z.abs c.const in
|
||||
let cplus = Z.(cplus / c_gcd) in
|
||||
if Z.sign c.const >= 0 then cplus else Z.neg cplus
|
||||
Z.ediv c.const c_gcd
|
||||
| Eq | Eq_mod _ ->
|
||||
if Z.equal (Z.rem c.const c_gcd) Z.zero then (
|
||||
(* compatible constant *)
|
||||
|
|
@ -408,6 +423,7 @@ module Make(A : ARG)
|
|||
let c = {c with le=norm_le self c.le } in
|
||||
match simplify_constr c with
|
||||
| Ok c ->
|
||||
Log.debugf 50 (fun k->k "(@[intsolver.add-constr@ %a@])" pp_constr c);
|
||||
LE.iter (fun t _ -> incr_v self t) c.le;
|
||||
self.constrs <- c :: self.constrs
|
||||
| Error () ->
|
||||
|
|
@ -427,7 +443,10 @@ module Make(A : ARG)
|
|||
BVec.iter self.defs
|
||||
~f:(fun (v,le) ->
|
||||
assert (not (T_map.mem v state.rw));
|
||||
state.rw <- T_map.add v (norm_le state le) state.rw);
|
||||
(* normalize as much as we can now *)
|
||||
let le = norm_le state le in
|
||||
Log.debugf 50 (fun k->k "(@[intsolver.add-rw %a@ := %a@])" pp_term v LE.pp le);
|
||||
state.rw <- T_map.add v le state.rw);
|
||||
BVec.iter self.cs
|
||||
~f:(fun (c:Constr.t) ->
|
||||
let {Constr.le; op; const; lits} = c in
|
||||
|
|
@ -451,13 +470,13 @@ module Make(A : ARG)
|
|||
end
|
||||
|
||||
|
||||
and elim_var_ self t : result =
|
||||
Log.debugf 30
|
||||
(fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])"
|
||||
A.pp_term t (T_map.cardinal self.vars));
|
||||
and elim_var_ self (x:term) : result =
|
||||
Log.debugf 20
|
||||
(fun k->k "(@[@{<Yellow>intsolver.elim-var@}@ %a@ :remaining %d@])"
|
||||
A.pp_term x (T_map.cardinal self.vars));
|
||||
|
||||
assert (not (T_map.mem t self.rw)); (* woudl have been rewritten away *)
|
||||
self.vars <- T_map.remove t self.vars;
|
||||
assert (not (T_map.mem x self.rw)); (* would have been rewritten away *)
|
||||
self.vars <- T_map.remove x self.vars;
|
||||
|
||||
(* gather the sets *)
|
||||
let set_e = ref [] in (* eqns *)
|
||||
|
|
@ -467,32 +486,141 @@ module Make(A : ARG)
|
|||
let others = ref [] in
|
||||
|
||||
let classify_constr (c:constr) =
|
||||
match T_map.get t c.le, c.op with
|
||||
match T_map.get x c.le, c.op with
|
||||
| None, _ ->
|
||||
others := c :: !others;
|
||||
| Some n_t, Leq ->
|
||||
if Z.sign n_t > 0 then
|
||||
set_l := (n_t,c) :: !set_l
|
||||
else
|
||||
set_g := (Z.abs n_t,c) :: !set_g
|
||||
set_g := (n_t,c) :: !set_g
|
||||
| Some n_t, Eq ->
|
||||
set_e := (n_t,c) :: !set_e
|
||||
| Some n_t, Eq_mod _ ->
|
||||
set_m := (n_t,c) :: !set_m
|
||||
in
|
||||
|
||||
List.iter classify_constr self.constrs;
|
||||
self.constrs <- !others; (* remove all constraints involving [t] *)
|
||||
|
||||
Log.debugf 50
|
||||
(fun k->
|
||||
let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in
|
||||
k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])"
|
||||
A.pp_term t pps !set_e pps !set_l pps !set_g pps !set_m);
|
||||
assert false
|
||||
A.pp_term x pps !set_e pps !set_l pps !set_g pps !set_m);
|
||||
|
||||
(* now apply the algorithm *)
|
||||
if !set_e <> [] then (
|
||||
(* case (a): eliminate via an equality. *)
|
||||
|
||||
(* pick an equality with a small coeff, if possible *)
|
||||
let coeff1, c1 =
|
||||
Iter.of_list !set_e
|
||||
|> Iter.min_exn ~lt:(fun (n1,_)(n2,_) -> Z.(abs n1 < abs n2))
|
||||
in
|
||||
|
||||
let le1 = LE.(neg @@ remove c1.le x) in
|
||||
|
||||
Log.debugf 30
|
||||
(fun k->k "(@[intsolver.case_a.eqn@ :coeff %a@ :c %a@])"
|
||||
Z.pp coeff1 pp_constr c1);
|
||||
|
||||
let elim_in_constr (coeff2, c2) =
|
||||
let le2 = LE.(neg @@ remove c2.le x) in
|
||||
|
||||
let gcd12 = Z.gcd coeff1 coeff2 in
|
||||
(* coeff1 × p1 = coeff2 × p2 = lcm = coeff1 × coeff2 / gcd,
|
||||
because coeff1 × coeff2 = lcm × gcd *)
|
||||
let lcm12 = Z.(abs coeff1 * abs coeff2 / gcd12) in
|
||||
let p1 = Z.(lcm12 / coeff1) in
|
||||
let p2 = Z.(lcm12 / coeff2) in
|
||||
Log.debugf 50
|
||||
(fun k->k "(@[intsolver.elim-in-constr@ %a@ :gcd %a :lcm %a@ :p1 %a :p2 %a@])"
|
||||
pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2);
|
||||
|
||||
let c' =
|
||||
let lits = Bag.append c1.lits c2.lits in
|
||||
if Z.sign coeff1 <> Z.sign coeff2 then (
|
||||
let le' = LE.(p1 * le1 + p2 * le2) in
|
||||
let const' = Z.(p1 * c1.const + p2 * c2.const) in
|
||||
{op=c2.op; le=le'; const=const'; lits}
|
||||
) else (
|
||||
let le' = LE.(p1 * le1 - p2 * le2) in
|
||||
let const' = Z.(p1 * c1.const - p2 * c2.const) in
|
||||
let le', const' =
|
||||
if Z.sign coeff1 < 0 then LE.neg le', Z.neg const'
|
||||
else le', const'
|
||||
in
|
||||
{op=c2.op; le=le'; const=const'; lits}
|
||||
)
|
||||
in
|
||||
add_constr self c'
|
||||
|
||||
(* also add a divisibility constraint if needed *)
|
||||
(* TODO:
|
||||
if Z.(p1 > one) then (
|
||||
let c' = {le=le2; op=Eq_mod p1; const=c2.const} in
|
||||
add_constr self c'
|
||||
)
|
||||
*)
|
||||
in
|
||||
List.iter elim_in_constr !set_l;
|
||||
List.iter elim_in_constr !set_g;
|
||||
List.iter elim_in_constr !set_m;
|
||||
|
||||
(* FIXME: handle the congruence *)
|
||||
) else if !set_l = [] || !set_g = [] then (
|
||||
(* case (b): no bound on at least one side *)
|
||||
assert (!set_e=[]);
|
||||
|
||||
() (* FIXME: handle the congruence *)
|
||||
) else (
|
||||
(* case (c): combine inequalities pairwise *)
|
||||
|
||||
let elim_pair (coeff1, c1) (coeff2, c2) : unit =
|
||||
assert (Z.sign coeff1 > 0 && Z.sign coeff2 < 0);
|
||||
|
||||
let le1 = LE.remove c1.le x in
|
||||
let le2 = LE.remove c2.le x in
|
||||
|
||||
let gcd12 = Z.gcd coeff1 coeff2 in
|
||||
let lcm12 = Z.(coeff1 * abs coeff2 / gcd12) in
|
||||
|
||||
let p1 = Z.(lcm12 / coeff1) in
|
||||
let p2 = Z.(lcm12 / Z.abs coeff2) in
|
||||
|
||||
Log.debugf 50
|
||||
(fun k->k "(@[intsolver.case-b.elim-pair@ L=%a@ G=%a@ \
|
||||
:gcd %a :lcm %a@ :p1 %a :p2 %a@])"
|
||||
pp_constr c1 pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2);
|
||||
|
||||
let new_ineq =
|
||||
let le = LE.(p2 * le1 - p1 * le2) in
|
||||
let const = Z.(p2 * c1.const - p1 * c2.const) in
|
||||
let lits = Bag.append c1.lits c2.lits in
|
||||
{op=Leq; le; const; lits}
|
||||
in
|
||||
|
||||
add_constr self new_ineq;
|
||||
(* TODO: handle modulo constraints *)
|
||||
|
||||
in
|
||||
|
||||
List.iter (fun x1 -> List.iter (elim_pair x1) !set_g) !set_l;
|
||||
);
|
||||
|
||||
(* now recurse *)
|
||||
solve_rec self
|
||||
end
|
||||
|
||||
let check (self:t) : result =
|
||||
Log.debugf 10 (fun k->k "(@[intsolver.check@])");
|
||||
|
||||
Log.debugf 10 (fun k->k "(@[@{<Yellow>intsolver.check@}@])");
|
||||
let state = Check_.create self in
|
||||
Log.debugf 10
|
||||
(fun k->k "(@[intsolver.check.stat@ :n-vars %d@ :n-constr %d@])"
|
||||
(T_map.cardinal state.vars) (List.length state.constrs));
|
||||
|
||||
match state.ok with
|
||||
| Ok () ->
|
||||
Check_.solve_rec state
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ let rand_n low n : Z.t QC.arbitrary =
|
|||
QC.map ~rev:ZarithZ.to_int Z.of_int QC.(low -- n)
|
||||
|
||||
(* TODO: fudge *)
|
||||
let rand_z = rand_n (-50) 100
|
||||
let rand_z = rand_n (-15) 15
|
||||
|
||||
module Step = struct
|
||||
module G = QC.Gen
|
||||
|
|
@ -113,7 +113,7 @@ module Step = struct
|
|||
| _ ->
|
||||
let gen =
|
||||
let+ le = gen_linexp
|
||||
and+ kind = oneofl [`Leq;`Lt;`Eq]
|
||||
and+ kind = frequencyl [5, `Leq; 5, `Lt; 3,`Eq]
|
||||
and+ n = rand_z.QC.gen in
|
||||
vars, (match kind with
|
||||
| `Lt -> S_lt(le,n)
|
||||
|
|
@ -159,7 +159,7 @@ module Step = struct
|
|||
let print = Fmt.to_string (Fmt.Dump.list pp_) in
|
||||
QC.make ~shrink ~print (gen_for n1 n2)
|
||||
|
||||
let rand : t list QC.arbitrary = rand_for 1 100
|
||||
let rand : t list QC.arbitrary = rand_for 1 15
|
||||
end
|
||||
|
||||
let on_propagate _ ~reason:_ = ()
|
||||
|
|
@ -272,7 +272,7 @@ let set_stats_maybe ar =
|
|||
|
||||
let check_sound =
|
||||
let ar =
|
||||
Step.(rand_for 0 300)
|
||||
Step.(rand_for 0 15)
|
||||
|> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat")
|
||||
|> set_stats_maybe
|
||||
in
|
||||
|
|
@ -307,7 +307,7 @@ let prop_backtrack pb =
|
|||
|
||||
let check_backtrack =
|
||||
let ar =
|
||||
Step.(rand_for 0 300)
|
||||
Step.(rand_for 0 15)
|
||||
|> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat")
|
||||
|> set_stats_maybe
|
||||
in
|
||||
|
|
@ -315,25 +315,9 @@ let check_backtrack =
|
|||
~long_factor:10 ~count:200 ~name:"solver2_backtrack"
|
||||
ar prop_backtrack
|
||||
|
||||
let check_scalable =
|
||||
let prop pb =
|
||||
let solver = Solver.create () in
|
||||
add_steps solver pb;
|
||||
ignore (Solver.check solver : Solver.result);
|
||||
true
|
||||
in
|
||||
let ar =
|
||||
Step.(rand_for 3_000 5_000)
|
||||
|> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat")
|
||||
|> set_stats_maybe
|
||||
in
|
||||
QC.Test.make ~long_factor:2 ~count:10 ~name:"solver2_scalable"
|
||||
ar prop
|
||||
|
||||
let props = [
|
||||
check_sound;
|
||||
check_backtrack;
|
||||
check_scalable;
|
||||
]
|
||||
|
||||
(* regression tests *)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue