add CCHashtbl.Poly and fix issue in Containers (close #46)

This commit is contained in:
Simon Cruanes 2017-02-11 13:28:47 +01:00
parent edc488b909
commit 8ddacbb028
3 changed files with 185 additions and 175 deletions

View file

@ -10,37 +10,38 @@ type 'a printer = Format.formatter -> 'a -> unit
(** {2 Polymorphic tables} *) (** {2 Polymorphic tables} *)
let get tbl x = module Poly = struct
let get tbl x =
try Some (Hashtbl.find tbl x) try Some (Hashtbl.find tbl x)
with Not_found -> None with Not_found -> None
let get_or tbl x ~default = let get_or tbl x ~default =
try Hashtbl.find tbl x try Hashtbl.find tbl x
with Not_found -> default with Not_found -> default
(*$= (*$=
"c" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 3 ~default:"c") "c" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 3 ~default:"c")
"b" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 2 ~default:"c") "b" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 2 ~default:"c")
*) *)
let keys tbl k = Hashtbl.iter (fun key _ -> k key) tbl let keys tbl k = Hashtbl.iter (fun key _ -> k key) tbl
let values tbl k = Hashtbl.iter (fun _ v -> k v) tbl let values tbl k = Hashtbl.iter (fun _ v -> k v) tbl
let keys_list tbl = Hashtbl.fold (fun k _ a -> k::a) tbl [] let keys_list tbl = Hashtbl.fold (fun k _ a -> k::a) tbl []
let values_list tbl = Hashtbl.fold (fun _ v a -> v::a) tbl [] let values_list tbl = Hashtbl.fold (fun _ v a -> v::a) tbl []
let add_list tbl k v = let add_list tbl k v =
let l = try Hashtbl.find tbl k with Not_found -> [] in let l = try Hashtbl.find tbl k with Not_found -> [] in
Hashtbl.replace tbl k (v::l) Hashtbl.replace tbl k (v::l)
let incr ?(by=1) tbl x = let incr ?(by=1) tbl x =
let n = get_or tbl x ~default:0 in let n = get_or tbl x ~default:0 in
if n+by <= 0 if n+by <= 0
then Hashtbl.remove tbl x then Hashtbl.remove tbl x
else Hashtbl.replace tbl x (n+by) else Hashtbl.replace tbl x (n+by)
let decr ?(by=1) tbl x = let decr ?(by=1) tbl x =
try try
let n = Hashtbl.find tbl x in let n = Hashtbl.find tbl x in
if n-by <= 0 if n-by <= 0
@ -48,43 +49,43 @@ let decr ?(by=1) tbl x =
else Hashtbl.replace tbl x (n-by) else Hashtbl.replace tbl x (n-by)
with Not_found -> () with Not_found -> ()
let map_list f h = let map_list f h =
Hashtbl.fold Hashtbl.fold
(fun x y acc -> f x y :: acc) (fun x y acc -> f x y :: acc)
h [] h []
(*$T (*$T
of_list [1,"a"; 2,"b"] |> map_list (fun x y -> string_of_int x ^ y) \ of_list [1,"a"; 2,"b"] |> map_list (fun x y -> string_of_int x ^ y) \
|> List.sort Pervasives.compare = ["1a"; "2b"] |> List.sort Pervasives.compare = ["1a"; "2b"]
*) *)
let to_seq tbl k = Hashtbl.iter (fun key v -> k (key,v)) tbl let to_seq tbl k = Hashtbl.iter (fun key v -> k (key,v)) tbl
let add_seq tbl seq = seq (fun (k,v) -> Hashtbl.add tbl k v) let add_seq tbl seq = seq (fun (k,v) -> Hashtbl.add tbl k v)
let of_seq seq = let of_seq seq =
let tbl = Hashtbl.create 32 in let tbl = Hashtbl.create 32 in
add_seq tbl seq; add_seq tbl seq;
tbl tbl
let add_seq_count tbl seq = seq (fun k -> incr tbl k) let add_seq_count tbl seq = seq (fun k -> incr tbl k)
let of_seq_count seq = let of_seq_count seq =
let tbl = Hashtbl.create 32 in let tbl = Hashtbl.create 32 in
add_seq_count tbl seq; add_seq_count tbl seq;
tbl tbl
let to_list tbl = let to_list tbl =
Hashtbl.fold Hashtbl.fold
(fun k v l -> (k,v) :: l) (fun k v l -> (k,v) :: l)
tbl [] tbl []
let of_list l = let of_list l =
let tbl = Hashtbl.create 32 in let tbl = Hashtbl.create 32 in
List.iter (fun (k,v) -> Hashtbl.add tbl k v) l; List.iter (fun (k,v) -> Hashtbl.add tbl k v) l;
tbl tbl
let update tbl ~f ~k = let update tbl ~f ~k =
let v = get tbl k in let v = get tbl k in
match v, f k v with match v, f k v with
| None, None -> () | None, None -> ()
@ -92,7 +93,7 @@ let update tbl ~f ~k =
| Some _, Some v' -> Hashtbl.replace tbl k v' | Some _, Some v' -> Hashtbl.replace tbl k v'
| Some _, None -> Hashtbl.remove tbl k | Some _, None -> Hashtbl.remove tbl k
(*$R (*$R
let tbl = Hashtbl.create 32 in let tbl = Hashtbl.create 32 in
update tbl ~k:1 ~f:(fun _ _ -> Some "1"); update tbl ~k:1 ~f:(fun _ _ -> Some "1");
assert_equal (Some "1") (get tbl 1); assert_equal (Some "1") (get tbl 1);
@ -101,16 +102,16 @@ let update tbl ~f ~k =
assert_equal 2 (Hashtbl.length tbl); assert_equal 2 (Hashtbl.length tbl);
update tbl ~k:1 ~f:(fun _ _ -> None); update tbl ~k:1 ~f:(fun _ _ -> None);
assert_equal None (get tbl 1); assert_equal None (get tbl 1);
*) *)
let get_or_add tbl ~f ~k = let get_or_add tbl ~f ~k =
try Hashtbl.find tbl k try Hashtbl.find tbl k
with Not_found -> with Not_found ->
let v = f k in let v = f k in
Hashtbl.add tbl k v; Hashtbl.add tbl k v;
v v
(*$R (*$R
let tbl = Hashtbl.create 32 in let tbl = Hashtbl.create 32 in
let v1 = get_or_add tbl ~k:1 ~f:(fun _ -> "1") in let v1 = get_or_add tbl ~k:1 ~f:(fun _ -> "1") in
assert_equal "1" v1; assert_equal "1" v1;
@ -121,9 +122,9 @@ let get_or_add tbl ~f ~k =
assert_equal "2" (get_or_add tbl ~k:2 ~f:(fun _ -> assert false)); assert_equal "2" (get_or_add tbl ~k:2 ~f:(fun _ -> assert false));
assert_equal 2 (Hashtbl.length tbl); assert_equal 2 (Hashtbl.length tbl);
() ()
*) *)
let print pp_k pp_v fmt m = let print pp_k pp_v fmt m =
Format.fprintf fmt "@[<hov2>tbl {@,"; Format.fprintf fmt "@[<hov2>tbl {@,";
let first = ref true in let first = ref true in
Hashtbl.iter Hashtbl.iter
@ -135,6 +136,9 @@ let print pp_k pp_v fmt m =
Format.pp_print_cut fmt () Format.pp_print_cut fmt ()
) m; ) m;
Format.fprintf fmt "}@]" Format.fprintf fmt "}@]"
end
include Poly
(** {2 Functor} *) (** {2 Functor} *)

