intsolver: partial implementation

This commit is contained in:
Simon Cruanes 2022-01-31 11:08:20 -05:00
parent be7451b070
commit 10c8006597
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 152 additions and 40 deletions

View file

@ -205,6 +205,8 @@ module Make(A : ARG)
let return t : t = T_map.add t Z.one empty let return t : t = T_map.add t Z.one empty
let neg self : t = T_map.map Z.neg self let neg self : t = T_map.map Z.neg self
let mem self t : bool = T_map.mem t self let mem self t : bool = T_map.mem t self
let remove self t = T_map.remove t self
let find_exn self t = T_map.find t self
let mult n self = let mult n self =
if Z.(n = zero) then empty if Z.(n = zero) then empty
else T_map.map (fun c -> Z.(c * n)) self else T_map.map (fun c -> Z.(c * n)) self
@ -230,7 +232,12 @@ module Make(A : ARG)
let t_le = mult c (f t) in let t_le = mult c (f t) in
merge m t_le merge m t_le
) )
empty self self empty
let (+) = merge
let ( * ) = mult
let ( ~- ) = neg
let (-) a b = a + ~- b
end end
module Cert = struct module Cert = struct
@ -300,7 +307,7 @@ module Make(A : ARG)
let assert_ (self:t) l op c ~lit : unit = let assert_ (self:t) l op c ~lit : unit =
let le = Linexp.of_list l in let le = Linexp.of_list l in
let c = {Constr.le; const=c; op; lits=Bag.return lit} in let c = {Constr.le; const=c; op; lits=Bag.return lit} in
Log.debugf 10 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c); Log.debugf 15 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c);
BVec.push self.cs c BVec.push self.cs c
(* TODO: check before hand that [t] occurs nowhere else *) (* TODO: check before hand that [t] occurs nowhere else *)
@ -349,9 +356,13 @@ module Make(A : ARG)
let[@inline] is_ok_ self = CCResult.is_ok self.ok let[@inline] is_ok_ self = CCResult.is_ok self.ok
(* perform rewriting on the linear expression *) (* perform rewriting on the linear expression *)
let norm_le (self:state) (le:LE.t) : LE.t = let rec norm_le (self:state) (le:LE.t) : LE.t =
LE.flat_map LE.flat_map
(fun t -> try T_map.find t self.rw with Not_found -> LE.return t) (fun t ->
begin match T_map.find t self.rw with
| le -> norm_le self le
| exception Not_found -> LE.return t
end)
le le
let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars
@ -359,6 +370,12 @@ module Make(A : ARG)
let[@inline] incr_v (self:state) (t:term) : unit = let[@inline] incr_v (self:state) (t:term) : unit =
self.vars <- T_map.add t (1 + count_v self t) self.vars self.vars <- T_map.add t (1 + count_v self t) self.vars
(* GCD of the coefficients of this linear expression *)
let gcd_coeffs (le:LE.t) : Z.t =
match T_map.choose_opt le with
| None -> Z.one
| Some (_, z0) -> T_map.fold (fun _ z m -> Z.gcd z m) le z0
let decr_v (self:state) (t:term) : unit = let decr_v (self:state) (t:term) : unit =
let n = count_v self t - 1 in let n = count_v self t - 1 in
assert (n >= 0); assert (n >= 0);
@ -377,9 +394,7 @@ module Make(A : ARG)
let const = match c.op with let const = match c.op with
| Leq -> | Leq ->
(* round down, regardless of sign *) (* round down, regardless of sign *)
let cplus = Z.abs c.const in Z.ediv c.const c_gcd
let cplus = Z.(cplus / c_gcd) in
if Z.sign c.const >= 0 then cplus else Z.neg cplus
| Eq | Eq_mod _ -> | Eq | Eq_mod _ ->
if Z.equal (Z.rem c.const c_gcd) Z.zero then ( if Z.equal (Z.rem c.const c_gcd) Z.zero then (
(* compatible constant *) (* compatible constant *)
@ -408,6 +423,7 @@ module Make(A : ARG)
let c = {c with le=norm_le self c.le } in let c = {c with le=norm_le self c.le } in
match simplify_constr c with match simplify_constr c with
| Ok c -> | Ok c ->
Log.debugf 50 (fun k->k "(@[intsolver.add-constr@ %a@])" pp_constr c);
LE.iter (fun t _ -> incr_v self t) c.le; LE.iter (fun t _ -> incr_v self t) c.le;
self.constrs <- c :: self.constrs self.constrs <- c :: self.constrs
| Error () -> | Error () ->
@ -427,7 +443,10 @@ module Make(A : ARG)
BVec.iter self.defs BVec.iter self.defs
~f:(fun (v,le) -> ~f:(fun (v,le) ->
assert (not (T_map.mem v state.rw)); assert (not (T_map.mem v state.rw));
state.rw <- T_map.add v (norm_le state le) state.rw); (* normalize as much as we can now *)
let le = norm_le state le in
Log.debugf 50 (fun k->k "(@[intsolver.add-rw %a@ := %a@])" pp_term v LE.pp le);
state.rw <- T_map.add v le state.rw);
BVec.iter self.cs BVec.iter self.cs
~f:(fun (c:Constr.t) -> ~f:(fun (c:Constr.t) ->
let {Constr.le; op; const; lits} = c in let {Constr.le; op; const; lits} = c in
@ -451,13 +470,13 @@ module Make(A : ARG)
end end
and elim_var_ self t : result = and elim_var_ self (x:term) : result =
Log.debugf 30 Log.debugf 20
(fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])" (fun k->k "(@[@{<Yellow>intsolver.elim-var@}@ %a@ :remaining %d@])"
A.pp_term t (T_map.cardinal self.vars)); A.pp_term x (T_map.cardinal self.vars));
assert (not (T_map.mem t self.rw)); (* woudl have been rewritten away *) assert (not (T_map.mem x self.rw)); (* would have been rewritten away *)
self.vars <- T_map.remove t self.vars; self.vars <- T_map.remove x self.vars;
(* gather the sets *) (* gather the sets *)
let set_e = ref [] in (* eqns *) let set_e = ref [] in (* eqns *)
@ -467,32 +486,141 @@ module Make(A : ARG)
let others = ref [] in let others = ref [] in
let classify_constr (c:constr) = let classify_constr (c:constr) =
match T_map.get t c.le, c.op with match T_map.get x c.le, c.op with
| None, _ -> | None, _ ->
others := c :: !others; others := c :: !others;
| Some n_t, Leq -> | Some n_t, Leq ->
if Z.sign n_t > 0 then if Z.sign n_t > 0 then
set_l := (n_t,c) :: !set_l set_l := (n_t,c) :: !set_l
else else
set_g := (Z.abs n_t,c) :: !set_g set_g := (n_t,c) :: !set_g
| Some n_t, Eq -> | Some n_t, Eq ->
set_e := (n_t,c) :: !set_e set_e := (n_t,c) :: !set_e
| Some n_t, Eq_mod _ -> | Some n_t, Eq_mod _ ->
set_m := (n_t,c) :: !set_m set_m := (n_t,c) :: !set_m
in in
List.iter classify_constr self.constrs; List.iter classify_constr self.constrs;
self.constrs <- !others; (* remove all constraints involving [t] *)
Log.debugf 50 Log.debugf 50
(fun k-> (fun k->
let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in
k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])" k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])"
A.pp_term t pps !set_e pps !set_l pps !set_g pps !set_m); A.pp_term x pps !set_e pps !set_l pps !set_g pps !set_m);
assert false
(* now apply the algorithm *)
if !set_e <> [] then (
(* case (a): eliminate via an equality. *)
(* pick an equality with a small coeff, if possible *)
let coeff1, c1 =
Iter.of_list !set_e
|> Iter.min_exn ~lt:(fun (n1,_)(n2,_) -> Z.(abs n1 < abs n2))
in
let le1 = LE.(neg @@ remove c1.le x) in
Log.debugf 30
(fun k->k "(@[intsolver.case_a.eqn@ :coeff %a@ :c %a@])"
Z.pp coeff1 pp_constr c1);
let elim_in_constr (coeff2, c2) =
let le2 = LE.(neg @@ remove c2.le x) in
let gcd12 = Z.gcd coeff1 coeff2 in
(* coeff1 × p1 = coeff2 × p2 = lcm = coeff1 × coeff2 / gcd,
because coeff1 × coeff2 = lcm × gcd *)
let lcm12 = Z.(abs coeff1 * abs coeff2 / gcd12) in
let p1 = Z.(lcm12 / coeff1) in
let p2 = Z.(lcm12 / coeff2) in
Log.debugf 50
(fun k->k "(@[intsolver.elim-in-constr@ %a@ :gcd %a :lcm %a@ :p1 %a :p2 %a@])"
pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2);
let c' =
let lits = Bag.append c1.lits c2.lits in
if Z.sign coeff1 <> Z.sign coeff2 then (
let le' = LE.(p1 * le1 + p2 * le2) in
let const' = Z.(p1 * c1.const + p2 * c2.const) in
{op=c2.op; le=le'; const=const'; lits}
) else (
let le' = LE.(p1 * le1 - p2 * le2) in
let const' = Z.(p1 * c1.const - p2 * c2.const) in
let le', const' =
if Z.sign coeff1 < 0 then LE.neg le', Z.neg const'
else le', const'
in
{op=c2.op; le=le'; const=const'; lits}
)
in
add_constr self c'
(* also add a divisibility constraint if needed *)
(* TODO:
if Z.(p1 > one) then (
let c' = {le=le2; op=Eq_mod p1; const=c2.const} in
add_constr self c'
)
*)
in
List.iter elim_in_constr !set_l;
List.iter elim_in_constr !set_g;
List.iter elim_in_constr !set_m;
(* FIXME: handle the congruence *)
) else if !set_l = [] || !set_g = [] then (
(* case (b): no bound on at least one side *)
assert (!set_e=[]);
() (* FIXME: handle the congruence *)
) else (
(* case (c): combine inequalities pairwise *)
let elim_pair (coeff1, c1) (coeff2, c2) : unit =
assert (Z.sign coeff1 > 0 && Z.sign coeff2 < 0);
let le1 = LE.remove c1.le x in
let le2 = LE.remove c2.le x in
let gcd12 = Z.gcd coeff1 coeff2 in
let lcm12 = Z.(coeff1 * abs coeff2 / gcd12) in
let p1 = Z.(lcm12 / coeff1) in
let p2 = Z.(lcm12 / Z.abs coeff2) in
Log.debugf 50
(fun k->k "(@[intsolver.case-b.elim-pair@ L=%a@ G=%a@ \
:gcd %a :lcm %a@ :p1 %a :p2 %a@])"
pp_constr c1 pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2);
let new_ineq =
let le = LE.(p2 * le1 - p1 * le2) in
let const = Z.(p2 * c1.const - p1 * c2.const) in
let lits = Bag.append c1.lits c2.lits in
{op=Leq; le; const; lits}
in
add_constr self new_ineq;
(* TODO: handle modulo constraints *)
in
List.iter (fun x1 -> List.iter (elim_pair x1) !set_g) !set_l;
);
(* now recurse *)
solve_rec self
end end
let check (self:t) : result = let check (self:t) : result =
Log.debugf 10 (fun k->k "(@[intsolver.check@])");
Log.debugf 10 (fun k->k "(@[@{<Yellow>intsolver.check@}@])");
let state = Check_.create self in let state = Check_.create self in
Log.debugf 10
(fun k->k "(@[intsolver.check.stat@ :n-vars %d@ :n-constr %d@])"
(T_map.cardinal state.vars) (List.length state.constrs));
match state.ok with match state.ok with
| Ok () -> | Ok () ->
Check_.solve_rec state Check_.solve_rec state

