mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-05 19:00:31 -05:00
336 lines
8.7 KiB
OCaml
336 lines
8.7 KiB
OCaml
(* reference implementations for some structures, for comparison purpose *)
|
|
|
|
module PersistentHashtbl (H : Hashtbl.HashedType) = struct
|
|
module Table = Hashtbl.Make (H)
|
|
(** Imperative hashtable *)
|
|
|
|
type key = H.t
|
|
|
|
type 'a t = 'a zipper ref
|
|
|
|
and 'a zipper =
|
|
| Table of 'a Table.t (** Concrete table *)
|
|
| Add of key * 'a * 'a t (** Add key *)
|
|
| Replace of key * 'a * 'a t (** Replace key by value *)
|
|
| Remove of key * 'a t (** As the table, but without given key *)
|
|
|
|
let create i = ref (Table (Table.create i))
|
|
let empty () = create 11
|
|
|
|
(* pass continuation to get a tailrec rerooting *)
|
|
let rec _reroot t k =
|
|
match !t with
|
|
| Table tbl -> k tbl (* done *)
|
|
| Add (key, v, t') ->
|
|
_reroot t' (fun tbl ->
|
|
t' := Remove (key, t);
|
|
Table.add tbl key v;
|
|
t := Table tbl;
|
|
k tbl)
|
|
| Replace (key, v, t') ->
|
|
_reroot t' (fun tbl ->
|
|
let v' = Table.find tbl key in
|
|
t' := Replace (key, v', t);
|
|
t := Table tbl;
|
|
Table.replace tbl key v;
|
|
k tbl)
|
|
| Remove (key, t') ->
|
|
_reroot t' (fun tbl ->
|
|
let v = Table.find tbl key in
|
|
t' := Add (key, v, t);
|
|
t := Table tbl;
|
|
Table.remove tbl key;
|
|
k tbl)
|
|
|
|
(* Reroot: modify the zipper so that the current node is a proper
|
|
hashtable, and return the hashtable *)
|
|
let reroot t =
|
|
match !t with
|
|
| Table tbl -> tbl
|
|
| _ -> _reroot t (fun x -> x)
|
|
|
|
let is_empty t = Table.length (reroot t) = 0
|
|
let find t k = Table.find (reroot t) k
|
|
|
|
(*$R
|
|
let h = H.of_seq my_seq in
|
|
OUnit.assert_equal "a" (H.find h 1);
|
|
OUnit.assert_raises Not_found (fun () -> H.find h 5);
|
|
let h' = H.replace h 5 "e" in
|
|
OUnit.assert_equal "a" (H.find h' 1);
|
|
OUnit.assert_equal "e" (H.find h' 5);
|
|
OUnit.assert_equal "a" (H.find h 1);
|
|
OUnit.assert_raises Not_found (fun () -> H.find h 5);
|
|
*)
|
|
|
|
(*$R
|
|
let n = 10000 in
|
|
let seq = Sequence.map (fun i -> i, string_of_int i) Sequence.(0--n) in
|
|
let h = H.of_seq seq in
|
|
Sequence.iter
|
|
(fun (k,v) ->
|
|
OUnit.assert_equal ~printer:(fun x -> x) v (H.find h k))
|
|
seq;
|
|
OUnit.assert_raises Not_found (fun () -> H.find h (n+1));
|
|
*)
|
|
|
|
(*$QR
|
|
_list_int_int
|
|
(fun l ->
|
|
let h = H.of_list l in
|
|
List.for_all
|
|
(fun (k,v) ->
|
|
try
|
|
H.find h k = v
|
|
with Not_found -> false)
|
|
l
|
|
)
|
|
*)
|
|
|
|
let get_exn k t = find t k
|
|
let get k t = try Some (find t k) with Not_found -> None
|
|
let mem t k = Table.mem (reroot t) k
|
|
let length t = Table.length (reroot t)
|
|
|
|
(*$R
|
|
let h = H.of_seq
|
|
Sequence.(map (fun i -> i, string_of_int i)
|
|
(0 -- 200)) in
|
|
OUnit.assert_equal 201 (H.length h);
|
|
*)
|
|
|
|
(*$QR
|
|
_list_int_int (fun l ->
|
|
let h = H.of_list l in
|
|
H.length h = List.length l
|
|
)
|
|
*)
|
|
|
|
let replace t k v =
|
|
let tbl = reroot t in
|
|
(* create the new hashtable *)
|
|
let t' = ref (Table tbl) in
|
|
(* update [t] to point to the new hashtable *)
|
|
(try
|
|
let v' = Table.find tbl k in
|
|
t := Replace (k, v', t')
|
|
with Not_found -> t := Remove (k, t'));
|
|
(* modify the underlying hashtable *)
|
|
Table.replace tbl k v;
|
|
t'
|
|
|
|
let remove t k =
|
|
let tbl = reroot t in
|
|
try
|
|
let v' = Table.find tbl k in
|
|
(* value present, make a new hashtable without this value *)
|
|
let t' = ref (Table tbl) in
|
|
t := Add (k, v', t');
|
|
Table.remove tbl k;
|
|
t'
|
|
with Not_found -> (* not member, nothing to do *)
|
|
t
|
|
|
|
(*$R
|
|
let h = H.of_seq my_seq in
|
|
OUnit.assert_equal (H.find h 2) "b";
|
|
OUnit.assert_equal (H.find h 3) "c";
|
|
OUnit.assert_equal (H.find h 4) "d";
|
|
OUnit.assert_equal (H.length h) 4;
|
|
let h = H.remove h 2 in
|
|
OUnit.assert_equal (H.find h 3) "c";
|
|
OUnit.assert_equal (H.length h) 3;
|
|
OUnit.assert_raises Not_found (fun () -> H.find h 2)
|
|
*)
|
|
|
|
(*$R
|
|
let open Sequence.Infix in
|
|
let n = 10000 in
|
|
let seq = Sequence.map (fun i -> i, string_of_int i) (0 -- n) in
|
|
let h = H.of_seq seq in
|
|
OUnit.assert_equal (n+1) (H.length h);
|
|
let h = Sequence.fold (fun h i -> H.remove h i) h (0 -- 500) in
|
|
OUnit.assert_equal (n-500) (H.length h);
|
|
OUnit.assert_bool "is_empty" (H.is_empty (H.create 16));
|
|
*)
|
|
|
|
(*$QR
|
|
_list_int_int (fun l ->
|
|
let h = H.of_list l in
|
|
let h = List.fold_left (fun h (k,_) -> H.remove h k) h l in
|
|
H.is_empty h)
|
|
*)
|
|
|
|
let update t k f =
|
|
let v = get k t in
|
|
match v, f v with
|
|
| None, None -> t (* no change *)
|
|
| Some _, None -> remove t k
|
|
| _, Some v' -> replace t k v'
|
|
|
|
let copy t =
|
|
let tbl = reroot t in
|
|
(* no one will point to the new [t] *)
|
|
let t = ref (Table (Table.copy tbl)) in
|
|
t
|
|
|
|
let iter t f =
|
|
let tbl = reroot t in
|
|
Table.iter f tbl
|
|
|
|
let fold f acc t =
|
|
let tbl = reroot t in
|
|
Table.fold (fun k v acc -> f acc k v) tbl acc
|
|
|
|
let map f t =
|
|
let tbl = reroot t in
|
|
let res = Table.create (Table.length tbl) in
|
|
Table.iter (fun k v -> Table.replace res k (f k v)) tbl;
|
|
ref (Table res)
|
|
|
|
let filter p t =
|
|
let tbl = reroot t in
|
|
let res = Table.create (Table.length tbl) in
|
|
Table.iter (fun k v -> if p k v then Table.replace res k v) tbl;
|
|
ref (Table res)
|
|
|
|
let filter_map f t =
|
|
let tbl = reroot t in
|
|
let res = Table.create (Table.length tbl) in
|
|
Table.iter
|
|
(fun k v ->
|
|
match f k v with
|
|
| None -> ()
|
|
| Some v' -> Table.replace res k v')
|
|
tbl;
|
|
ref (Table res)
|
|
|
|
exception ExitPTbl
|
|
|
|
let for_all p t =
|
|
try
|
|
iter t (fun k v -> if not (p k v) then raise ExitPTbl);
|
|
true
|
|
with ExitPTbl -> false
|
|
|
|
let exists p t =
|
|
try
|
|
iter t (fun k v -> if p k v then raise ExitPTbl);
|
|
false
|
|
with ExitPTbl -> true
|
|
|
|
let merge f t1 t2 =
|
|
let tbl = Table.create (max (length t1) (length t2)) in
|
|
iter t1 (fun k v1 ->
|
|
let v2 = try Some (find t2 k) with Not_found -> None in
|
|
match f k (Some v1) v2 with
|
|
| None -> ()
|
|
| Some v' -> Table.replace tbl k v');
|
|
iter t2 (fun k v2 ->
|
|
if not (mem t1 k) then (
|
|
match f k None (Some v2) with
|
|
| None -> ()
|
|
| Some _ -> Table.replace tbl k v2
|
|
));
|
|
ref (Table tbl)
|
|
|
|
(*$R
|
|
let t1 = H.of_list [1, "a"; 2, "b1"] in
|
|
let t2 = H.of_list [2, "b2"; 3, "c"] in
|
|
let t = H.merge
|
|
(fun _ v1 v2 -> match v1, v2 with
|
|
| None, _ -> v2
|
|
| _ , None -> v1
|
|
| Some s1, Some s2 -> if s1 < s2 then Some s1 else Some s2)
|
|
t1 t2
|
|
in
|
|
OUnit.assert_equal ~printer:string_of_int 3 (H.length t);
|
|
OUnit.assert_equal "a" (H.find t 1);
|
|
OUnit.assert_equal "b1" (H.find t 2);
|
|
OUnit.assert_equal "c" (H.find t 3);
|
|
*)
|
|
|
|
let add_seq init seq =
|
|
let tbl = ref init in
|
|
seq (fun (k, v) -> tbl := replace !tbl k v);
|
|
!tbl
|
|
|
|
let of_seq seq = add_seq (empty ()) seq
|
|
let add_list init l = add_seq init (fun k -> List.iter k l)
|
|
|
|
(*$QR
|
|
_list_int_int (fun l ->
|
|
let l1, l2 = List.partition (fun (x,_) -> x mod 2 = 0) l in
|
|
let h1 = H.of_list l1 in
|
|
let h2 = H.add_list h1 l2 in
|
|
List.for_all
|
|
(fun (k,v) -> H.find h2 k = v)
|
|
l
|
|
&&
|
|
List.for_all
|
|
(fun (k,v) -> H.find h1 k = v)
|
|
l1
|
|
&&
|
|
List.length l1 = H.length h1
|
|
&&
|
|
List.length l = H.length h2
|
|
)
|
|
*)
|
|
|
|
let of_list l = add_list (empty ()) l
|
|
|
|
let to_list t =
|
|
let tbl = reroot t in
|
|
let bindings = Table.fold (fun k v acc -> (k, v) :: acc) tbl [] in
|
|
bindings
|
|
|
|
(*$R
|
|
let h = H.of_seq my_seq in
|
|
let l = Sequence.to_list (H.to_seq h) in
|
|
OUnit.assert_equal my_list (List.sort compare l)
|
|
*)
|
|
|
|
let to_seq t k =
|
|
let tbl = reroot t in
|
|
Table.iter (fun x y -> k (x, y)) tbl
|
|
|
|
(*$R
|
|
let h = H.of_seq my_seq in
|
|
OUnit.assert_equal "b" (H.find h 2);
|
|
OUnit.assert_equal "a" (H.find h 1);
|
|
OUnit.assert_raises Not_found (fun () -> H.find h 42);
|
|
*)
|
|
|
|
let equal eq t1 t2 =
|
|
length t1 = length t2
|
|
&& for_all
|
|
(fun k v ->
|
|
match get k t2 with
|
|
| None -> false
|
|
| Some v' -> eq v v')
|
|
t1
|
|
|
|
let pp pp_k pp_v buf t =
|
|
Buffer.add_string buf "{";
|
|
let first = ref true in
|
|
iter t (fun k v ->
|
|
if !first then
|
|
first := false
|
|
else
|
|
Buffer.add_string buf ", ";
|
|
Printf.bprintf buf "%a -> %a" pp_k k pp_v v);
|
|
Buffer.add_string buf "}"
|
|
|
|
let print pp_k pp_v fmt t =
|
|
Format.pp_print_string fmt "{";
|
|
let first = ref true in
|
|
iter t (fun k v ->
|
|
if !first then
|
|
first := false
|
|
else (
|
|
Format.pp_print_string fmt ", ";
|
|
Format.pp_print_cut fmt ()
|
|
);
|
|
Format.fprintf fmt "%a -> %a" pp_k k pp_v v);
|
|
Format.pp_print_string fmt "}"
|
|
end
|