feat(util): add backtrackable ref

This commit is contained in:
Simon Cruanes 2021-12-07 14:07:41 -05:00
parent a614fdb2e1
commit 9517e88467
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
3 changed files with 60 additions and 0 deletions

View file

@ -0,0 +1,29 @@
type 'a t = {
mutable cur: 'a;
stack: 'a Vec.t;
copy: ('a -> 'a) option;
}
let create ?copy x: _ t =
{cur=x; stack=Vec.create(); copy}
let[@inline] get self = self.cur
let[@inline] set self x = self.cur <- x
let[@inline] update self f = self.cur <- f self.cur
let[@inline] n_levels self = Vec.size self.stack
let[@inline] push_level self : unit =
let x = self.cur in
let x = match self.copy with None -> x | Some f -> f x in
Vec.push self.stack x
let pop_levels self n : unit =
assert (n>=0);
if n > Vec.size self.stack then invalid_arg "Backtrackable_ref.pop_levels";
let i = Vec.size self.stack-n in
let x = Vec.get self.stack i in
self.cur <- x;
Vec.shrink self.stack i;
()

View file

@ -0,0 +1,30 @@
(** {1 Backtrackable ref} *)
type 'a t
val create : ?copy:('a -> 'a) -> 'a -> 'a t
(** Create a backtrackable reference holding the given value initially.
@param copy if provided, will be used to copy the value when [push_level]
is called. *)
val set : 'a t -> 'a -> unit
(** Set the reference's current content *)
val get : 'a t -> 'a
(** Get the reference's current content *)
val update : 'a t -> ('a -> 'a) -> unit
(** Update the reference's current content *)
val push_level : _ t -> unit
(** Push a backtracking level, copying the current value on top of some
stack. The [copy] function will be used if it was provided in {!create}. *)
val n_levels : _ t -> int
(** Number of saved values *)
val pop_levels : _ t -> int -> unit
(** Pop [n] levels, restoring to the value the reference was storing [n] calls
to [push_level] earlier.
@raise Invalid_argument if [n] is bigger than [n_levels]. *)

View file

@ -22,6 +22,7 @@ module Int_map = Util.Int_map
module IArray = IArray
module Backtrack_stack = Backtrack_stack
module Backtrackable_tbl = Backtrackable_tbl
module Backtrackable_ref = Backtrackable_ref
module Log = Log
module Error = Error