sidekick/examples/sudoku/sudoku_solve.ml

403 lines
11 KiB
OCaml
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

(** simple sudoku solver *)
open Sidekick_util
let errorf msg = Fmt.kasprintf failwith msg
module Cell : sig
type t = private int
val equal : t -> t -> bool
val neq : t -> t -> bool
val hash : t -> int
val empty : t
val is_empty : t -> bool
val is_full : t -> bool
val make : int -> t
val pp : t Fmt.printer
end = struct
type t = int
let empty = 0
let[@inline] make i =
assert (i >= 0 && i <= 9);
i
let[@inline] is_empty x = x = 0
let[@inline] is_full x = x > 0
let hash = CCHash.int
let[@inline] equal (a : t) b = a = b
let[@inline] neq (a : t) b = a <> b
let pp out i =
if i = 0 then
Fmt.char out '.'
else
Fmt.int out i
end
module Grid : sig
type t
val get : t -> int -> int -> Cell.t
val set : t -> int -> int -> Cell.t -> t
type set = (int * int * Cell.t) Iter.t
(** A set of related cells *)
val rows : t -> set Iter.t
val cols : t -> set Iter.t
val squares : t -> set Iter.t
val all_cells : t -> (int * int * Cell.t) Iter.t
val parse : string -> t
val is_full : t -> bool
val is_valid : t -> bool
val matches : pat:t -> t -> bool
val pp : t Fmt.printer
end = struct
type t = Cell.t array
let[@inline] get (s : t) i j = s.((i * 9) + j)
let[@inline] set (s : t) i j n =
let s' = Array.copy s in
s'.((i * 9) + j) <- n;
s'
type set = (int * int * Cell.t) Iter.t
(** A set of related cells *)
open Iter.Infix
let all_cells (g : t) =
0 -- 8 >>= fun i ->
0 -- 8 >|= fun j -> i, j, get g i j
let rows (g : t) =
0 -- 8 >|= fun i ->
0 -- 8 >|= fun j -> i, j, get g i j
let cols g =
0 -- 8 >|= fun j ->
0 -- 8 >|= fun i -> i, j, get g i j
let squares g =
0 -- 2 >>= fun sq_i ->
0 -- 2 >|= fun sq_j ->
0 -- 2 >>= fun off_i ->
0 -- 2 >|= fun off_j ->
let i = (3 * sq_i) + off_i in
let j = (3 * sq_j) + off_j in
i, j, get g i j
let is_full g = Array.for_all Cell.is_full g
let is_valid g =
let all_distinct (s : set) =
s
>|= (fun (_, _, c) -> c)
|> Iter.diagonal
|> Iter.for_all (fun (c1, c2) -> Cell.neq c1 c2)
in
(Iter.for_all all_distinct @@ rows g)
&& (Iter.for_all all_distinct @@ cols g)
&& (Iter.for_all all_distinct @@ squares g)
let matches ~pat:g1 g2 : bool =
all_cells g1
|> Iter.filter (fun (_, _, c) -> Cell.is_full c)
|> Iter.for_all (fun (x, y, c) -> Cell.equal c @@ get g2 x y)
let pp out g =
Fmt.fprintf out "@[<v>";
Array.iteri
(fun i n ->
Cell.pp out n;
if i mod 9 = 8 then Fmt.fprintf out "@,")
g;
Fmt.fprintf out "@]"
let parse (s : string) : t =
if String.length s < 81 then
errorf "line is too short, expected 81 chars, not %d" (String.length s);
let a = Array.make 81 Cell.empty in
for i = 0 to 80 do
let c = String.get s i in
let n =
if c = '.' then
0
else
Char.code c - Char.code '0'
in
if n < 0 || n > 9 then errorf "invalid char %c" c;
a.(i) <- Cell.make n
done;
a
end
module B_ref = Sidekick_util.Backtrackable_ref
module Solver : sig
type t
val create : Grid.t -> t
val solve : t -> Grid.t option
end = struct
open Sidekick_core
type Const.view += Cell_is of { x: int; y: int; value: Cell.t }
let ops =
(module struct
let pp out = function
| Cell_is { x; y; value } ->
Fmt.fprintf out "(%d:%d=%a)" x y Cell.pp value
| _ -> ()
let hash = function
| Cell_is { x; y; value } ->
Hash.(combine3 (int x) (int y) (Cell.hash value))
| _ -> assert false
let equal a b =
match a, b with
| Cell_is a, Cell_is b ->
a.x = b.x && a.y = b.y && Cell.equal a.value b.value
| _ -> false
end : Const.DYN_OPS)
module Sat = Sidekick_sat
let mk_cell tst x y value : Term.t =
Term.const tst
@@ Const.make (Cell_is { x; y; value }) ops ~ty:(Term.bool tst)
let mk_cell_lit ?sign tst x y value : Lit.t =
Lit.atom ?sign @@ mk_cell tst x y value
module Theory : sig
type t
val grid : t -> Grid.t
val create : Term.store -> Grid.t -> t
val to_plugin : t -> Sat.plugin
end = struct
type t = { tst: Term.store; grid: Grid.t B_ref.t }
let[@inline] grid self : Grid.t = B_ref.get self.grid
let[@inline] set_grid self g : unit = B_ref.set self.grid g
let push_level self = B_ref.push_level self.grid
let pop_levels self n = B_ref.pop_levels self.grid n
let pp_c_ = Fmt.(list ~sep:(return "@ ")) Lit.pp
let[@inline] logs_conflict kind c : unit =
Log.debugf 4 (fun k -> k "(@[conflict.%s@ %a@])" kind pp_c_ c)
(* check that all cells are full *)
let check_full_ (self : t) (acts : Sat.acts) : unit =
(*let@ () = Profile.with_ "check-full" in*)
let (module A) = acts in
Grid.all_cells (grid self) (fun (x, y, c) ->
if Cell.is_empty c then (
let c =
CCList.init 9 (fun c ->
mk_cell_lit self.tst x y (Cell.make (c + 1)))
in
Log.debugf 4 (fun k -> k "(@[add-clause@ %a@])" pp_c_ c);
A.add_clause ~keep:true c Proof_trace.dummy_step_id
))
(* check constraints *)
let check_ (self : t) (acts : Sat.acts) : unit =
(*let@ () = Profile.with_ "check-constraints" in*)
Log.debugf 4 (fun k ->
k "(@[sudoku.check@ @[:g %a@]@])" Grid.pp (B_ref.get self.grid));
let (module A) = acts in
let[@inline] all_diff kind f =
let pairs =
f (grid self)
|> Iter.flat_map (fun set ->
set
|> Iter.filter (fun (_, _, c) -> Cell.is_full c)
|> Iter.diagonal)
in
pairs (fun ((x1, y1, c1), (x2, y2, c2)) ->
if Cell.equal c1 c2 then (
assert (x1 <> x2 || y1 <> y2);
let c =
[
mk_cell_lit self.tst ~sign:false x1 y1 c1;
mk_cell_lit self.tst ~sign:false x2 y2 c2;
]
in
logs_conflict ("all-diff." ^ kind) c;
A.raise_conflict c Proof_trace.dummy_step_id
))
in
all_diff "rows" Grid.rows;
all_diff "cols" Grid.cols;
all_diff "squares" Grid.squares;
()
let trail_ (acts : Sat.acts) =
let (module A) = acts in
A.iter_assumptions
(* update current grid with the given slice *)
let add_slice (self : t) (acts : Sat.acts) : unit =
let (module A) = acts in
trail_ acts (fun lit ->
match Lit.sign lit, Term.view (Lit.term lit) with
| true, E_const { Const.c_view = Cell_is { x; y; value = c }; _ } ->
assert (Cell.is_full c);
let grid = grid self in
let c' = Grid.get grid x y in
if Cell.is_empty c' then
set_grid self (Grid.set grid x y c)
else if Cell.neq c c' then (
(* conflict: at most one value *)
let c =
[
mk_cell_lit self.tst ~sign:false x y c;
mk_cell_lit self.tst ~sign:false x y c';
]
in
logs_conflict "at-most-one" c;
A.raise_conflict c Proof_trace.dummy_step_id
)
| _ -> ())
let partial_check (self : t) acts : unit =
(* let@ () = Profile.with_ "partial-check" in*)
Log.debugf 4 (fun k ->
k "(@[sudoku.partial-check@ :trail [@[%a@]]@])" (Fmt.iter Lit.pp)
(trail_ acts));
add_slice self acts;
check_ self acts
let final_check (self : t) acts : unit =
(*let@ () = Profile.with_ "final-check" in*)
Log.debugf 4 (fun k -> k "(@[sudoku.final-check@])");
check_full_ self acts;
check_ self acts
let create tst g : t = { tst; grid = B_ref.create g }
let to_plugin (self : t) : Sat.plugin =
Sat.mk_plugin_cdcl_t
~push_level:(fun () -> push_level self)
~pop_levels:(fun n -> pop_levels self n)
~partial_check:(partial_check self) ~final_check:(final_check self) ()
end
type t = { grid0: Grid.t; tst: Term.store; theory: Theory.t; solver: Sat.t }
let solve (self : t) : _ option =
let@ () = Profile.with_ "sudoku.solve" in
let assumptions =
Grid.all_cells self.grid0
|> Iter.filter (fun (_, _, c) -> Cell.is_full c)
|> Iter.map (fun (x, y, c) -> mk_cell_lit self.tst x y c)
|> Iter.to_rev_list
in
Log.debugf 2 (fun k ->
k "(@[sudoku.solve@ :assumptions %a@])" (Fmt.Dump.list Lit.pp)
assumptions);
let r =
match Sat.solve self.solver ~assumptions with
| Sat.Sat _ -> Some (Theory.grid self.theory)
| Sat.Unsat _ -> None
in
(* TODO: print some stats *)
r
let create g : t =
let tst = Term.Store.create () in
let theory = Theory.create tst g in
let plugin : Sat.plugin = Theory.to_plugin theory in
{
tst;
solver = Sat.create ~proof:Proof_trace.dummy plugin;
theory;
grid0 = g;
}
end
let solve_grid (g : Grid.t) : Grid.t option =
let s = Solver.create g in
Solver.solve s
module type CHRONO = sig
val pp_elapsed : Fmt.formatter -> unit
end
let chrono ~pp_time : (module CHRONO) =
let module M = struct
let start = Sys.time ()
let pp_elapsed out =
if pp_time then Fmt.fprintf out " (in %.3fs)" (Sys.time () -. start)
end in
(module M)
let solve_file ~pp_time file =
let@ () = Profile.with_ "solve-file" in
let open (val chrono ~pp_time) in
Format.printf "solve grids in file %S@." file;
let grids =
CCIO.with_in file CCIO.read_lines_l
|> CCList.filter_map (fun s ->
let s = String.trim s in
if s = "" then
None
else (
match Grid.parse s with
| g -> Some g
| exception e ->
errorf "cannot parse sudoku %S: %s@." s (Printexc.to_string e)
))
in
Format.printf "parsed %d grids%t@." (List.length grids) pp_elapsed;
List.iter
(fun g ->
Format.printf
"@[<v>@,#########################@,@[<2>solve grid:@ %a@]@]@." Grid.pp g;
let open (val chrono ~pp_time) in
match solve_grid g with
| None -> Format.printf "no solution%t@." pp_elapsed
| Some g' when not @@ Grid.is_full g' ->
errorf "grid %a@ is not full" Grid.pp g'
| Some g' when not @@ Grid.is_valid g' ->
errorf "grid %a@ is not valid" Grid.pp g'
| Some g' when not @@ Grid.matches ~pat:g g' ->
errorf "grid %a@ @[<2>does not match original@ %a@]" Grid.pp g' Grid.pp
g
| Some g' ->
Format.printf "@[<v>@[<2>solution%t:@ %a@]@,###################@]@."
pp_elapsed Grid.pp g')
grids;
Format.printf "@.solved %d grids%t@." (List.length grids) pp_elapsed;
()
let () =
let@ () = Sidekick_tef.with_setup in
Fmt.set_color_default true;
let files = ref [] in
let debug = ref 0 in
let pp_time = ref true in
let opts =
[
"--debug", Arg.Set_int debug, " debug";
"-d", Arg.Set_int debug, " debug";
"--no-time", Arg.Clear pp_time, " do not print solve time";
]
|> Arg.align
in
Arg.parse opts (fun f -> files := f :: !files) "sudoku_solve [options] <file>";
Log.set_debug !debug;
try List.iter (fun f -> solve_file ~pp_time:!pp_time f) !files
with Failure msg | Invalid_argument msg ->
Format.printf "@{<Red>Error@}:@.%s@." msg;
exit 1