View file

@ -39,7 +39,7 @@ let rand_n low n : Z.t QC.arbitrary =
QC.map ~rev:ZarithZ.to_int Z.of_int QC.(low -- n) QC.map ~rev:ZarithZ.to_int Z.of_int QC.(low -- n)
(* TODO: fudge *) (* TODO: fudge *)
let rand_z = rand_n (-50) 100 let rand_z = rand_n (-15) 15
module Step = struct module Step = struct
module G = QC.Gen module G = QC.Gen
@ -113,7 +113,7 @@ module Step = struct
| _ -> | _ ->
let gen = let gen =
let+ le = gen_linexp let+ le = gen_linexp
and+ kind = oneofl [`Leq;`Lt;`Eq] and+ kind = frequencyl [5, `Leq; 5, `Lt; 3,`Eq]
and+ n = rand_z.QC.gen in and+ n = rand_z.QC.gen in
vars, (match kind with vars, (match kind with
| `Lt -> S_lt(le,n) | `Lt -> S_lt(le,n)
@ -159,7 +159,7 @@ module Step = struct
let print = Fmt.to_string (Fmt.Dump.list pp_) in let print = Fmt.to_string (Fmt.Dump.list pp_) in
QC.make ~shrink ~print (gen_for n1 n2) QC.make ~shrink ~print (gen_for n1 n2)
let rand : t list QC.arbitrary = rand_for 1 100 let rand : t list QC.arbitrary = rand_for 1 15
end end
let on_propagate _ ~reason:_ = () let on_propagate _ ~reason:_ = ()
@ -272,7 +272,7 @@ let set_stats_maybe ar =
let check_sound = let check_sound =
let ar = let ar =
Step.(rand_for 0 300) Step.(rand_for 0 15)
|> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat")
|> set_stats_maybe |> set_stats_maybe
in in
@ -307,7 +307,7 @@ let prop_backtrack pb =
let check_backtrack = let check_backtrack =
let ar = let ar =
Step.(rand_for 0 300) Step.(rand_for 0 15)
|> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat")
|> set_stats_maybe |> set_stats_maybe
in in
@ -315,25 +315,9 @@ let check_backtrack =
~long_factor:10 ~count:200 ~name:"solver2_backtrack" ~long_factor:10 ~count:200 ~name:"solver2_backtrack"
ar prop_backtrack ar prop_backtrack
let check_scalable =
let prop pb =
let solver = Solver.create () in
add_steps solver pb;
ignore (Solver.check solver : Solver.result);
true
in
let ar =
Step.(rand_for 3_000 5_000)
|> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat")
|> set_stats_maybe
in
QC.Test.make ~long_factor:2 ~count:10 ~name:"solver2_scalable"
ar prop
let props = [ let props = [
check_sound; check_sound;
check_backtrack; check_backtrack;
check_scalable;
] ]
(* regression tests *) (* regression tests *)