From 10c8006597a3ebc27629ef46ea5d6aec1b99c045 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Mon, 31 Jan 2022 11:08:20 -0500 Subject: [PATCH] intsolver: partial implementation --- src/intsolver/sidekick_intsolver.ml | 166 ++++++++++++++++-- .../tests/sidekick_test_intsolver.ml | 26 +-- 2 files changed, 152 insertions(+), 40 deletions(-) diff --git a/src/intsolver/sidekick_intsolver.ml b/src/intsolver/sidekick_intsolver.ml index 69c10a49..853a1817 100644 --- a/src/intsolver/sidekick_intsolver.ml +++ b/src/intsolver/sidekick_intsolver.ml @@ -205,6 +205,8 @@ module Make(A : ARG) let return t : t = T_map.add t Z.one empty let neg self : t = T_map.map Z.neg self let mem self t : bool = T_map.mem t self + let remove self t = T_map.remove t self + let find_exn self t = T_map.find t self let mult n self = if Z.(n = zero) then empty else T_map.map (fun c -> Z.(c * n)) self @@ -230,7 +232,12 @@ module Make(A : ARG) let t_le = mult c (f t) in merge m t_le ) - empty self + self empty + + let (+) = merge + let ( * ) = mult + let ( ~- ) = neg + let (-) a b = a + ~- b end module Cert = struct @@ -300,7 +307,7 @@ module Make(A : ARG) let assert_ (self:t) l op c ~lit : unit = let le = Linexp.of_list l in let c = {Constr.le; const=c; op; lits=Bag.return lit} in - Log.debugf 10 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c); + Log.debugf 15 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c); BVec.push self.cs c (* TODO: check before hand that [t] occurs nowhere else *) @@ -349,9 +356,13 @@ module Make(A : ARG) let[@inline] is_ok_ self = CCResult.is_ok self.ok (* perform rewriting on the linear expression *) - let norm_le (self:state) (le:LE.t) : LE.t = + let rec norm_le (self:state) (le:LE.t) : LE.t = LE.flat_map - (fun t -> try T_map.find t self.rw with Not_found -> LE.return t) + (fun t -> + begin match T_map.find t self.rw with + | le -> norm_le self le + | exception Not_found -> LE.return t + end) le let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars @@ -359,6 +370,12 @@ module Make(A : ARG) let[@inline] incr_v (self:state) (t:term) : unit = self.vars <- T_map.add t (1 + count_v self t) self.vars + (* GCD of the coefficients of this linear expression *) + let gcd_coeffs (le:LE.t) : Z.t = + match T_map.choose_opt le with + | None -> Z.one + | Some (_, z0) -> T_map.fold (fun _ z m -> Z.gcd z m) le z0 + let decr_v (self:state) (t:term) : unit = let n = count_v self t - 1 in assert (n >= 0); @@ -377,9 +394,7 @@ module Make(A : ARG) let const = match c.op with | Leq -> (* round down, regardless of sign *) - let cplus = Z.abs c.const in - let cplus = Z.(cplus / c_gcd) in - if Z.sign c.const >= 0 then cplus else Z.neg cplus + Z.ediv c.const c_gcd | Eq | Eq_mod _ -> if Z.equal (Z.rem c.const c_gcd) Z.zero then ( (* compatible constant *) @@ -408,6 +423,7 @@ module Make(A : ARG) let c = {c with le=norm_le self c.le } in match simplify_constr c with | Ok c -> + Log.debugf 50 (fun k->k "(@[intsolver.add-constr@ %a@])" pp_constr c); LE.iter (fun t _ -> incr_v self t) c.le; self.constrs <- c :: self.constrs | Error () -> @@ -427,7 +443,10 @@ module Make(A : ARG) BVec.iter self.defs ~f:(fun (v,le) -> assert (not (T_map.mem v state.rw)); - state.rw <- T_map.add v (norm_le state le) state.rw); + (* normalize as much as we can now *) + let le = norm_le state le in + Log.debugf 50 (fun k->k "(@[intsolver.add-rw %a@ := %a@])" pp_term v LE.pp le); + state.rw <- T_map.add v le state.rw); BVec.iter self.cs ~f:(fun (c:Constr.t) -> let {Constr.le; op; const; lits} = c in @@ -451,13 +470,13 @@ module Make(A : ARG) end - and elim_var_ self t : result = - Log.debugf 30 - (fun k->k "(@[intsolver.elim-var@ %a@ :remaining %d@])" - A.pp_term t (T_map.cardinal self.vars)); + and elim_var_ self (x:term) : result = + Log.debugf 20 + (fun k->k "(@[@{intsolver.elim-var@}@ %a@ :remaining %d@])" + A.pp_term x (T_map.cardinal self.vars)); - assert (not (T_map.mem t self.rw)); (* woudl have been rewritten away *) - self.vars <- T_map.remove t self.vars; + assert (not (T_map.mem x self.rw)); (* would have been rewritten away *) + self.vars <- T_map.remove x self.vars; (* gather the sets *) let set_e = ref [] in (* eqns *) @@ -467,32 +486,141 @@ module Make(A : ARG) let others = ref [] in let classify_constr (c:constr) = - match T_map.get t c.le, c.op with + match T_map.get x c.le, c.op with | None, _ -> others := c :: !others; | Some n_t, Leq -> if Z.sign n_t > 0 then set_l := (n_t,c) :: !set_l else - set_g := (Z.abs n_t,c) :: !set_g + set_g := (n_t,c) :: !set_g | Some n_t, Eq -> set_e := (n_t,c) :: !set_e | Some n_t, Eq_mod _ -> set_m := (n_t,c) :: !set_m in + List.iter classify_constr self.constrs; + self.constrs <- !others; (* remove all constraints involving [t] *) Log.debugf 50 (fun k-> let pps = Fmt.Dump.(list @@ pair Z.pp pp_constr) in k "(@[intsolver.classify.for %a@ E=%a@ L=%a@ G=%a@ M=%a@])" - A.pp_term t pps !set_e pps !set_l pps !set_g pps !set_m); - assert false + A.pp_term x pps !set_e pps !set_l pps !set_g pps !set_m); + + (* now apply the algorithm *) + if !set_e <> [] then ( + (* case (a): eliminate via an equality. *) + + (* pick an equality with a small coeff, if possible *) + let coeff1, c1 = + Iter.of_list !set_e + |> Iter.min_exn ~lt:(fun (n1,_)(n2,_) -> Z.(abs n1 < abs n2)) + in + + let le1 = LE.(neg @@ remove c1.le x) in + + Log.debugf 30 + (fun k->k "(@[intsolver.case_a.eqn@ :coeff %a@ :c %a@])" + Z.pp coeff1 pp_constr c1); + + let elim_in_constr (coeff2, c2) = + let le2 = LE.(neg @@ remove c2.le x) in + + let gcd12 = Z.gcd coeff1 coeff2 in + (* coeff1 × p1 = coeff2 × p2 = lcm = coeff1 × coeff2 / gcd, + because coeff1 × coeff2 = lcm × gcd *) + let lcm12 = Z.(abs coeff1 * abs coeff2 / gcd12) in + let p1 = Z.(lcm12 / coeff1) in + let p2 = Z.(lcm12 / coeff2) in + Log.debugf 50 + (fun k->k "(@[intsolver.elim-in-constr@ %a@ :gcd %a :lcm %a@ :p1 %a :p2 %a@])" + pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2); + + let c' = + let lits = Bag.append c1.lits c2.lits in + if Z.sign coeff1 <> Z.sign coeff2 then ( + let le' = LE.(p1 * le1 + p2 * le2) in + let const' = Z.(p1 * c1.const + p2 * c2.const) in + {op=c2.op; le=le'; const=const'; lits} + ) else ( + let le' = LE.(p1 * le1 - p2 * le2) in + let const' = Z.(p1 * c1.const - p2 * c2.const) in + let le', const' = + if Z.sign coeff1 < 0 then LE.neg le', Z.neg const' + else le', const' + in + {op=c2.op; le=le'; const=const'; lits} + ) + in + add_constr self c' + + (* also add a divisibility constraint if needed *) + (* TODO: + if Z.(p1 > one) then ( + let c' = {le=le2; op=Eq_mod p1; const=c2.const} in + add_constr self c' + ) + *) + in + List.iter elim_in_constr !set_l; + List.iter elim_in_constr !set_g; + List.iter elim_in_constr !set_m; + + (* FIXME: handle the congruence *) + ) else if !set_l = [] || !set_g = [] then ( + (* case (b): no bound on at least one side *) + assert (!set_e=[]); + + () (* FIXME: handle the congruence *) + ) else ( + (* case (c): combine inequalities pairwise *) + + let elim_pair (coeff1, c1) (coeff2, c2) : unit = + assert (Z.sign coeff1 > 0 && Z.sign coeff2 < 0); + + let le1 = LE.remove c1.le x in + let le2 = LE.remove c2.le x in + + let gcd12 = Z.gcd coeff1 coeff2 in + let lcm12 = Z.(coeff1 * abs coeff2 / gcd12) in + + let p1 = Z.(lcm12 / coeff1) in + let p2 = Z.(lcm12 / Z.abs coeff2) in + + Log.debugf 50 + (fun k->k "(@[intsolver.case-b.elim-pair@ L=%a@ G=%a@ \ + :gcd %a :lcm %a@ :p1 %a :p2 %a@])" + pp_constr c1 pp_constr c2 Z.pp gcd12 Z.pp lcm12 Z.pp p1 Z.pp p2); + + let new_ineq = + let le = LE.(p2 * le1 - p1 * le2) in + let const = Z.(p2 * c1.const - p1 * c2.const) in + let lits = Bag.append c1.lits c2.lits in + {op=Leq; le; const; lits} + in + + add_constr self new_ineq; + (* TODO: handle modulo constraints *) + + in + + List.iter (fun x1 -> List.iter (elim_pair x1) !set_g) !set_l; + ); + + (* now recurse *) + solve_rec self end let check (self:t) : result = - Log.debugf 10 (fun k->k "(@[intsolver.check@])"); + + Log.debugf 10 (fun k->k "(@[@{intsolver.check@}@])"); let state = Check_.create self in + Log.debugf 10 + (fun k->k "(@[intsolver.check.stat@ :n-vars %d@ :n-constr %d@])" + (T_map.cardinal state.vars) (List.length state.constrs)); + match state.ok with | Ok () -> Check_.solve_rec state diff --git a/src/intsolver/tests/sidekick_test_intsolver.ml b/src/intsolver/tests/sidekick_test_intsolver.ml index 892538df..98ec09ca 100644 --- a/src/intsolver/tests/sidekick_test_intsolver.ml +++ b/src/intsolver/tests/sidekick_test_intsolver.ml @@ -39,7 +39,7 @@ let rand_n low n : Z.t QC.arbitrary = QC.map ~rev:ZarithZ.to_int Z.of_int QC.(low -- n) (* TODO: fudge *) -let rand_z = rand_n (-50) 100 +let rand_z = rand_n (-15) 15 module Step = struct module G = QC.Gen @@ -113,7 +113,7 @@ module Step = struct | _ -> let gen = let+ le = gen_linexp - and+ kind = oneofl [`Leq;`Lt;`Eq] + and+ kind = frequencyl [5, `Leq; 5, `Lt; 3,`Eq] and+ n = rand_z.QC.gen in vars, (match kind with | `Lt -> S_lt(le,n) @@ -159,7 +159,7 @@ module Step = struct let print = Fmt.to_string (Fmt.Dump.list pp_) in QC.make ~shrink ~print (gen_for n1 n2) - let rand : t list QC.arbitrary = rand_for 1 100 + let rand : t list QC.arbitrary = rand_for 1 15 end let on_propagate _ ~reason:_ = () @@ -272,7 +272,7 @@ let set_stats_maybe ar = let check_sound = let ar = - Step.(rand_for 0 300) + Step.(rand_for 0 15) |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") |> set_stats_maybe in @@ -307,7 +307,7 @@ let prop_backtrack pb = let check_backtrack = let ar = - Step.(rand_for 0 300) + Step.(rand_for 0 15) |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") |> set_stats_maybe in @@ -315,25 +315,9 @@ let check_backtrack = ~long_factor:10 ~count:200 ~name:"solver2_backtrack" ar prop_backtrack -let check_scalable = - let prop pb = - let solver = Solver.create () in - add_steps solver pb; - ignore (Solver.check solver : Solver.result); - true - in - let ar = - Step.(rand_for 3_000 5_000) - |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") - |> set_stats_maybe - in - QC.Test.make ~long_factor:2 ~count:10 ~name:"solver2_scalable" - ar prop - let props = [ check_sound; check_backtrack; - check_scalable; ] (* regression tests *)