diff --git a/src/core/Backtrackable_ref.ml b/src/core/Backtrackable_ref.ml new file mode 100644 index 00000000..bc91cfd5 --- /dev/null +++ b/src/core/Backtrackable_ref.ml @@ -0,0 +1,29 @@ + +type 'a t = { + mutable cur: 'a; + stack: 'a Vec.t; + copy: ('a -> 'a) option; +} + +let create ?copy x: _ t = + {cur=x; stack=Vec.create(); copy} + +let[@inline] get self = self.cur +let[@inline] set self x = self.cur <- x +let[@inline] update self f = self.cur <- f self.cur + +let[@inline] n_levels self = Vec.size self.stack + +let[@inline] push_level self : unit = + let x = self.cur in + let x = match self.copy with None -> x | Some f -> f x in + Vec.push self.stack x + +let pop_levels self n : unit = + assert (n>=0); + if n > Vec.size self.stack then invalid_arg "Backtrackable_ref.pop_levels"; + let i = Vec.size self.stack-n in + let x = Vec.get self.stack i in + self.cur <- x; + Vec.shrink self.stack i; + () diff --git a/src/core/Backtrackable_ref.mli b/src/core/Backtrackable_ref.mli new file mode 100644 index 00000000..a1755115 --- /dev/null +++ b/src/core/Backtrackable_ref.mli @@ -0,0 +1,30 @@ + +(** {1 Backtrackable ref} *) + +type 'a t + +val create : ?copy:('a -> 'a) -> 'a -> 'a t +(** Create a backtrackable reference holding the given value initially. + @param copy if provided, will be used to copy the value when [push_level] + is called. *) + +val set : 'a t -> 'a -> unit +(** Set the reference's current content *) + +val get : 'a t -> 'a +(** Get the reference's current content *) + +val update : 'a t -> ('a -> 'a) -> unit +(** Update the reference's current content *) + +val push_level : _ t -> unit +(** Push a backtracking level, copying the current value on top of some + stack. The [copy] function will be used if it was provided in {!create}. *) + +val n_levels : _ t -> int +(** Number of saved values *) + +val pop_levels : _ t -> int -> unit +(** Pop [n] levels, restoring to the value the reference was storing [n] calls + to [push_level] earlier. + @raise Invalid_argument if [n] is bigger than [n_levels]. *) diff --git a/src/core/Msat.ml b/src/core/Msat.ml index 1dbd99a3..be7f6ce4 100644 --- a/src/core/Msat.ml +++ b/src/core/Msat.ml @@ -49,6 +49,8 @@ module Make_mcsat = Solver.Make_mcsat module Make_cdcl_t = Solver.Make_cdcl_t module Make_pure_sat = Solver.Make_pure_sat +module Backtrackable_ref = Backtrackable_ref + (**/**) module Vec = Vec module Log = Log diff --git a/src/sudoku/sudoku_solve.ml b/src/sudoku/sudoku_solve.ml index 72dc1e6b..d69026eb 100644 --- a/src/sudoku/sudoku_solve.ml +++ b/src/sudoku/sudoku_solve.ml @@ -125,36 +125,7 @@ end = struct a end -(** Backtrackable ref *) -module B_ref : sig - type 'a t - val create : 'a -> 'a t - val set : 'a t -> 'a -> unit - val get : 'a t -> 'a - val update : 'a t -> ('a -> 'a) -> unit - val push_level : _ t -> unit - val pop_levels : _ t -> int -> unit -end = struct - type 'a t = { - mutable cur: 'a; - stack: 'a Vec.t; - } - - let create x: _ t = {cur=x; stack=Vec.create()} - - let[@inline] get self = self.cur - let[@inline] set self x = self.cur <- x - let[@inline] update self f = self.cur <- f self.cur - - let[@inline] push_level self : unit = Vec.push self.stack self.cur - let pop_levels self n : unit = - assert (n>=0 && n <= Vec.size self.stack); - let i = Vec.size self.stack-n in - let x = Vec.get self.stack i in - self.cur <- x; - Vec.shrink self.stack i; - () -end +module B_ref = Msat.Backtrackable_ref module Solver : sig type t