add CCHashtbl.Poly and fix issue in Containers (close #46)

This commit is contained in:
Simon Cruanes 2017-02-11 13:28:47 +01:00
parent edc488b909
commit 8ddacbb028
3 changed files with 185 additions and 175 deletions

View file

@ -10,131 +10,135 @@ type 'a printer = Format.formatter -> 'a -> unit
(** {2 Polymorphic tables} *) (** {2 Polymorphic tables} *)
let get tbl x = module Poly = struct
try Some (Hashtbl.find tbl x) let get tbl x =
with Not_found -> None try Some (Hashtbl.find tbl x)
with Not_found -> None
let get_or tbl x ~default = let get_or tbl x ~default =
try Hashtbl.find tbl x try Hashtbl.find tbl x
with Not_found -> default with Not_found -> default
(*$= (*$=
"c" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 3 ~default:"c") "c" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 3 ~default:"c")
"b" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 2 ~default:"c") "b" (let tbl = of_list [1,"a"; 2,"b"] in get_or tbl 2 ~default:"c")
*) *)
let keys tbl k = Hashtbl.iter (fun key _ -> k key) tbl let keys tbl k = Hashtbl.iter (fun key _ -> k key) tbl
let values tbl k = Hashtbl.iter (fun _ v -> k v) tbl let values tbl k = Hashtbl.iter (fun _ v -> k v) tbl
let keys_list tbl = Hashtbl.fold (fun k _ a -> k::a) tbl [] let keys_list tbl = Hashtbl.fold (fun k _ a -> k::a) tbl []
let values_list tbl = Hashtbl.fold (fun _ v a -> v::a) tbl [] let values_list tbl = Hashtbl.fold (fun _ v a -> v::a) tbl []
let add_list tbl k v = let add_list tbl k v =
let l = try Hashtbl.find tbl k with Not_found -> [] in let l = try Hashtbl.find tbl k with Not_found -> [] in
Hashtbl.replace tbl k (v::l) Hashtbl.replace tbl k (v::l)
let incr ?(by=1) tbl x = let incr ?(by=1) tbl x =
let n = get_or tbl x ~default:0 in let n = get_or tbl x ~default:0 in
if n+by <= 0 if n+by <= 0
then Hashtbl.remove tbl x
else Hashtbl.replace tbl x (n+by)
let decr ?(by=1) tbl x =
try
let n = Hashtbl.find tbl x in
if n-by <= 0
then Hashtbl.remove tbl x then Hashtbl.remove tbl x
else Hashtbl.replace tbl x (n-by) else Hashtbl.replace tbl x (n+by)
with Not_found -> ()
let map_list f h = let decr ?(by=1) tbl x =
Hashtbl.fold try
(fun x y acc -> f x y :: acc) let n = Hashtbl.find tbl x in
h [] if n-by <= 0
then Hashtbl.remove tbl x
else Hashtbl.replace tbl x (n-by)
with Not_found -> ()
(*$T let map_list f h =
of_list [1,"a"; 2,"b"] |> map_list (fun x y -> string_of_int x ^ y) \ Hashtbl.fold
|> List.sort Pervasives.compare = ["1a"; "2b"] (fun x y acc -> f x y :: acc)
*) h []
let to_seq tbl k = Hashtbl.iter (fun key v -> k (key,v)) tbl (*$T
of_list [1,"a"; 2,"b"] |> map_list (fun x y -> string_of_int x ^ y) \
|> List.sort Pervasives.compare = ["1a"; "2b"]
*)
let add_seq tbl seq = seq (fun (k,v) -> Hashtbl.add tbl k v) let to_seq tbl k = Hashtbl.iter (fun key v -> k (key,v)) tbl
let of_seq seq = let add_seq tbl seq = seq (fun (k,v) -> Hashtbl.add tbl k v)
let tbl = Hashtbl.create 32 in
add_seq tbl seq;
tbl
let add_seq_count tbl seq = seq (fun k -> incr tbl k) let of_seq seq =
let tbl = Hashtbl.create 32 in
add_seq tbl seq;
tbl
let of_seq_count seq = let add_seq_count tbl seq = seq (fun k -> incr tbl k)
let tbl = Hashtbl.create 32 in
add_seq_count tbl seq;
tbl
let to_list tbl = let of_seq_count seq =
Hashtbl.fold let tbl = Hashtbl.create 32 in
(fun k v l -> (k,v) :: l) add_seq_count tbl seq;
tbl [] tbl
let of_list l = let to_list tbl =
let tbl = Hashtbl.create 32 in Hashtbl.fold
List.iter (fun (k,v) -> Hashtbl.add tbl k v) l; (fun k v l -> (k,v) :: l)
tbl tbl []
let update tbl ~f ~k = let of_list l =
let v = get tbl k in let tbl = Hashtbl.create 32 in
match v, f k v with List.iter (fun (k,v) -> Hashtbl.add tbl k v) l;
| None, None -> () tbl
| None, Some v' -> Hashtbl.add tbl k v'
| Some _, Some v' -> Hashtbl.replace tbl k v'
| Some _, None -> Hashtbl.remove tbl k
(*$R let update tbl ~f ~k =
let tbl = Hashtbl.create 32 in let v = get tbl k in
update tbl ~k:1 ~f:(fun _ _ -> Some "1"); match v, f k v with
assert_equal (Some "1") (get tbl 1); | None, None -> ()
update tbl ~k:2 ~f:(fun _ v->match v with Some _ -> assert false | None -> Some "2"); | None, Some v' -> Hashtbl.add tbl k v'
assert_equal (Some "2") (get tbl 2); | Some _, Some v' -> Hashtbl.replace tbl k v'
assert_equal 2 (Hashtbl.length tbl); | Some _, None -> Hashtbl.remove tbl k
update tbl ~k:1 ~f:(fun _ _ -> None);
assert_equal None (get tbl 1);
*)
let get_or_add tbl ~f ~k = (*$R
try Hashtbl.find tbl k let tbl = Hashtbl.create 32 in
with Not_found -> update tbl ~k:1 ~f:(fun _ _ -> Some "1");
let v = f k in assert_equal (Some "1") (get tbl 1);
Hashtbl.add tbl k v; update tbl ~k:2 ~f:(fun _ v->match v with Some _ -> assert false | None -> Some "2");
v assert_equal (Some "2") (get tbl 2);
assert_equal 2 (Hashtbl.length tbl);
update tbl ~k:1 ~f:(fun _ _ -> None);
assert_equal None (get tbl 1);
*)
(*$R let get_or_add tbl ~f ~k =
let tbl = Hashtbl.create 32 in try Hashtbl.find tbl k
let v1 = get_or_add tbl ~k:1 ~f:(fun _ -> "1") in with Not_found ->
assert_equal "1" v1; let v = f k in
assert_equal (Some "1") (get tbl 1); Hashtbl.add tbl k v;
let v2 = get_or_add tbl ~k:2 ~f:(fun _ ->"2") in v
assert_equal "2" v2;
assert_equal (Some "2") (get tbl 2);
assert_equal "2" (get_or_add tbl ~k:2 ~f:(fun _ -> assert false));
assert_equal 2 (Hashtbl.length tbl);
()
*)
let print pp_k pp_v fmt m = (*$R
Format.fprintf fmt "@[<hov2>tbl {@,"; let tbl = Hashtbl.create 32 in
let first = ref true in let v1 = get_or_add tbl ~k:1 ~f:(fun _ -> "1") in
Hashtbl.iter assert_equal "1" v1;
(fun k v -> assert_equal (Some "1") (get tbl 1);
if !first then first := false else Format.pp_print_string fmt ", "; let v2 = get_or_add tbl ~k:2 ~f:(fun _ ->"2") in
pp_k fmt k; assert_equal "2" v2;
Format.pp_print_string fmt " -> "; assert_equal (Some "2") (get tbl 2);
pp_v fmt v; assert_equal "2" (get_or_add tbl ~k:2 ~f:(fun _ -> assert false));
Format.pp_print_cut fmt () assert_equal 2 (Hashtbl.length tbl);
) m; ()
Format.fprintf fmt "}@]" *)
let print pp_k pp_v fmt m =
Format.fprintf fmt "@[<hov2>tbl {@,";
let first = ref true in
Hashtbl.iter
(fun k v ->
if !first then first := false else Format.pp_print_string fmt ", ";
pp_k fmt k;
Format.pp_print_string fmt " -> ";
pp_v fmt v;
Format.pp_print_cut fmt ()
) m;
Format.fprintf fmt "}@]"
end
include Poly
(** {2 Functor} *) (** {2 Functor} *)

