add {join_by,join_all_by,group_join_by}

This commit is contained in:
Simon Cruanes 2017-02-02 21:54:33 +01:00
parent a41a997c8f
commit 6c7d59042a
2 changed files with 139 additions and 0 deletions

View file

@ -15,6 +15,9 @@ type (+'a, +'b) t2 = ('a -> 'b -> unit) -> unit
let pp_ilist = Q.Print.(list int) let pp_ilist = Q.Print.(list int)
*) *)
type 'a equal = 'a -> 'a -> bool
type 'a hash = 'a -> int
(** Build a sequence from a iter function *) (** Build a sequence from a iter function *)
let from_iter f = f let from_iter f = f
@ -521,6 +524,96 @@ let join ~join_row s1 s2 k =
OUnit.assert_equal ["1 = 1"; "2 = 2"] (to_list s); OUnit.assert_equal ["1 = 1"; "2 = 2"] (to_list s);
*) *)
let join_by (type a) ?(eq=(=)) ?(hash=Hashtbl.hash) f1 f2 ~merge c1 c2 =
let module Tbl = Hashtbl.Make(struct
type t = a
let equal = eq
let hash = hash
end) in
let tbl = Tbl.create 32 in
c1
(fun x ->
let key = f1 x in
Tbl.add tbl key x);
let res = ref [] in
c2
(fun y ->
let key = f2 y in
let xs = Tbl.find_all tbl key in
List.iter
(fun x -> match merge key x y with
| None -> ()
| Some z -> res := z :: !res)
xs);
fun yield -> List.iter yield !res
type ('a, 'b) join_all_cell = {
mutable ja_left: 'a list;
mutable ja_right: 'b list;
}
let join_all_by (type a) ?(eq=(=)) ?(hash=Hashtbl.hash) f1 f2 ~merge c1 c2 =
let module Tbl = Hashtbl.Make(struct
type t = a
let equal = eq
let hash = hash
end) in
let tbl = Tbl.create 32 in
(* build the map [key -> cell] *)
c1
(fun x ->
let key = f1 x in
try
let c = Tbl.find tbl key in
c.ja_left <- x :: c.ja_left
with Not_found ->
Tbl.add tbl key {ja_left=[x]; ja_right=[]});
c2
(fun y ->
let key = f2 y in
try
let c = Tbl.find tbl key in
c.ja_right <- y :: c.ja_right
with Not_found ->
Tbl.add tbl key {ja_left=[]; ja_right=[y]});
let res = ref [] in
Tbl.iter
(fun key cell -> match merge key cell.ja_left cell.ja_right with
| None -> ()
| Some z -> res := z :: !res)
tbl;
fun yield -> List.iter yield !res
let group_join_by (type a) ?(eq=(=)) ?(hash=Hashtbl.hash) f c1 c2 =
let module Tbl = Hashtbl.Make(struct
type t = a
let equal = eq
let hash = hash
end) in
let tbl = Tbl.create 32 in
c1 (fun x -> Tbl.replace tbl x []);
c2
(fun y ->
(* project [y] into some element of [c1] *)
let key = f y in
try
let l = Tbl.find tbl key in
Tbl.replace tbl key (y :: l)
with Not_found -> ());
fun yield -> Tbl.iter (fun k l -> yield (k,l)) tbl
(*$=
['a', ["abc"; "attic"]; \
'b', ["barbary"; "boom"; "bop"]; \
'c', []] \
(group_join_by (fun s->s.[0]) \
(of_str "abc") \
(of_list ["abc"; "boom"; "attic"; "deleted"; "barbary"; "bop"]) \
|> map (fun (c,l)->c,List.sort Pervasives.compare l) \
|> sort |> to_list)
*)
let rec unfoldr f b k = match f b with let rec unfoldr f b k = match f b with
| None -> () | None -> ()
| Some (x, b') -> | Some (x, b') ->

View file

@ -39,6 +39,9 @@ type +'a sequence = 'a t
type (+'a, +'b) t2 = ('a -> 'b -> unit) -> unit type (+'a, +'b) t2 = ('a -> 'b -> unit) -> unit
(** Sequence of pairs of values of type ['a] and ['b]. *) (** Sequence of pairs of values of type ['a] and ['b]. *)
type 'a equal = 'a -> 'a -> bool
type 'a hash = 'a -> int
(** {2 Build a sequence} *) (** {2 Build a sequence} *)
val from_iter : (('a -> unit) -> unit) -> 'a t val from_iter : (('a -> unit) -> unit) -> 'a t
@ -285,6 +288,49 @@ val join : join_row:('a -> 'b -> 'c option) -> 'a t -> 'b t -> 'c t
the two elements do not combine. Assume that [b] allows for multiple the two elements do not combine. Assume that [b] allows for multiple
iterations. *) iterations. *)
val join_by : ?eq:'key equal -> ?hash:'key hash ->
('a -> 'key) -> ('b -> 'key) ->
merge:('key -> 'a -> 'b -> 'c option) ->
'a t ->
'b t ->
'c t
(** [join key1 key2 ~merge] is a binary operation
that takes two sequences [a] and [b], projects their
elements resp. with [key1] and [key2], and combine
values [(x,y)] from [(a,b)] with the same [key]
using [merge]. If [merge] returns [None], the combination
of values is discarded.
@since NEXT_RELEASE *)
val join_all_by : ?eq:'key equal -> ?hash:'key hash ->
('a -> 'key) -> ('b -> 'key) ->
merge:('key -> 'a list -> 'b list -> 'c option) ->
'a t ->
'b t ->
'c t
(** [join_all_by key1 key2 ~merge] is a binary operation
that takes two sequences [a] and [b], projects their
elements resp. with [key1] and [key2], and, for each key [k]
occurring in at least one of them:
- compute the list [l1] of elements of [a] that map to [k]
- compute the list [l2] of elements of [b] that map to [k]
- call [merge k l1 l2]. If [merge] returns [None], the combination
of values is discarded, otherwise it returns [Some c]
and [c] is inserted in the result.
@since NEXT_RELEASE *)
val group_join_by : ?eq:'a equal -> ?hash:'a hash ->
('b -> 'a) ->
'a t ->
'b t ->
('a * 'b list) t
(** [group_join_by key2] associates to every element [x] of
the first sequence, all the elements [y] of the second
sequence such that [eq x (key y)]. Elements of the first
sequences without corresponding values in the second one
are mapped to [[]]
@since NEXT_RELEASE *)
val unfoldr : ('b -> ('a * 'b) option) -> 'b -> 'a t val unfoldr : ('b -> ('a * 'b) option) -> 'b -> 'a t
(** [unfoldr f b] will apply [f] to [b]. If it (** [unfoldr f b] will apply [f] to [b]. If it
yields [Some (x,b')] then [x] is returned yields [Some (x,b')] then [x] is returned