From e74c85e3d2e6260448d67d43c23909e4f268ec1d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 19 Nov 2014 17:44:55 +0100 Subject: [PATCH] more modern interface to Mixtbl; added a way to iterate on all bindings --- misc/mixtbl.ml | 82 +++++++++++++++++++++++++------------------- misc/mixtbl.mli | 48 ++++++++++++++++---------- tests/test_mixtbl.ml | 29 ++++++++-------- 3 files changed, 91 insertions(+), 68 deletions(-) diff --git a/misc/mixtbl.ml b/misc/mixtbl.ml index 95d3413b..d89e6e67 100644 --- a/misc/mixtbl.ml +++ b/misc/mixtbl.ml @@ -26,35 +26,32 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. (** {1 Hash Table with Heterogeneous Keys} *) -type 'a t = ('a, (unit -> unit)) Hashtbl.t - -type ('a, 'b) injection = { - getter : 'a t -> 'a -> 'b option; - setter : 'a t -> 'a -> 'b -> unit; +type 'b injection = { + get : (unit -> unit) -> 'b option; + set : 'b -> (unit -> unit); } +type 'a t = ('a, unit -> unit) Hashtbl.t + let create n = Hashtbl.create n -let access () = +let create_inj () = let r = ref None in - let getter tbl k = - r := None; (* reset state in case last operation was not a get *) - try - (Hashtbl.find tbl k) (); - let result = !r in - r := None; (* clean up here in order to avoid memory leak *) - result - with Not_found -> None + let get f = + r := None; + f (); + !r + and set v = + (fun () -> r := Some v) in - let setter tbl k v = - let v_opt = Some v in - Hashtbl.replace tbl k (fun () -> r := v_opt) - in - { getter; setter; } + {get;set} -let get ~inj tbl x = inj.getter tbl x +let get ~inj tbl x = + try inj.get (Hashtbl.find tbl x) + with Not_found -> None -let set ~inj tbl x y = inj.setter tbl x y +let set ~inj tbl x y = + Hashtbl.replace tbl x (inj.set y) let length tbl = Hashtbl.length tbl @@ -65,14 +62,14 @@ let remove tbl x = Hashtbl.remove tbl x let copy tbl = Hashtbl.copy tbl let mem ~inj tbl x = - match inj.getter tbl x with - | None -> false - | Some _ -> true + try + inj.get (Hashtbl.find tbl x) <> None + with Not_found -> false let find ~inj tbl x = - match inj.getter tbl x with - | None -> raise Not_found - | Some y -> y + match inj.get (Hashtbl.find tbl x) with + | None -> raise Not_found + | Some v -> v let iter_keys tbl f = Hashtbl.iter (fun x _ -> f x) tbl @@ -80,12 +77,27 @@ let iter_keys tbl f = let fold_keys tbl acc f = Hashtbl.fold (fun x _ acc -> f acc x) tbl acc -let keys tbl = - Hashtbl.fold (fun x _ acc -> x :: acc) tbl [] +(** {2 Iterators} *) -let bindings ~inj tbl = - fold_keys tbl [] - (fun acc k -> - match inj.getter tbl k with - | None -> acc - | Some v -> (k, v) :: acc) +type 'a sequence = ('a -> unit) -> unit + +let keys_seq tbl yield = + Hashtbl.iter + (fun x _ -> yield x) + tbl + +let bindings_of ~inj tbl yield = + Hashtbl.iter + (fun k value -> + match inj.get value with + | None -> () + | Some v -> yield (k, v) + ) tbl + +type value = + | Value : ('b injection -> 'b option) -> value + +let bindings tbl yield = + Hashtbl.iter + (fun x y -> yield (x, Value (fun inj -> inj.get y))) + tbl diff --git a/misc/mixtbl.mli b/misc/mixtbl.mli index 4681c1b9..6e714c64 100644 --- a/misc/mixtbl.mli +++ b/misc/mixtbl.mli @@ -58,28 +58,33 @@ type 'a t (** A hash table containing values of different types. The type parameter ['a] represents the type of the keys. *) -type ('a, 'b) injection -(** An accessor for values of type 'b in the table. Values put - in the table using an injection can only be retrieved using this - very same injection. *) +type 'b injection +(** An accessor for values of type 'b in any table. Values put + in the table using an key can only be retrieved using this + very same key. *) val create : int -> 'a t (** [create n] creates a hash table of initial size [n]. *) -val access : unit -> ('a, 'b) injection +val create_inj : unit -> 'b injection (** Return a value that works for a given type of values. This function is - normally called once for each type of value. Several injections may be + normally called once for each type of value. Several keys may be created for the same type, but a value set with a given setter can only be - retrieved with the matching getter. The same injection can be reused + retrieved with the matching getter. The same key can be reused across multiple tables (although not in a thread-safe way). *) -val get : inj:('a, 'b) injection -> 'a t -> 'a -> 'b option +val get : inj:'b injection -> 'a t -> 'a -> 'b option (** Get the value corresponding to this key, if it exists and - belongs to the same injection *) + belongs to the same key *) -val set : inj:('a, 'b) injection -> 'a t -> 'a -> 'b -> unit +val set : inj:'b injection -> 'a t -> 'a -> 'b -> unit (** Bind the key to the value, using [inj] *) +val find : inj:'b injection -> 'a t -> 'a -> 'b +(** Find the value for the given key, which must be of the right type. + raises Not_found if either the key is not found, or if its value + doesn't belong to the right type *) + val length : 'a t -> int (** Number of bindings *) @@ -92,22 +97,27 @@ val remove : 'a t -> 'a -> unit val copy : 'a t -> 'a t (** Copy of the table *) -val mem : inj:('a, _) injection -> 'a t -> 'a -> bool +val mem : inj:_ injection-> 'a t -> 'a -> bool (** Is the given key in the table, with the right type? *) -val find : inj:('a, 'b) injection -> 'a t -> 'a -> 'b -(** Find the value for the given key, which must be of the right type. - raises Not_found if either the key is not found, or if its value - doesn't belong to the right type *) - val iter_keys : 'a t -> ('a -> unit) -> unit (** Iterate on the keys of this table *) val fold_keys : 'a t -> 'b -> ('b -> 'a -> 'b) -> 'b (** Fold over the keys *) -val keys : 'a t -> 'a list -(** List of the keys *) +(** {2 Iterators} *) -val bindings : inj:('a, 'b) injection -> 'a t -> ('a * 'b) list +type 'a sequence = ('a -> unit) -> unit + +val keys_seq : 'a t -> 'a sequence +(** All the keys *) + +val bindings_of : inj:'b injection -> 'a t -> ('a * 'b) sequence (** All the bindings that come from the corresponding injection *) + +type value = + | Value : ('b injection -> 'b option) -> value + +val bindings : 'a t -> ('a * value) sequence +(** Iterate on all bindings *) diff --git a/tests/test_mixtbl.ml b/tests/test_mixtbl.ml index 6e517417..bbb5b28f 100644 --- a/tests/test_mixtbl.ml +++ b/tests/test_mixtbl.ml @@ -1,14 +1,15 @@ open OUnit open Containers_misc +open CCFun let example () = - let inj_int = Mixtbl.access () in + let inj_int = Mixtbl.create_inj () in let tbl = Mixtbl.create 10 in OUnit.assert_equal None (Mixtbl.get ~inj:inj_int tbl "a"); Mixtbl.set inj_int tbl "a" 1; OUnit.assert_equal (Some 1) (Mixtbl.get ~inj:inj_int tbl "a"); - let inj_string = Mixtbl.access () in + let inj_string = Mixtbl.create_inj () in Mixtbl.set inj_string tbl "b" "Hello"; OUnit.assert_equal (Some "Hello") (Mixtbl.get inj_string tbl "b"); OUnit.assert_equal None (Mixtbl.get inj_string tbl "a"); @@ -19,7 +20,7 @@ let example () = () let test_length () = - let inj_int = Mixtbl.access () in + let inj_int = Mixtbl.create_inj () in let tbl = Mixtbl.create 5 in Mixtbl.set ~inj:inj_int tbl "foo" 1; Mixtbl.set ~inj:inj_int tbl "bar" 2; @@ -32,8 +33,8 @@ let test_length () = () let test_clear () = - let inj_int = Mixtbl.access () in - let inj_str = Mixtbl.access () in + let inj_int = Mixtbl.create_inj () in + let inj_str = Mixtbl.create_inj () in let tbl = Mixtbl.create 5 in Mixtbl.set ~inj:inj_int tbl "foo" 1; Mixtbl.set ~inj:inj_int tbl "bar" 2; @@ -44,8 +45,8 @@ let test_clear () = () let test_mem () = - let inj_int = Mixtbl.access () in - let inj_str = Mixtbl.access () in + let inj_int = Mixtbl.create_inj () in + let inj_str = Mixtbl.create_inj () in let tbl = Mixtbl.create 5 in Mixtbl.set ~inj:inj_int tbl "foo" 1; Mixtbl.set ~inj:inj_int tbl "bar" 2; @@ -59,27 +60,27 @@ let test_mem () = () let test_keys () = - let inj_int = Mixtbl.access () in - let inj_str = Mixtbl.access () in + let inj_int = Mixtbl.create_inj () in + let inj_str = Mixtbl.create_inj () in let tbl = Mixtbl.create 5 in Mixtbl.set ~inj:inj_int tbl "foo" 1; Mixtbl.set ~inj:inj_int tbl "bar" 2; Mixtbl.set ~inj:inj_str tbl "baaz" "hello"; - let l = Mixtbl.keys tbl in + let l = Mixtbl.keys_seq tbl |> CCSequence.to_list in OUnit.assert_equal ["baaz"; "bar"; "foo"] (List.sort compare l); () let test_bindings () = - let inj_int = Mixtbl.access () in - let inj_str = Mixtbl.access () in + let inj_int = Mixtbl.create_inj () in + let inj_str = Mixtbl.create_inj () in let tbl = Mixtbl.create 5 in Mixtbl.set ~inj:inj_int tbl "foo" 1; Mixtbl.set ~inj:inj_int tbl "bar" 2; Mixtbl.set ~inj:inj_str tbl "baaz" "hello"; Mixtbl.set ~inj:inj_str tbl "str" "rts"; - let l_int = Mixtbl.bindings tbl ~inj:inj_int in + let l_int = Mixtbl.bindings_of tbl ~inj:inj_int |> CCSequence.to_list in OUnit.assert_equal ["bar", 2; "foo", 1] (List.sort compare l_int); - let l_str = Mixtbl.bindings tbl ~inj:inj_str in + let l_str = Mixtbl.bindings_of tbl ~inj:inj_str |> CCSequence.to_list in OUnit.assert_equal ["baaz", "hello"; "str", "rts"] (List.sort compare l_str); ()