View file

@ -12,96 +12,102 @@ type 'a printer = Format.formatter -> 'a -> unit
(** {2 Polymorphic tables} *) (** {2 Polymorphic tables} *)
val get : ('a,'b) Hashtbl.t -> 'a -> 'b option (** This sub-module contains the extension of the standard polymorphic hashtbl. *)
(** Safe version of {!Hashtbl.find} *)
val get_or : ('a,'b) Hashtbl.t -> 'a -> default:'b -> 'b module Poly : sig
(** [get_or tbl k ~default] returns the value associated to [k] if present, val get : ('a,'b) Hashtbl.t -> 'a -> 'b option
and returns [default] otherwise (if [k] doesn't belong in [tbl]) (** Safe version of {!Hashtbl.find} *)
@since 0.16 *)
val keys : ('a,'b) Hashtbl.t -> 'a sequence val get_or : ('a,'b) Hashtbl.t -> 'a -> default:'b -> 'b
(** Iterate on keys (similar order as {!Hashtbl.iter}) *) (** [get_or tbl k ~default] returns the value associated to [k] if present,
and returns [default] otherwise (if [k] doesn't belong in [tbl])
@since 0.16 *)
val values : ('a,'b) Hashtbl.t -> 'b sequence val keys : ('a,'b) Hashtbl.t -> 'a sequence
(** Iterate on values in the table *) (** Iterate on keys (similar order as {!Hashtbl.iter}) *)
val keys_list : ('a, 'b) Hashtbl.t -> 'a list val values : ('a,'b) Hashtbl.t -> 'b sequence
(** [keys_list t] is the list of keys in [t]. (** Iterate on values in the table *)
@since 0.8 *)
val values_list : ('a, 'b) Hashtbl.t -> 'b list val keys_list : ('a, 'b) Hashtbl.t -> 'a list
(** [values_list t] is the list of values in [t]. (** [keys_list t] is the list of keys in [t].
@since 0.8 *) @since 0.8 *)
val map_list : ('a -> 'b -> 'c) -> ('a, 'b) Hashtbl.t -> 'c list val values_list : ('a, 'b) Hashtbl.t -> 'b list
(** Map on a hashtable's items, collect into a list *) (** [values_list t] is the list of values in [t].
@since 0.8 *)
val incr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit val map_list : ('a -> 'b -> 'c) -> ('a, 'b) Hashtbl.t -> 'c list
(** [incr ?by tbl x] increments or initializes the counter associated with [x]. (** Map on a hashtable's items, collect into a list *)
If [get tbl x = None], then after update, [get tbl x = Some 1];
otherwise, if [get tbl x = Some n], now [get tbl x = Some (n+1)].
@param by if specified, the int value is incremented by [by] rather than 1
@since 0.16 *)
val decr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit val incr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit
(** Same as {!incr} but substract 1 (or the value of [by]). (** [incr ?by tbl x] increments or initializes the counter associated with [x].
If the value reaches 0, the key is removed from the table. If [get tbl x = None], then after update, [get tbl x = Some 1];
This does nothing if the key is not already present in the table. otherwise, if [get tbl x = Some n], now [get tbl x = Some (n+1)].
@since 0.16 *) @param by if specified, the int value is incremented by [by] rather than 1
@since 0.16 *)
val to_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence val decr : ?by:int -> ('a, int) Hashtbl.t -> 'a -> unit
(** Iterate on bindings in the table *) (** Same as {!incr} but substract 1 (or the value of [by]).
If the value reaches 0, the key is removed from the table.
This does nothing if the key is not already present in the table.
@since 0.16 *)
val add_list : ('a, 'b list) Hashtbl.t -> 'a -> 'b -> unit val to_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence
(** [add_list tbl x y] adds [y] to the list [x] is bound to. If [x] is (** Iterate on bindings in the table *)
not bound, it becomes bound to [[y]].
@since 0.16 *)
val add_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence -> unit val add_list : ('a, 'b list) Hashtbl.t -> 'a -> 'b -> unit
(** Add the corresponding pairs to the table, using {!Hashtbl.add}. (** [add_list tbl x y] adds [y] to the list [x] is bound to. If [x] is
@since 0.16 *) not bound, it becomes bound to [[y]].
@since 0.16 *)
val of_seq : ('a * 'b) sequence -> ('a,'b) Hashtbl.t val add_seq : ('a,'b) Hashtbl.t -> ('a * 'b) sequence -> unit
(** From the given bindings, added in order *) (** Add the corresponding pairs to the table, using {!Hashtbl.add}.
@since 0.16 *)
val add_seq_count : ('a, int) Hashtbl.t -> 'a sequence -> unit val of_seq : ('a * 'b) sequence -> ('a,'b) Hashtbl.t
(** [add_seq_count tbl seq] increments the count of each element of [seq] (** From the given bindings, added in order *)
by calling {!incr}. This is useful for counting how many times each
element of [seq] occurs.
@since 0.16 *)
val of_seq_count : 'a sequence -> ('a, int) Hashtbl.t val add_seq_count : ('a, int) Hashtbl.t -> 'a sequence -> unit
(** Similar to {!add_seq_count}, but allocates a new table and returns it (** [add_seq_count tbl seq] increments the count of each element of [seq]
@since 0.16 *) by calling {!incr}. This is useful for counting how many times each
element of [seq] occurs.
@since 0.16 *)
val to_list : ('a,'b) Hashtbl.t -> ('a * 'b) list val of_seq_count : 'a sequence -> ('a, int) Hashtbl.t
(** List of bindings (order unspecified) *) (** Similar to {!add_seq_count}, but allocates a new table and returns it
@since 0.16 *)
val of_list : ('a * 'b) list -> ('a,'b) Hashtbl.t val to_list : ('a,'b) Hashtbl.t -> ('a * 'b) list
(** Build a table from the given list of bindings [k_i -> v_i], (** List of bindings (order unspecified) *)
added in order using {!add}. If a key occurs several times,
it will be added several times, and the visible binding
will be the last one. *)
val update : ('a, 'b) Hashtbl.t -> f:('a -> 'b option -> 'b option) -> k:'a -> unit val of_list : ('a * 'b) list -> ('a,'b) Hashtbl.t
(** [update tbl ~f ~k] updates key [k] by calling [f k (Some v)] if (** Build a table from the given list of bindings [k_i -> v_i],
[k] was mapped to [v], or [f k None] otherwise; if the call added in order using {!add}. If a key occurs several times,
returns [None] then [k] is removed/stays removed, if the call it will be added several times, and the visible binding
returns [Some v'] then the binding [k -> v'] is inserted will be the last one. *)
using {!Hashtbl.replace}
@since 0.14 *)
val get_or_add : ('a, 'b) Hashtbl.t -> f:('a -> 'b) -> k:'a -> 'b val update : ('a, 'b) Hashtbl.t -> f:('a -> 'b option -> 'b option) -> k:'a -> unit
(** [get_or_add tbl ~k ~f] finds and returns the binding of [k] (** [update tbl ~f ~k] updates key [k] by calling [f k (Some v)] if
in [tbl], if it exists. If it does not exist, then [f k] [k] was mapped to [v], or [f k None] otherwise; if the call
is called to obtain a new binding [v]; [k -> v] is added returns [None] then [k] is removed/stays removed, if the call
to [tbl] and [v] is returned. returns [Some v'] then the binding [k -> v'] is inserted
@since NEXT_RELEASE *) using {!Hashtbl.replace}
@since 0.14 *)
val print : 'a printer -> 'b printer -> ('a, 'b) Hashtbl.t printer val get_or_add : ('a, 'b) Hashtbl.t -> f:('a -> 'b) -> k:'a -> 'b
(** Printer for table (** [get_or_add tbl ~k ~f] finds and returns the binding of [k]
@since 0.13 *) in [tbl], if it exists. If it does not exist, then [f k]
is called to obtain a new binding [v]; [k -> v] is added
to [tbl] and [v] is returned.
@since NEXT_RELEASE *)
val print : 'a printer -> 'b printer -> ('a, 'b) Hashtbl.t printer
(** Printer for table
@since 0.13 *)
end
include module type of Poly
(** {2 Functor} *) (** {2 Functor} *)

View file

@ -44,7 +44,7 @@ module Hashtbl = struct
and module Make = Hashtbl.Make and module Make = Hashtbl.Make
and type ('a,'b) t = ('a,'b) Hashtbl.t and type ('a,'b) t = ('a,'b) Hashtbl.t
) )
(* still unable to include CCHashtbl itself, for the polymorphic functions *) include CCHashtbl.Poly
module type S' = CCHashtbl.S module type S' = CCHashtbl.S
module Make' = CCHashtbl.Make module Make' = CCHashtbl.Make
end end