View file

@ -12,96 +12,102 @@ type 'a printer = Format.formatter -> 'a -> unit
(** {2 Polymorphic tables} *) (** {2 Polymorphic tables} *)
val get : ('a,'b) Hashtbl.t -> 'a -> 'b option (** This sub-module contains the extension of the standard polymorphic hashtbl. *)
(** Safe version of {!Hashtbl.find} *)
val get_or : ('a,'b) Hashtbl.t -> 'a -> default:'b -> 'b module Poly : sig
(** [get_or tbl k ~default] returns the value associated to [k] if present, val get : ('a,'b) Hashtbl.t -> 'a -> 'b option
(** Safe version of {!Hashtbl.find} *)
val get_or : ('a,'b) Hashtbl.t -> 'a -> default:'b -> 'b
(** [get_or tbl k ~default] returns the value associated to [k] if present,
and returns [default] otherwise (if [k] doesn't belong in [tbl]) and returns [default] otherwise (if [k] doesn't belong in [tbl])
@since 0.16 *) @since 0.16 *)
val keys : ('a,'b) Hashtbl.t -> 'a sequence val keys : ('a,'b) Hashtbl.t -> 'a sequence
(** Iterate on keys (similar order as {!Hashtbl.iter}) *) (** Iterate on keys (similar order as {!Hashtbl.iter}) *)
val values : ('a,'b) Hashtbl.t -> 'b sequence val values : ('a,'b) Hashtbl.t -> 'b sequence
(** Iterate on values in the table *) (** Iterate on values in the table *)
val keys_list : ('a, 'b) Hashtbl.t -> 'a list val keys_list : ('a, 'b) Hashtbl.t -> 'a list
(** [keys_list t] is the list of keys in [t]. (** [keys_list t] is the list of keys in [t].
@since 0.8 *) @since 0.8 *)
val values_list : ('a, 'b) Hashtbl.t -> 'b list val values_list : ('a, 'b) Hashtbl.t -> 'b list
(** [values_list t] is the list of values in [t]. (** [values_list t] is the list of values in [t].
@since 0.8 *) @since 0.8 *)
val map_list : ('a -> 'b -> 'c) -> ('a, 'b) Hashtbl.t -> 'c list val map_list : ('a -> 'b -> 'c) -> ('a, 'b) Hashtbl.t -> 'c list
(** Map on a hashtable's items, collect into a list *) (** Map on a hashtable's items, collect into a list *)
val incr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit val incr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit
(** [incr ?by tbl x] increments or initializes the counter associated with [x]. (** [incr ?by tbl x] increments or initializes the counter associated with [x].
If [get tbl x = None], then after update, [get tbl x = Some 1]; If [get tbl x = None], then after update, [get tbl x = Some 1];
otherwise, if [get tbl x = Some n], now [get tbl x = Some (n+1)]. otherwise, if [get tbl x = Some n], now [get tbl x = Some (n+1)].
@param by if specified, the int value is incremented by [by] rather than 1 @param by if specified, the int value is incremented by [by] rather than 1
@since 0.16 *) @since 0.16 *)
val decr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit val decr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit
(** Same as {!incr} but substract 1 (or the value of [by]). (** Same as {!incr} but substract 1 (or the value of [by]).
If the value reaches 0, the key is removed from the table. If the value reaches 0, the key is removed from the table.
This does nothing if the key is not already present in the table. This does nothing if the key is not already present in the table.
@since 0.16 *) @since 0.16 *)
val to_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence val to_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence
(** Iterate on bindings in the table *) (** Iterate on bindings in the table *)
val add_list : ('a, 'b list) Hashtbl.t -> 'a -> 'b -> unit val add_list : ('a, 'b list) Hashtbl.t -> 'a -> 'b -> unit
(** [add_list tbl x y] adds [y] to the list [x] is bound to. If [x] is (** [add_list tbl x y] adds [y] to the list [x] is bound to. If [x] is
not bound, it becomes bound to [[y]]. not bound, it becomes bound to [[y]].
@since 0.16 *) @since 0.16 *)
val add_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence -> unit val add_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence -> unit
(** Add the corresponding pairs to the table, using {!Hashtbl.add}. (** Add the corresponding pairs to the table, using {!Hashtbl.add}.
@since 0.16 *) @since 0.16 *)
val of_seq : ('a * 'b) sequence -> ('a,'b) Hashtbl.t val of_seq : ('a * 'b) sequence -> ('a,'b) Hashtbl.t
(** From the given bindings, added in order *) (** From the given bindings, added in order *)
val add_seq_count : ('a, int) Hashtbl.t -> 'a sequence -> unit val add_seq_count : ('a, int) Hashtbl.t -> 'a sequence -> unit
(** [add_seq_count tbl seq] increments the count of each element of [seq] (** [add_seq_count tbl seq] increments the count of each element of [seq]
by calling {!incr}. This is useful for counting how many times each by calling {!incr}. This is useful for counting how many times each
element of [seq] occurs. element of [seq] occurs.
@since 0.16 *) @since 0.16 *)
val of_seq_count : 'a sequence -> ('a, int) Hashtbl.t val of_seq_count : 'a sequence -> ('a, int) Hashtbl.t
(** Similar to {!add_seq_count}, but allocates a new table and returns it (** Similar to {!add_seq_count}, but allocates a new table and returns it
@since 0.16 *) @since 0.16 *)
val to_list : ('a,'b) Hashtbl.t -> ('a * 'b) list val to_list : ('a,'b) Hashtbl.t -> ('a * 'b) list
(** List of bindings (order unspecified) *) (** List of bindings (order unspecified) *)
val of_list : ('a * 'b) list -> ('a,'b) Hashtbl.t val of_list : ('a * 'b) list -> ('a,'b) Hashtbl.t
(** Build a table from the given list of bindings [k_i -> v_i], (** Build a table from the given list of bindings [k_i -> v_i],
added in order using {!add}. If a key occurs several times, added in order using {!add}. If a key occurs several times,
it will be added several times, and the visible binding it will be added several times, and the visible binding
will be the last one. *) will be the last one. *)
val update : ('a, 'b) Hashtbl.t -> f:('a -> 'b option -> 'b option) -> k:'a -> unit val update : ('a, 'b) Hashtbl.t -> f:('a -> 'b option -> 'b option) -> k:'a -> unit
(** [update tbl ~f ~k] updates key [k] by calling [f k (Some v)] if (** [update tbl ~f ~k] updates key [k] by calling [f k (Some v)] if
[k] was mapped to [v], or [f k None] otherwise; if the call [k] was mapped to [v], or [f k None] otherwise; if the call
returns [None] then [k] is removed/stays removed, if the call returns [None] then [k] is removed/stays removed, if the call
returns [Some v'] then the binding [k -> v'] is inserted returns [Some v'] then the binding [k -> v'] is inserted
using {!Hashtbl.replace} using {!Hashtbl.replace}
@since 0.14 *) @since 0.14 *)
val get_or_add : ('a, 'b) Hashtbl.t -> f:('a -> 'b) -> k:'a -> 'b val get_or_add : ('a, 'b) Hashtbl.t -> f:('a -> 'b) -> k:'a -> 'b
(** [get_or_add tbl ~k ~f] finds and returns the binding of [k] (** [get_or_add tbl ~k ~f] finds and returns the binding of [k]
in [tbl], if it exists. If it does not exist, then [f k] in [tbl], if it exists. If it does not exist, then [f k]
is called to obtain a new binding [v]; [k -> v] is added is called to obtain a new binding [v]; [k -> v] is added
to [tbl] and [v] is returned. to [tbl] and [v] is returned.
@since NEXT_RELEASE *) @since NEXT_RELEASE *)
val print : 'a printer -> 'b printer -> ('a, 'b) Hashtbl.t printer val print : 'a printer -> 'b printer -> ('a, 'b) Hashtbl.t printer
(** Printer for table (** Printer for table
@since 0.13 *) @since 0.13 *)
end
include module type of Poly
(** {2 Functor} *) (** {2 Functor} *)

View file

@ -44,7 +44,7 @@ module Hashtbl = struct
and module Make = Hashtbl.Make and module Make = Hashtbl.Make
and type ('a,'b) t = ('a,'b) Hashtbl.t and type ('a,'b) t = ('a,'b) Hashtbl.t
) )
(* still unable to include CCHashtbl itself, for the polymorphic functions *) include CCHashtbl.Poly
module type S' = CCHashtbl.S module type S' = CCHashtbl.S
module Make' = CCHashtbl.Make module Make' = CCHashtbl.Make
end end