diff --git a/src/core/CCHashtbl.ml b/src/core/CCHashtbl.ml index cc169335..832d8682 100644 --- a/src/core/CCHashtbl.ml +++ b/src/core/CCHashtbl.ml @@ -103,6 +103,26 @@ let update tbl ~f ~k = assert_equal None (get tbl 1); *) +let get_or_add tbl ~f ~k = + try Hashtbl.find tbl k + with Not_found -> + let v = f k in + Hashtbl.add tbl k v; + v + +(*$R + let tbl = Hashtbl.create 32 in + let v1 = get_or_add tbl ~k:1 ~f:(fun _ -> "1") in + assert_equal "1" v1; + assert_equal (Some "1") (get tbl 1); + let v2 = get_or_add tbl ~k:2 ~f:(fun _ ->"2") in + assert_equal "2" v2; + assert_equal (Some "2") (get tbl 2); + assert_equal "2" (get_or_add tbl ~k:2 ~f:(fun _ -> assert false)); + assert_equal 2 (Hashtbl.length tbl); + () +*) + let print pp_k pp_v fmt m = Format.fprintf fmt "@[tbl {@,"; let first = ref true in @@ -188,7 +208,10 @@ module type S = sig (** List of bindings (order unspecified) *) val of_list : (key * 'a) list -> 'a t - (** From the given list of bindings, added in order *) + (** Build a table from the given list of bindings [k_i -> v_i], + added in order using {!add}. If a key occurs several times, + it will be added several times, and the visible binding + will be the last one. *) val update : 'a t -> f:(key -> 'a option -> 'a option) -> k:key -> unit (** [update tbl ~f ~k] updates key [k] by calling [f k (Some v)] if @@ -198,6 +221,13 @@ module type S = sig using {!Hashtbl.replace} @since 0.14 *) + val get_or_add : 'a t -> f:(key -> 'a) -> k:key -> 'a + (** [get_or_add tbl ~k ~f] finds and returns the binding of [k] + in [tbl], if it exists. If it does not exist, then [f k] + is called to obtain a new binding [v]; [k -> v] is added + to [tbl] and [v] is returned. + @since NEXT_RELEASE *) + val print : key printer -> 'a printer -> 'a t printer (** Printer for tables @since 0.13 *) @@ -277,6 +307,13 @@ module Make(X : Hashtbl.HashedType) | Some _, Some v' -> replace tbl k v' | Some _, None -> remove tbl k + let get_or_add tbl ~f ~k = + try find tbl k + with Not_found -> + let v = f k in + add tbl k v; + v + let to_seq tbl k = iter (fun key v -> k (key,v)) tbl let add_seq tbl seq = seq (fun (k,v) -> add tbl k v) diff --git a/src/core/CCHashtbl.mli b/src/core/CCHashtbl.mli index 36efe93e..54163e81 100644 --- a/src/core/CCHashtbl.mli +++ b/src/core/CCHashtbl.mli @@ -92,6 +92,13 @@ val update : ('a, 'b) Hashtbl.t -> f:('a -> 'b option -> 'b option) -> k:'a -> u using {!Hashtbl.replace} @since 0.14 *) +val get_or_add : ('a, 'b) Hashtbl.t -> f:('a -> 'b) -> k:'a -> 'b +(** [get_or_add tbl ~k ~f] finds and returns the binding of [k] + in [tbl], if it exists. If it does not exist, then [f k] + is called to obtain a new binding [v]; [k -> v] is added + to [tbl] and [v] is returned. + @since NEXT_RELEASE *) + val print : 'a printer -> 'b printer -> ('a, 'b) Hashtbl.t printer (** Printer for table @since 0.13 *) @@ -181,6 +188,13 @@ module type S = sig using {!Hashtbl.replace} @since 0.14 *) + val get_or_add : 'a t -> f:(key -> 'a) -> k:key -> 'a + (** [get_or_add tbl ~k ~f] finds and returns the binding of [k] + in [tbl], if it exists. If it does not exist, then [f k] + is called to obtain a new binding [v]; [k -> v] is added + to [tbl] and [v] is returned. + @since NEXT_RELEASE *) + val print : key printer -> 'a printer -> 'a t printer (** Printer for tables @since 0.13 *)