diff --git a/src/intsolver/dune b/src/intsolver/dune new file mode 100644 index 00000000..04779850 --- /dev/null +++ b/src/intsolver/dune @@ -0,0 +1,6 @@ +(library + (name sidekick_intsolver) + (public_name sidekick.intsolver) + (synopsis "Simple integer solver") + (flags :standard -warn-error -a+8 -w -32 -open Sidekick_util) + (libraries containers sidekick.core sidekick.arith)) diff --git a/src/intsolver/sidekick_intsolver.ml b/src/intsolver/sidekick_intsolver.ml new file mode 100644 index 00000000..f5e4fd1d --- /dev/null +++ b/src/intsolver/sidekick_intsolver.ml @@ -0,0 +1,304 @@ + +module type ARG = sig + module Z : Sidekick_arith.INT + + type term + type lit + + val pp_term : term Fmt.printer + val pp_lit : lit Fmt.printer + + module T_map : CCMap.S with type key = term +end + +module type S = sig + module A : ARG + + module Op : sig + type t = + | Leq + | Lt + | Eq + val pp : t Fmt.printer + end + + type t + + val create : unit -> t + + val push_level : t -> unit + + val pop_levels : t -> int -> unit + + val assert_ : + t -> + (A.Z.t * A.term) list -> Op.t -> A.Z.t -> + lit:A.lit -> + unit + + val define : + t -> + A.term -> + (A.Z.t * A.term) list -> + unit + + module Cert : sig + type t + val pp : t Fmt.printer + + val lits : t -> A.lit Iter.t + end + + module Model : sig + type t + val pp : t Fmt.printer + + val eval : t -> A.term -> A.Z.t option + end + + type result = + | Sat of Model.t + | Unsat of Cert.t + + val pp_result : result Fmt.printer + + val check : t -> result + + (**/**) + val _check_invariants : t -> unit + (**/**) +end + + + +module Make(A : ARG) + : S with module A = A += struct + module BVec = Backtrack_stack + module A = A + open A + + module Op = struct + type t = + | Leq + | Lt + | Eq + + let pp out = function + | Leq -> Fmt.string out "<=" + | Lt -> Fmt.string out "<" + | Eq -> Fmt.string out "=" + end + + module Linexp = struct + type t = Z.t T_map.t + let is_empty = T_map.is_empty + let empty : t = T_map.empty + + let pp out (self:t) : unit = + let pp_pair out (t,z) = Fmt.fprintf out "%a ยท %a" Z.pp z A.pp_term t in + if is_empty self then Fmt.string out "0" + else Fmt.fprintf out "(@[+@ %a@])" + Fmt.(iter ~sep:(return "@ ") pp_pair) (T_map.to_iter self) + + let iter = T_map.iter + let return t : t = T_map.add t Z.one empty + let neg self : t = T_map.map Z.neg self + let mult n self = + if Z.(n = zero) then empty + else T_map.map (fun c -> Z.(c * n)) self + + let add (self:t) (c:Z.t) (t:term) : t = + let n = Z.(c + T_map.get_or ~default:Z.zero t self) in + if Z.(n = zero) + then T_map.remove t self + else T_map.add t n self + + let merge (self:t) (other:t) : t = + T_map.fold + (fun t c m -> add m c t) + other self + + let of_list l : t = + List.fold_left (fun self (c,t) -> add self c t) empty l + + (* map each term to a linexp *) + let flat_map f (self:t) : t = + T_map.fold + (fun t c m -> + let t_le = mult c (f t) in + merge m t_le + ) + empty self + end + + module Cert = struct + type t = unit + let pp = Fmt.unit + + let lits _ = Iter.empty (* TODO *) + end + + module Model = struct + type t = { + m: Z.t T_map.t; + } [@@unboxed] + + let pp out self = + let pp_pair out (t,z) = Fmt.fprintf out "(@[%a := %a@])" A.pp_term t Z.pp z in + Fmt.fprintf out "(@[model@ %a@])" + Fmt.(iter ~sep:(return "@ ") pp_pair) (T_map.to_iter self.m) + + let empty : t = {m=T_map.empty} + + let eval (self:t) t : Z.t option = T_map.get t self.m + end + + module Constr = struct + type t = { + le: Linexp.t; + const: Z.t; + op: Op.t; + lits: lit Bag.t; + } + + let pp out self = + Fmt.fprintf out "(@[%a@ %a %a@])" Linexp.pp self.le Op.pp self.op Z.pp self.const + end + + type t = { + defs: (term * Linexp.t) BVec.t; + cs: Constr.t BVec.t; + } + + let create() : t = + { defs=BVec.create(); + cs=BVec.create(); } + + let push_level self = + BVec.push_level self.defs; + BVec.push_level self.cs; + () + + let pop_levels self n = + BVec.pop_levels self.defs n ~f:(fun _ -> ()); + BVec.pop_levels self.cs n ~f:(fun _ -> ()); + () + + type result = + | Sat of Model.t + | Unsat of Cert.t + + let pp_result out = function + | Sat m -> Fmt.fprintf out "(@[SAT@ %a@])" Model.pp m + | Unsat cert -> Fmt.fprintf out "(@[UNSAT@ %a@])" Cert.pp cert + + let assert_ (self:t) l op c ~lit : unit = + let le = Linexp.of_list l in + let c = {Constr.le; const=c; op; lits=Bag.return lit} in + Log.debugf 10 (fun k->k "(@[sidekick.intsolver.assert@ %a@])" Constr.pp c); + BVec.push self.cs c + + (* TODO: check before hand that [t] occurs nowhere else *) + let define (self:t) t l : unit = + let le = Linexp.of_list l in + BVec.push self.defs (t,le) + + (* #### checking #### *) + + module Check_ = struct + module LE = Linexp + + type op = + | Leq + | Lt + | Eq + | Eq_mod of { + prime: Z.t; + pow: int; + } (* modulo prime^pow *) + + type constr = { + le: LE.t; + const: Z.t; + op: op; + lits: lit Bag.t; + } + + type state = { + mutable rw: LE.t T_map.t; (* rewrite rules *) + mutable vars: int T_map.t; (* variables in at least one constraint *) + mutable constrs: constr list; + } + (* main solving state. mutable, but copied for backtracking. + invariant: variables in [rw] do not occur anywhere else + *) + + (* perform rewriting on the linear expression *) + let norm_le (self:state) (le:LE.t) : LE.t = + LE.flat_map + (fun t -> try T_map.find t self.rw with Not_found -> LE.return t) + le + + let[@inline] count_v self t : int = T_map.get_or ~default:0 t self.vars + let[@inline] incr_v (self:state) (t:term) : unit = + self.vars <- T_map.add t (1 + count_v self t) self.vars + let decr_v (self:state) (t:term) : unit = + let n = count_v self t - 1 in + assert (n >= 0); + self.vars <- + (if n=0 then T_map.remove t self.vars + else T_map.add t n self.vars) + + let add_constr (self:state) (c:constr) = + let c = {c with le=norm_le self c.le } in + LE.iter (fun t _ -> incr_v self t) c.le; + self.constrs <- c :: self.constrs + + let remove_constr (self:state) (c:constr) = + LE.iter (fun t _ -> decr_v self t) c.le + + let create (self:t) : state = + let state = { + vars=T_map.empty; + rw=T_map.empty; + constrs=[]; + } in + BVec.iter self.defs + ~f:(fun (v,le) -> + assert (not (T_map.mem v state.rw)); + state.rw <- T_map.add v (norm_le state le) state.rw); + BVec.iter self.cs + ~f:(fun (c:Constr.t) -> + let {Constr.le; op; const; lits} = c in + let op = match op with + | Op.Eq -> Eq + | Op.Leq -> Leq + | Op.Lt -> Lt + in + let c = {le;const;lits;op} in + add_constr state c + ); + state + + let rec solve_rec (self:state) : result = + begin match T_map.choose_opt self.vars with + | None -> + let m = Model.empty in + Sat m (* TODO: model *) + + | Some (t, _) -> + Log.debugf 30 (fun k->k "(@[intsolver.elim-var@ %a@])" A.pp_term t); + assert false + + end + + end + + let check (self:t) : result = + Log.debugf 10 (fun k->k "(@[intsolver.check@])"); + let state = Check_.create self in + Check_.solve_rec state + + let _check_invariants _ = () +end diff --git a/src/intsolver/tests/dune b/src/intsolver/tests/dune new file mode 100644 index 00000000..3f2210d0 --- /dev/null +++ b/src/intsolver/tests/dune @@ -0,0 +1,15 @@ + +(library + (name sidekick_test_intsolver) + (libraries zarith sidekick.intsolver sidekick.util sidekick.zarith + qcheck alcotest)) + +;(rule +; (targets sidekick_test_intsolver.ml) +; (enabled_if (>= %{ocaml_version} 4.08.0)) +; (action (copy test_intsolver.real.ml %{targets}))) +; +;(rule +; (targets sidekick_test_intsolver.ml) +; (enabled_if (< %{ocaml_version} 4.08.0)) +; (action (with-stdout-to %{targets} (echo "let props=[];; let tests=\"intsolver\",[]")))) diff --git a/src/intsolver/tests/sidekick_test_intsolver.ml b/src/intsolver/tests/sidekick_test_intsolver.ml new file mode 100644 index 00000000..ed025f36 --- /dev/null +++ b/src/intsolver/tests/sidekick_test_intsolver.ml @@ -0,0 +1,359 @@ + +open CCMonomorphic + +module Fmt = CCFormat +module QC = QCheck +module Log = Sidekick_util.Log +let spf = Printf.sprintf + +module ZarithZ = Z +module Z = Sidekick_zarith.Int + +module Var = struct + include CCInt + + let pp out x = Format.fprintf out "X_%d" x + + let rand n : t QC.arbitrary = QC.make ~print:(Fmt.to_string pp) @@ QC.Gen.(0--n) + type lit = int + let pp_lit = Fmt.int + let not_lit i = Some (- i) +end + +module Var_map = CCMap.Make(Var) + +module Solver = Sidekick_intsolver.Make(struct + module Z = Z + type term = Var.t + let pp_term = Var.pp + type lit = Var.lit + let pp_lit = Var.pp_lit + module T_map = Var_map + end) + +let unwrap_opt_ msg = function + | Some x -> x + | None -> failwith msg + +let rand_n low n : Z.t QC.arbitrary = + QC.map ~rev:ZarithZ.to_int Z.of_int QC.(low -- n) + +let rand_z = rand_n (-1000) 30_000 + +module Step = struct + module G = QC.Gen + + type linexp = (Z.t * Var.t) list + + type t = + | S_new_var of Var.t + | S_define of Var.t * (Z.t * Var.t) list + | S_leq of linexp * Z.t + | S_lt of linexp * Z.t + | S_eq of linexp * Z.t + + let pp_le out (le:linexp) = + let pp_pair out (n,x) = + if Z.equal Z.one n then Var.pp out x + else Fmt.fprintf out "%a . %a" Z.pp n Var.pp x in + Fmt.fprintf out "(@[%a@])" + Fmt.(list ~sep:(return " +@ ") pp_pair) le + + let pp_ out = function + | S_new_var v -> Fmt.fprintf out "(@[new-var %a@])" Var.pp v + | S_define (v,le) -> Fmt.fprintf out "(@[define %a@ := %a@])" Var.pp v pp_le le + | S_leq (le,n) -> Fmt.fprintf out "(@[upper %a <= %a@])" pp_le le Z.pp n + | S_lt (le,n) -> Fmt.fprintf out "(@[upper %a < %a@])" pp_le le Z.pp n + | S_eq (le,n) -> Fmt.fprintf out "(@[lower %a > %a@])" pp_le le Z.pp n + + (* check that a sequence is well formed *) + let well_formed (l:t list) : bool = + let rec aux vars = function + | [] -> true + | S_new_var v :: tl -> + not (List.mem v vars) && aux (v::vars) tl + | (S_leq (le,_) | S_lt (le,_) | S_eq (le,_)) :: tl -> + List.for_all (fun (_,x) -> List.mem x vars) le && aux vars tl + | S_define (x,le) :: tl-> + not (List.mem x vars) && + List.for_all (fun (_,y) -> List.mem y vars) le && + aux (x::vars) tl + in + aux [] l + + let shrink_step self = + let module S = QC.Shrink in + match self with + | S_new_var _ + | S_leq _ | S_lt _ | S_eq _ -> QC.Iter.empty + | S_define (x, le) -> + let open QC.Iter in + let* le = S.list le in + if List.length le >= 2 then return (S_define (x,le)) else empty + + let rand_steps (n:int) : t list QC.Gen.t = + let open G in + let rec aux n vars acc = + if n<=0 then return (List.rev acc) + else ( + let gen_linexp = + let* vars' = G.shuffle_l vars in + let* n = 1 -- List.length vars' in + let vars' = CCList.take n vars' in + assert (List.length vars' = n); + let* coeffs = list_repeat n rand_z.gen in + return (List.combine coeffs vars') + in + let* vars, proof_rule = + frequency @@ List.flatten [ + (* add a constraint *) + (match vars with + | [] -> [] + | _ -> + let gen = + let+ le = gen_linexp + and+ kind = oneofl [`Leq;`Lt;`Eq] + and+ n = rand_z.QC.gen in + vars, (match kind with + | `Lt -> S_lt(le,n) + | `Leq -> S_leq(le,n) + | `Eq -> S_eq(le,n) + ) + in + [6, gen]); + (* make a new non-basic var *) + (let gen = + let v = List.length vars in + return ((v::vars), S_new_var v) + in + [2, gen]); + (* make a definition *) + (if List.length vars>2 + then ( + let v = List.length vars in + let gen = + let+ le = gen_linexp in + v::vars, S_define (v, le) + in + [5, gen] + ) else []); + ] + in + aux (n-1) vars (proof_rule::acc) + ) + in + aux n [] [] + + (* shrink a list but keep it well formed *) + let shrink : t list QC.Shrink.t = + QC.Shrink.(filter well_formed @@ list ~shrink:shrink_step) + + let gen_for n1 n2 = + let open G in + assert (n1 < n2); + let* n = n1 -- n2 in + rand_steps n + + let rand_for n1 n2 : t list QC.arbitrary = + let print = Fmt.to_string (Fmt.Dump.list pp_) in + QC.make ~shrink ~print (gen_for n1 n2) + + let rand : t list QC.arbitrary = rand_for 1 100 +end + +let on_propagate _ ~reason:_ = () + +(* add a single proof_rule to the solvere *) +let add_step solver (s:Step.t) : unit = + begin match s with + | Step.S_new_var _v -> () + | Step.S_leq (le,n) -> + Solver.assert_ solver le Solver.Op.Leq n ~lit:0 + | Step.S_lt (le,n) -> + Solver.assert_ solver le Solver.Op.Lt n ~lit:0 + | Step.S_eq (le,n) -> + Solver.assert_ solver le Solver.Op.Eq n ~lit:0 + | Step.S_define (x,le) -> + Solver.define solver x le + end + +let add_steps ?(f=fun()->()) (solver:Solver.t) l : unit = + f(); + List.iter + (fun s -> add_step solver s; f()) + l + +(* is this solver's state sat? *) +let check_solver_is_sat solver : bool = + match Solver.check solver with + | Solver.Sat _ -> true + | Solver.Unsat _ -> false + +(* is this problem sat? *) +let check_pb_is_sat pb : bool = + let solver = Solver.create() in + add_steps solver pb; + check_solver_is_sat solver + +(* basic debug printer for Q.t *) +let str_z n = ZarithZ.to_string n + +let prop_sound ?(inv=false) pb = + let solver = Solver.create () in + begin match + add_steps solver pb; + Solver.check solver + with + | Sat model -> + + let get_val v = + match Solver.Model.eval model v with + | Some n -> n + | None -> assert false + in + + let eval_le le = + List.fold_left (fun s (n,y) -> Z.(s + n * get_val y)) Z.zero le + in + + let check_step s = + (try + if inv then Solver._check_invariants solver; + match s with + | Step.S_new_var _ -> () + | Step.S_define (x, le) -> + let v_x = get_val x in + let v_le = eval_le le in + if Z.(v_x <> v_le) then ( + failwith (spf "bad def (X_%d): val(x)=%s, val(expr)=%s" x (str_z v_x)(str_z v_le)) + ); + | Step.S_lt (x, n) -> + let v_x = eval_le x in + if Z.(v_x >= n) then failwith (spf "val=%s, n=%s"(str_z v_x)(str_z n)) + | Step.S_leq (x, n) -> + let v_x = eval_le x in + if Z.(v_x > n) then failwith (spf "val=%s, n=%s"(str_z v_x)(str_z n)) + | Step.S_eq (x, n) -> + let v_x = eval_le x in + if Z.(v_x <> n) then failwith (spf "val=%s, n=%s"(str_z v_x)(str_z n)) + with e -> + QC.Test.fail_reportf "proof_rule failed: %a@.exn:@.%s@." + Step.pp_ s (Printexc.to_string e) + ); + if inv then Solver._check_invariants solver; + true + in + List.for_all check_step pb + + | Solver.Unsat _cert -> + (* FIXME: + Solver._check_cert cert; + *) + true + end + +(* a bunch of useful stats for a problem *) +let steps_stats = [ + "n-define", Step.(List.fold_left (fun n -> function S_define _ -> n+1 | _->n) 0); + "n-bnd", + Step.(List.fold_left + (fun n -> function (S_leq _ | S_lt _ | S_eq _) -> n+1 | _->n) 0); + "n-vars", + Step.(List.fold_left + (fun n -> function S_define _ | S_new_var _ -> n+1 | _ -> n) 0); +] + +let enable_stats = + match Sys.getenv_opt "TEST_STAT" with Some("1"|"true") -> true | _ -> false + +let set_stats_maybe ar = + if enable_stats then QC.set_stats steps_stats ar else ar + +let check_sound = + let ar = + Step.(rand_for 0 300) + |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") + |> set_stats_maybe + in + QC.Test.make ~long_factor:10 ~count:500 ~name:"solver2_sound" ar prop_sound + +let prop_backtrack pb = + let solver = Solver.create () in + let stack = Stack.create() in + let res = ref true in + begin try + List.iter + (fun s -> + let is_sat = check_solver_is_sat solver in + Solver.push_level solver; + Stack.push is_sat stack; + if not is_sat then (res := false; raise Exit); + add_step solver s; + ) + pb; + with Exit -> () + end; + res := !res && check_solver_is_sat solver; + Log.debugf 50 (fun k->k "res=%b, expected=%b" !res (check_pb_is_sat pb)); + assert CCBool.(equal !res (check_pb_is_sat pb)); + (* now backtrack and check at each level *) + while not (Stack.is_empty stack) do + let res = Stack.pop stack in + Solver.pop_levels solver 1; + assert CCBool.(equal res (check_solver_is_sat solver)) + done; + true + +let check_backtrack = + let ar = + Step.(rand_for 0 300) + |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") + |> set_stats_maybe + in + QC.Test.make + ~long_factor:10 ~count:200 ~name:"solver2_backtrack" + ar prop_backtrack + +let check_scalable = + let prop pb = + let solver = Solver.create () in + add_steps solver pb; + ignore (Solver.check solver : Solver.result); + true + in + let ar = + Step.(rand_for 3_000 5_000) + |> QC.set_collect (fun pb -> if check_pb_is_sat pb then "sat" else "unsat") + |> set_stats_maybe + in + QC.Test.make ~long_factor:2 ~count:10 ~name:"solver2_scalable" + ar prop + +let props = [ + check_sound; + check_backtrack; + check_scalable; +] + +(* regression tests *) + +module Reg = struct + let alco_mk name f = name, `Quick, f + + let reg_prop_sound ?inv name l = + alco_mk name @@ fun () -> + if not (prop_sound ?inv l) then Alcotest.fail "fail"; + () + + let reg_prop_backtrack name l = + alco_mk name @@ fun () -> + if not (prop_backtrack l) then Alcotest.fail "fail"; + () + + open Step + let tests = [ + ] +end + +let tests = + "solver", List.flatten [ Reg.tests ]