From 0c04df58b011de710ee621466fb4043750770954 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 11 Mar 2016 21:16:34 +0100 Subject: [PATCH] update CCHet to not use Obj.magic; add test --- src/data/CCHet.ml | 124 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 91 insertions(+), 33 deletions(-) diff --git a/src/data/CCHet.ml b/src/data/CCHet.ml index ed46f1d3..ff86f672 100644 --- a/src/data/CCHet.ml +++ b/src/data/CCHet.ml @@ -3,37 +3,80 @@ (** {1 Associative containers with Heterogenerous Values} *) +(*$R + let k1 : int Key.t = Key.create() in + let k2 : int Key.t = Key.create() in + let k3 : string Key.t = Key.create() in + let k4 : float Key.t = Key.create() in + + let tbl = Tbl.create () in + + Tbl.add tbl k1 1; + Tbl.add tbl k2 2; + Tbl.add tbl k3 "k3"; + + assert_equal (Some 1) (Tbl.find tbl k1); + assert_equal (Some 2) (Tbl.find tbl k2); + assert_equal (Some "k3") (Tbl.find tbl k3); + assert_equal None (Tbl.find tbl k4); + assert_equal 3 (Tbl.length tbl); + + Tbl.add tbl k1 10; + assert_equal (Some 10) (Tbl.find tbl k1); + assert_equal 3 (Tbl.length tbl); + assert_equal None (Tbl.find tbl k4); + + Tbl.add tbl k4 0.0; + assert_equal (Some 0.0) (Tbl.find tbl k4); + + () + + +*) + type 'a sequence = ('a -> unit) -> unit type 'a gen = unit -> 'a option +module type KEY_IMPL = sig + type t + exception Store of t + val id : int +end + module Key = struct - type 'a t = int + type 'a t = (module KEY_IMPL with type t = 'a) - let create = - let _n = ref 0 in - fun () -> - incr _n; - !_n + let _n = ref 0 - let id a = a + let create (type k) () = + incr _n; + let id = !_n in + let module K = struct + type t = k + let id = id + exception Store of k + end in + (module K : KEY_IMPL with type t = k) + + let id (type k) (module K : KEY_IMPL with type t = k) = K.id let equal : type a b. a t -> b t -> bool - = fun a b -> - let ia = (a : a t :> int) in - let ib = (b : b t :> int) in - ia=ib - - (* XXX: the only ugly part *) - (* [cast_res k1 k2 v2] casts [v2] into a value of type [a] if [k1=k2] *) - let cast_res_ : type a b. a t -> b t -> b -> a - = fun k1 k2 v2 -> - if k1=k2 then Obj.magic v2 else raise Not_found + = fun (module K1) (module K2) -> K1.id = K2.id end type pair = | Pair : 'a Key.t * 'a -> pair +type exn_pair = + | E_pair : 'a Key.t * exn -> exn_pair + +let pair_of_e_pair (E_pair (k,e)) = + let module K = (val k) in + match e with + | K.Store v -> Pair (k,v) + | _ -> assert false + module Tbl = struct module M = Hashtbl.Make(struct type t = int @@ -41,33 +84,38 @@ module Tbl = struct let hash (i:int) = Hashtbl.hash i end) - type t = pair M.t + type t = exn_pair M.t let create ?(size=16) () = M.create size let mem t k = M.mem t (Key.id k) let find_exn (type a) t (k : a Key.t) : a = - let Pair (k', v) = M.find t (Key.id k) in - Key.cast_res_ k k' v + let module K = (val k) in + let E_pair (_, v) = M.find t K.id in + match v with + | K.Store v -> v + | _ -> assert false let find t k = try Some (find_exn t k) with Not_found -> None let add_pair_ t p = - let Pair (k,_) = p in - M.replace t (Key.id k) p + let Pair (k,v) = p in + let module K = (val k) in + let p = E_pair (k, K.Store v) in + M.replace t K.id p let add t k v = add_pair_ t (Pair (k,v)) let length t = M.length t - let iter f t = M.iter (fun _ pair -> f pair) t + let iter f t = M.iter (fun _ pair -> f (pair_of_e_pair pair)) t let to_seq t yield = iter yield t - let to_list t = M.fold (fun _ p l -> p::l) t [] + let to_list t = M.fold (fun _ p l -> pair_of_e_pair p::l) t [] let add_list t l = List.iter (add_pair_ t) l @@ -90,35 +138,45 @@ module Map = struct let compare (i:int) j = Pervasives.compare i j end) - type t = pair M.t + type t = exn_pair M.t let empty = M.empty let mem k t = M.mem (Key.id k) t let find_exn (type a) (k : a Key.t) t : a = - let Pair (k', v) = M.find (Key.id k) t in - Key.cast_res_ k k' v + let module K = (val k) in + let E_pair (_, e) = M.find K.id t in + match e with + | K.Store v -> v + | _ -> assert false let find k t = try Some (find_exn k t) with Not_found -> None - let add_pair_ p t = - let Pair (k,_) = p in - M.add (Key.id k) p t + let add_e_pair_ p t = + let E_pair ((module K),_) = p in + M.add K.id p t - let add k v t = add_pair_ (Pair (k,v)) t + let add_pair_ p t = + let Pair ((module K) as k,v) = p in + let p = E_pair (k, K.Store v) in + M.add K.id p t + + let add (type a) (k : a Key.t) v t = + let module K = (val k) in + add_e_pair_ (E_pair (k, K.Store v)) t let cardinal t = M.cardinal t let length = cardinal - let iter f t = M.iter (fun _ pair -> f pair) t + let iter f t = M.iter (fun _ p -> f (pair_of_e_pair p)) t let to_seq t yield = iter yield t - let to_list t = M.fold (fun _ p l -> p::l) t [] + let to_list t = M.fold (fun _ p l -> pair_of_e_pair p::l) t [] let add_list t l = List.fold_right add_pair_ l t