diff --git a/src/Sequence.ml b/src/Sequence.ml index 485a7b2..da7daa8 100644 --- a/src/Sequence.ml +++ b/src/Sequence.ml @@ -15,6 +15,9 @@ type (+'a, +'b) t2 = ('a -> 'b -> unit) -> unit let pp_ilist = Q.Print.(list int) *) +type 'a equal = 'a -> 'a -> bool +type 'a hash = 'a -> int + (** Build a sequence from a iter function *) let from_iter f = f @@ -521,6 +524,96 @@ let join ~join_row s1 s2 k = OUnit.assert_equal ["1 = 1"; "2 = 2"] (to_list s); *) +let join_by (type a) ?(eq=(=)) ?(hash=Hashtbl.hash) f1 f2 ~merge c1 c2 = + let module Tbl = Hashtbl.Make(struct + type t = a + let equal = eq + let hash = hash + end) in + let tbl = Tbl.create 32 in + c1 + (fun x -> + let key = f1 x in + Tbl.add tbl key x); + let res = ref [] in + c2 + (fun y -> + let key = f2 y in + let xs = Tbl.find_all tbl key in + List.iter + (fun x -> match merge key x y with + | None -> () + | Some z -> res := z :: !res) + xs); + fun yield -> List.iter yield !res + +type ('a, 'b) join_all_cell = { + mutable ja_left: 'a list; + mutable ja_right: 'b list; +} + +let join_all_by (type a) ?(eq=(=)) ?(hash=Hashtbl.hash) f1 f2 ~merge c1 c2 = + let module Tbl = Hashtbl.Make(struct + type t = a + let equal = eq + let hash = hash + end) in + let tbl = Tbl.create 32 in + (* build the map [key -> cell] *) + c1 + (fun x -> + let key = f1 x in + try + let c = Tbl.find tbl key in + c.ja_left <- x :: c.ja_left + with Not_found -> + Tbl.add tbl key {ja_left=[x]; ja_right=[]}); + c2 + (fun y -> + let key = f2 y in + try + let c = Tbl.find tbl key in + c.ja_right <- y :: c.ja_right + with Not_found -> + Tbl.add tbl key {ja_left=[]; ja_right=[y]}); + let res = ref [] in + Tbl.iter + (fun key cell -> match merge key cell.ja_left cell.ja_right with + | None -> () + | Some z -> res := z :: !res) + tbl; + fun yield -> List.iter yield !res + +let group_join_by (type a) ?(eq=(=)) ?(hash=Hashtbl.hash) f c1 c2 = + let module Tbl = Hashtbl.Make(struct + type t = a + let equal = eq + let hash = hash + end) in + let tbl = Tbl.create 32 in + c1 (fun x -> Tbl.replace tbl x []); + c2 + (fun y -> + (* project [y] into some element of [c1] *) + let key = f y in + try + let l = Tbl.find tbl key in + Tbl.replace tbl key (y :: l) + with Not_found -> ()); + fun yield -> Tbl.iter (fun k l -> yield (k,l)) tbl + +(*$= + ['a', ["abc"; "attic"]; \ + 'b', ["barbary"; "boom"; "bop"]; \ + 'c', []] \ + (group_join_by (fun s->s.[0]) \ + (of_str "abc") \ + (of_list ["abc"; "boom"; "attic"; "deleted"; "barbary"; "bop"]) \ + |> map (fun (c,l)->c,List.sort Pervasives.compare l) \ + |> sort |> to_list) +*) + + let rec unfoldr f b k = match f b with | None -> () | Some (x, b') -> diff --git a/src/Sequence.mli b/src/Sequence.mli index 15d394a..1d6509a 100644 --- a/src/Sequence.mli +++ b/src/Sequence.mli @@ -39,6 +39,9 @@ type +'a sequence = 'a t type (+'a, +'b) t2 = ('a -> 'b -> unit) -> unit (** Sequence of pairs of values of type ['a] and ['b]. *) +type 'a equal = 'a -> 'a -> bool +type 'a hash = 'a -> int + (** {2 Build a sequence} *) val from_iter : (('a -> unit) -> unit) -> 'a t @@ -285,6 +288,49 @@ val join : join_row:('a -> 'b -> 'c option) -> 'a t -> 'b t -> 'c t the two elements do not combine. Assume that [b] allows for multiple iterations. *) +val join_by : ?eq:'key equal -> ?hash:'key hash -> + ('a -> 'key) -> ('b -> 'key) -> + merge:('key -> 'a -> 'b -> 'c option) -> + 'a t -> + 'b t -> + 'c t +(** [join key1 key2 ~merge] is a binary operation + that takes two sequences [a] and [b], projects their + elements resp. with [key1] and [key2], and combine + values [(x,y)] from [(a,b)] with the same [key] + using [merge]. If [merge] returns [None], the combination + of values is discarded. + @since NEXT_RELEASE *) + +val join_all_by : ?eq:'key equal -> ?hash:'key hash -> + ('a -> 'key) -> ('b -> 'key) -> + merge:('key -> 'a list -> 'b list -> 'c option) -> + 'a t -> + 'b t -> + 'c t +(** [join_all_by key1 key2 ~merge] is a binary operation + that takes two sequences [a] and [b], projects their + elements resp. with [key1] and [key2], and, for each key [k] + occurring in at least one of them: + - compute the list [l1] of elements of [a] that map to [k] + - compute the list [l2] of elements of [b] that map to [k] + - call [merge k l1 l2]. If [merge] returns [None], the combination + of values is discarded, otherwise it returns [Some c] + and [c] is inserted in the result. + @since NEXT_RELEASE *) + +val group_join_by : ?eq:'a equal -> ?hash:'a hash -> + ('b -> 'a) -> + 'a t -> + 'b t -> + ('a * 'b list) t +(** [group_join_by key2] associates to every element [x] of + the first sequence, all the elements [y] of the second + sequence such that [eq x (key y)]. Elements of the first + sequences without corresponding values in the second one + are mapped to [[]] + @since NEXT_RELEASE *) + val unfoldr : ('b -> ('a * 'b) option) -> 'b -> 'a t (** [unfoldr f b] will apply [f] to [b]. If it yields [Some (x,b')] then [x] is returned