cleaner system to specify hash/eq/cmp for operations in CCLinq;

use Map to implement most binary operations, including join
This commit is contained in:
Simon Cruanes 2014-06-14 02:16:49 +02:00
parent 2492ee48a6
commit 4550a1c2c2
2 changed files with 265 additions and 189 deletions

View file

@ -31,12 +31,138 @@ type 'a equal = 'a -> 'a -> bool
type 'a ord = 'a -> 'a -> int
type 'a hash = 'a -> int
(* TODO: add CCVector as a collection *)
let _id x = x
module Map = struct
type ('a, 'b) t = {
is_empty : unit -> bool;
size : unit -> int; (** Number of keys *)
get : 'a -> 'b option;
fold : 'c. ('c -> 'a -> 'b -> 'c) -> 'c -> 'c;
to_seq : ('a * 'b) sequence;
}
let get m x = m.get x
let mem m x = match m.get x with
| None -> false
| Some _ -> true
let to_seq m = m.to_seq
let size m = m.size ()
type ('a, 'b) build = {
mutable cur : ('a, 'b) t;
add : 'a -> 'b -> unit;
update : 'a -> ('b option -> 'b option) -> unit;
}
let build_get b = b.cur
let add b x y = b.add x y
let update b f = b.update f
(* careful to use this map linearly *)
let make_hash (type key) ?(eq=(=)) ?(hash=Hashtbl.hash) () =
let module H = Hashtbl.Make(struct
type t = key
let equal = eq
let hash = hash
end) in
(* build table *)
let tbl = H.create 32 in
let cur = {
is_empty = (fun () -> H.length tbl = 0);
size = (fun () -> H.length tbl);
get = (fun k ->
try Some (H.find tbl k)
with Not_found -> None);
fold = (fun f acc -> H.fold (fun k v acc -> f acc k v) tbl acc);
to_seq = (fun k -> H.iter (fun key v -> k (key,v)) tbl);
} in
{ cur;
add = (fun k v -> H.replace tbl k v);
update = (fun k f ->
match (try f (Some (H.find tbl k)) with Not_found -> f None) with
| None -> H.remove tbl k
| Some v' -> H.replace tbl k v');
}
let make_cmp (type key) ?(cmp=Pervasives.compare) () =
let module M = CCSequence.Map.Make(struct
type t = key
let compare = cmp
end) in
let map = ref M.empty in
let cur = {
is_empty = (fun () -> M.is_empty !map);
size = (fun () -> M.cardinal !map);
get = (fun k ->
try Some (M.find k !map)
with Not_found -> None);
fold = (fun f acc ->
M.fold
(fun key set acc -> f acc key set) !map acc
);
to_seq = (fun k -> M.to_seq !map k);
} in
{
cur;
add = (fun k v -> map := M.add k v !map);
update = (fun k f ->
match (try f (Some (M.find k !map)) with Not_found -> f None) with
| None -> map := M.remove k !map
| Some v' -> map := M.add k v' !map);
}
type 'a build_method =
| FromCmp of 'a ord
| FromHash of 'a equal * 'a hash
| Default
let make ?(build=Default) () = match build with
| Default -> make_hash ()
| FromCmp cmp -> make_cmp ~cmp ()
| FromHash (eq,hash) -> make_hash ~eq ~hash ()
let multimap_of_seq ?(build=make ()) seq =
seq (fun (k,v) ->
build.update k (function
| None -> Some [v]
| Some l -> Some (v::l)));
build.cur
let count_of_seq ?(build=make ()) seq =
seq (fun x ->
build.update x
(function
| None -> Some 1
| Some n -> Some (n+1)));
build.cur
let get_exn m x =
match m.get x with
| None -> raise Not_found
| Some x -> x
let to_list m = m.to_seq |> CCSequence.to_rev_list
end
type 'a search_result =
| SearchContinue
| SearchStop of 'a
type ('a,'b,'key,'c) join_descr = {
join_key1 : 'a -> 'key;
join_key2 : 'b -> 'key;
join_merge : 'key -> 'a -> 'b -> 'c option;
join_build : 'key Map.build_method;
}
type ('a,'b) group_join_descr = {
gjoin_proj : 'b -> 'a;
gjoin_build : 'a Map.build_method;
}
module Coll = struct
type 'a t =
| Seq : 'a sequence -> 'a t
@ -168,131 +294,87 @@ module Coll = struct
assert (eq x y);
true
with Not_found -> false
let do_join ~join c1 c2 =
let build1 =
to_seq c1
|> CCSequence.map (fun x -> join.join_key1 x, x)
|> Map.multimap_of_seq ~build:(Map.make ~build:join.join_build ())
in
let l = CCSequence.fold
(fun acc y ->
let key = join.join_key2 y in
match Map.get build1 key with
| None -> acc
| Some l1 ->
List.fold_left
(fun acc x -> match join.join_merge key x y with
| None -> acc
| Some res -> res::acc
) acc l1
) [] (to_seq c2)
in
of_list l
let do_group_join ~gjoin c1 c2 =
let build = Map.make ~build:gjoin.gjoin_build () in
to_seq c1 (fun x -> Map.add build x []);
to_seq c2
(fun y ->
(* project [y] into some element of [c1] *)
let x = gjoin.gjoin_proj y in
Map.update build x
(function
| None -> None (* [x] not present, ignore! *)
| Some l -> Some (y::l)
)
);
Map.build_get build
let do_product c1 c2 =
let s1 = to_seq c1 and s2 = to_seq c2 in
of_seq (CCSequence.product s1 s2)
let do_union ~build c1 c2 =
let build = Map.make ~build () in
to_seq c1 (fun x -> Map.add build x ());
to_seq c2 (fun x -> Map.add build x ());
Map.to_seq (Map.build_get build)
|> CCSequence.map fst
|> of_seq
type inter_status =
| InterLeft
| InterDone (* already output *)
let do_inter ~build c1 c2 =
let build = Map.make ~build () in
let l = ref [] in
to_seq c1 (fun x -> Map.add build x InterLeft);
to_seq c2 (fun x ->
Map.update build x
(function
| None -> Some InterDone
| Some InterDone as foo -> foo
| Some InterLeft ->
l := x :: !l;
Some InterDone
)
);
of_list !l
let do_diff ~build c1 c2 =
let build = Map.make ~build () in
to_seq c2 (fun x -> Map.add build x ());
let map = Map.build_get build in
(* output elements of [c1] not in [map] *)
to_seq c1
|> CCSequence.filter (fun x -> not (Map.mem map x))
|> of_seq
end
type 'a collection = 'a Coll.t
module Map = struct
type ('a, 'b) t = {
is_empty : unit -> bool;
size : unit -> int; (** Number of keys *)
get : 'a -> 'b option;
fold : 'c. ('c -> 'a -> 'b -> 'c) -> 'c -> 'c;
to_seq : ('a * 'b) sequence;
}
type ('a, 'b) build = {
mutable cur : ('a, 'b) t;
add : 'a -> 'b -> unit;
update : 'a -> ('b option -> 'b option) -> unit;
}
(* careful to use this map linearly *)
let make_hash (type key) ?(eq=(=)) ?(hash=Hashtbl.hash) () =
let module H = Hashtbl.Make(struct
type t = key
let equal = eq
let hash = hash
end) in
(* build table *)
let tbl = H.create 32 in
let cur = {
is_empty = (fun () -> H.length tbl = 0);
size = (fun () -> H.length tbl);
get = (fun k ->
try Some (H.find tbl k)
with Not_found -> None);
fold = (fun f acc -> H.fold (fun k v acc -> f acc k v) tbl acc);
to_seq = (fun k -> H.iter (fun key v -> k (key,v)) tbl);
} in
{ cur;
add = (fun k v -> H.replace tbl k v);
update = (fun k f ->
match (try f (Some (H.find tbl k)) with Not_found -> f None) with
| None -> H.remove tbl k
| Some v' -> H.replace tbl k v');
}
let make_cmp (type key) ?(cmp=Pervasives.compare) () =
let module M = CCSequence.Map.Make(struct
type t = key
let compare = cmp
end) in
let map = ref M.empty in
let cur = {
is_empty = (fun () -> M.is_empty !map);
size = (fun () -> M.cardinal !map);
get = (fun k ->
try Some (M.find k !map)
with Not_found -> None);
fold = (fun f acc ->
M.fold
(fun key set acc -> f acc key set) !map acc
);
to_seq = (fun k -> M.to_seq !map k);
} in
{
cur;
add = (fun k v -> map := M.add k v !map);
update = (fun k f ->
match (try f (Some (M.find k !map)) with Not_found -> f None) with
| None -> map := M.remove k !map
| Some v' -> map := M.add k v' !map);
}
type 'a key_info = {
eq : 'a equal option;
cmp : 'a ord option;
hash : 'a hash option;
}
let default_key_info = {
eq=None; cmp=None; hash=None;
}
let make_info info =
match info with
| { hash=None; _}
| { eq=None; _} ->
begin match info.cmp with
| None -> make_cmp ()
| Some cmp -> make_cmp ~cmp ()
end
| {eq=Some eq; hash=Some hash; _} -> make_hash ~eq ~hash ()
let multimap build seq =
seq (fun (k,v) ->
build.update k (function
| None -> Some [v]
| Some l -> Some (v::l)));
build.cur
let multimap_cmp ?cmp seq =
let build = make_cmp ?cmp () in
multimap build seq
let count build seq =
seq (fun x ->
build.update x
(function
| None -> Some 1
| Some n -> Some (n+1)));
build.cur
let get m x = m.get x
let get_exn m x =
match m.get x with
| None -> raise Not_found
| Some x -> x
let size m = m.size ()
let to_seq m = m.to_seq
let to_list m = m.to_seq |> CCSequence.to_rev_list
end
(** {2 Query operators} *)
type (_,_) safety =
@ -320,21 +402,9 @@ type (_, _) unary =
> -> ('a collection, 'b) unary
| Contains : 'a equal * 'a -> ('a collection, bool) unary
| Get : ('b,'c) safety * 'a -> (('a,'b) Map.t, 'c) unary
| GroupBy : 'b ord * ('a -> 'b)
| GroupBy : 'b Map.build_method * ('a -> 'b)
-> ('a collection, ('b,'a list) Map.t) unary
| Count : 'a ord -> ('a collection, ('a, int) Map.t) unary
type ('a,'b,'key,'c) join_descr = {
join_key1 : 'a -> 'key;
join_key2 : 'b -> 'key;
join_merge : 'key -> 'a -> 'b -> 'c option;
join_key : 'key Map.key_info;
}
type ('a,'b) group_join_descr = {
gjoin_proj : 'b -> 'a;
gjoin_key : 'a Map.key_info;
}
| Count : 'a Map.build_method -> ('a collection, ('a, int) Map.t) unary
type set_op =
| Union
@ -345,10 +415,11 @@ type (_, _, _) binary =
| Join : ('a, 'b, 'key, 'c) join_descr
-> ('a collection, 'b collection, 'c collection) binary
| GroupJoin : ('a, 'b) group_join_descr
-> ('a collection, 'b collection, ('a, 'b collection) Map.t) binary
-> ('a collection, 'b collection, ('a, 'b list) Map.t) binary
| Product : ('a collection, 'b collection, ('a*'b) collection) binary
| Append : ('a collection, 'a collection, 'a collection) binary
| SetOp : set_op * 'a ord -> ('a collection, 'a collection, 'a collection) binary
| SetOp : set_op * 'a Map.build_method
-> ('a collection, 'a collection, 'a collection) binary
(* type of queries that return a 'a *)
and 'a t =
@ -457,46 +528,25 @@ let _do_unary : type a b. (a,b) unary -> a -> b
| Search obj -> Coll.search obj c
| Get (Safe, k) -> Map.get c k
| Get (Unsafe, k) -> Map.get_exn c k
| GroupBy (cmp,f) ->
| GroupBy (build,f) ->
Coll.to_seq c
|> CCSequence.map (fun x -> f x, x)
|> Map.multimap_cmp ~cmp
|> Map.multimap_of_seq ~build:(Map.make ~build ())
| Contains (eq, x) -> Coll.contains ~eq x c
| Count cmp ->
| Count build ->
Coll.to_seq c
|> Map.count (Map.make_cmp ~cmp ())
(* TODO: join of two collections *)
let _do_join ~join c1 c2 =
let _build = Map.make_info join.join_key in
assert false
(* TODO *)
let _do_group_join ~gjoin c1 c2 =
assert false
let _do_product c1 c2 =
let s1 = Coll.to_seq c1 and s2 = Coll.to_seq c2 in
Coll.of_seq (CCSequence.product s1 s2)
|> Map.count_of_seq ~build:(Map.make ~build ())
let _do_binary : type a b c. (a, b, c) binary -> a -> b -> c
= fun b c1 c2 -> match b with
| Join join -> _do_join ~join c1 c2
| GroupJoin gjoin -> _do_group_join ~gjoin c1 c2
| Product -> _do_product c1 c2
| Join join -> Coll.do_join ~join c1 c2
| GroupJoin gjoin -> Coll.do_group_join ~gjoin c1 c2
| Product -> Coll.do_product c1 c2
| Append ->
Coll.of_seq (CCSequence.append (Coll.to_seq c1) (Coll.to_seq c2))
| SetOp (Inter,cmp) ->
(* use a join *)
_do_join ~join:{
join_key1=_id;
join_key2=_id;
join_merge=(fun _ a b -> Some a);
join_key=Map.({default_key_info with cmp=Some cmp; })
} c1 c2
| SetOp (Union,cmp) -> failwith "union: not implemented" (* TODO *)
| SetOp (Diff,cmp) -> failwith "diff: not implemented" (* TODO *)
| SetOp (Inter,build) -> Coll.do_inter ~build c1 c2
| SetOp (Union,build) -> Coll.do_union ~build c1 c2
| SetOp (Diff,build) -> Coll.do_diff ~build c1 c2
let rec _run : type a. opt:bool -> a t -> a
= fun ~opt q -> match q with
@ -564,14 +614,29 @@ let map_iter_flatten q =
let map_to_list q =
Unary (GeneralMap Map.to_list, q)
let group_by ?(cmp=Pervasives.compare) f q =
Unary (GroupBy (cmp,f), q)
(* choose a build method from the optional arguments *)
let _make_build ?cmp ?eq ?hash () =
let _maybe default o = match o with
| Some x -> x
| None -> default
in
match eq, hash with
| Some _, _
| _, Some _ ->
Map.FromHash ( _maybe (=) eq, _maybe Hashtbl.hash hash)
| _ ->
match cmp with
| Some f -> Map.FromCmp f
| _ -> Map.Default
let group_by' ?cmp f q =
let group_by ?cmp ?eq ?hash f q =
Unary (GroupBy (_make_build ?cmp ?eq ?hash (),f), q)
let group_by' ?cmp ?eq ?hash f q =
map_iter (group_by ?cmp f q)
let count ?(cmp=Pervasives.compare) () q =
Unary (Count cmp, q)
let count ?cmp ?eq ?hash () q =
Unary (Count (_make_build ?cmp ?eq ?hash ()), q)
let count' ?cmp () q =
map_iter (count ?cmp () q)
@ -643,18 +708,20 @@ let find_map f q =
(** {6 Binary Operators} *)
let join ?cmp ?eq ?hash join_key1 join_key2 ~merge q1 q2 =
let join_build = _make_build ?eq ?hash ?cmp () in
let j = {
join_key1;
join_key2;
join_merge=merge;
join_key = Map.({ eq; cmp; hash; });
join_build;
} in
Binary (Join j, q1, q2)
let group_join ?cmp ?eq ?hash gjoin_proj q1 q2 =
let gjoin_build = _make_build ?eq ?hash ?cmp () in
let j = {
gjoin_proj;
gjoin_key = Map.({ eq; cmp; hash; });
gjoin_build;
} in
Binary (GroupJoin j, q1, q2)
@ -662,14 +729,17 @@ let product q1 q2 = Binary (Product, q1, q2)
let append q1 q2 = Binary (Append, q1, q2)
let inter ?(cmp=Pervasives.compare) () q1 q2 =
Binary (SetOp (Inter, cmp), q1, q2)
let inter ?cmp ?eq ?hash () q1 q2 =
let build = _make_build ?cmp ?eq ?hash () in
Binary (SetOp (Inter, build), q1, q2)
let union ?(cmp=Pervasives.compare) () q1 q2 =
Binary (SetOp (Union, cmp), q1, q2)
let union ?cmp ?eq ?hash () q1 q2 =
let build = _make_build ?cmp ?eq ?hash () in
Binary (SetOp (Union, build), q1, q2)
let diff ?(cmp=Pervasives.compare) () q1 q2 =
Binary (SetOp (Diff, cmp), q1, q2)
let diff ?cmp ?eq ?hash () q1 q2 =
let build = _make_build ?cmp ?eq ?hash () in
Binary (SetOp (Diff, build), q1, q2)
let fst q = map fst q
let snd q = map snd q

View file

@ -27,7 +27,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
(** {1 LINQ-like operations on collections}
The purpose it to provide powerful combinators to express iteration,
transformation and combination of collections of items.
transformation and combination of collections of items. This module depends
on several other modules, including {!CCList} and {!CCSequence}.
Functions and operations are assumed to be referentially transparent, i.e.
they should not rely on external side effects, they should not rely on
the order of execution.
{[
@ -161,17 +166,18 @@ val map_to_list : ('a,'b) Map.t t -> ('a*'b) list t
(** {6 Aggregation} *)
val group_by : ?cmp:'b ord ->
val group_by : ?cmp:'b ord -> ?eq:'b equal -> ?hash:'b hash ->
('a -> 'b) -> 'a collection t -> ('b,'a list) Map.t t
(** [group_by f] takes a collection [c] as input, and returns
a multimap [m] such that for each [x] in [c],
[x] occurs in [m] under the key [f x]. In other words, [f] is used
to obtain a key from [x], and [x] is added to the multimap using this key. *)
val group_by' : ?cmp:'b ord ->
val group_by' : ?cmp:'b ord -> ?eq:'b equal -> ?hash:'b hash ->
('a -> 'b) -> 'a collection t -> ('b * 'a list) collection t
val count : ?cmp:'a ord -> unit -> 'a collection t -> ('a, int) Map.t t
val count : ?cmp:'a ord -> ?eq:'a equal -> ?hash:'a hash ->
unit -> 'a collection t -> ('a, int) Map.t t
(** [count c] returns a map from elements of [c] to the number
of time those elements occur. *)
@ -228,7 +234,7 @@ val join : ?cmp:'key ord -> ?eq:'key equal -> ?hash:'key hash ->
val group_join : ?cmp:'a ord -> ?eq:'a equal -> ?hash:'a hash ->
('b -> 'a) -> 'a collection t -> 'b collection t ->
('a, 'b collection) Map.t t
('a, 'b list) Map.t t
(** [group_join key2] associates to every element [x] of
the first collection, all the elements [y] of the second
collection such that [eq x (key y)] *)
@ -239,17 +245,17 @@ val product : 'a collection t -> 'b collection t -> ('a * 'b) collection t
val append : 'a collection t -> 'a collection t -> 'a collection t
(** Append two collections together *)
val inter : ?cmp:'a ord -> unit ->
val inter : ?cmp:'a ord -> ?eq:'a equal -> ?hash:'a hash -> unit ->
'a collection t -> 'a collection t -> 'a collection t
(** Intersection of two collections. Each element will occur at most once
in the result *)
val union : ?cmp:'a ord -> unit ->
val union : ?cmp:'a ord -> ?eq:'a equal -> ?hash:'a hash -> unit ->
'a collection t -> 'a collection t -> 'a collection t
(** Union of two collections. Each element will occur at most once
in the result *)
val diff : ?cmp:'a ord -> unit ->
val diff : ?cmp:'a ord -> ?eq:'a equal -> ?hash:'a hash -> unit ->
'a collection t -> 'a collection t -> 'a collection t
(** Set difference *)