diff --git a/sequence.ml b/sequence.ml index 6c3f2f0..e29eb2b 100644 --- a/sequence.ml +++ b/sequence.ml @@ -274,6 +274,18 @@ let product outer inner = outer (fun x -> inner (fun y -> k (x,y)))) +(** [join ~join_row a b] combines every element of [a] with every + element of [b] using [join_row]. If [join_row] returns None, then + the two elements do not combine. Assume that [b] allows for multiple + iterations. *) +let join ~join_row s1 s2 = + fun k -> + s1 (fun a -> + s2 (fun b -> + match join_row a b with + | None -> () + | Some c -> k c)) (* yield the combination of [a] and [b] *) + (** [unfoldr f b] will apply [f] to [b]. If it yields [Some (x,b')] then [x] is returned and unfoldr recurses with [b']. *) diff --git a/sequence.mli b/sequence.mli index 1ac41ba..a8cf345 100644 --- a/sequence.mli +++ b/sequence.mli @@ -147,6 +147,12 @@ val product : 'a t -> 'b t -> ('a * 'b) t by calling [persistent] on it, so that it can be traversed several times (outer loop of the product) *) +val join : join_row:('a -> 'b -> 'c option) -> 'a t -> 'b t -> 'c t + (** [join ~join_row a b] combines every element of [a] with every + element of [b] using [join_row]. If [join_row] returns None, then + the two elements do not combine. Assume that [b] allows for multiple + iterations. *) + val unfoldr : ('b -> ('a * 'b) option) -> 'b -> 'a t (** [unfoldr f b] will apply [f] to [b]. If it yields [Some (x,b')] then [x] is returned diff --git a/tests/test_sequence.ml b/tests/test_sequence.ml index 0d7a928..fbb0a8d 100644 --- a/tests/test_sequence.ml +++ b/tests/test_sequence.ml @@ -88,6 +88,7 @@ let test_persistent () = let printer = pp_ilist in let stream = Stream.from (fun i -> if i < 5 then Some i else None) in let seq = S.of_stream stream in + (* consume seq into a persistent version of itself *) let seq' = S.persistent seq in OUnit.assert_equal ~printer [] (seq |> S.to_list); OUnit.assert_equal ~printer [0;1;2;3;4] (seq' |> S.to_list); @@ -127,6 +128,16 @@ let test_product () = "b",0; "b", 1; "b", 2; "c",0; "c", 1; "c", 2;] s +let test_join () = + let s1 = (1 -- 3) in + let s2 = S.of_list ["1"; "2"] in + let join_row i j = + if string_of_int i = j then Some (string_of_int i ^ " = " ^ j) else None + in + let s = S.join ~join_row s1 s2 in + OUnit.assert_equal ["1 = 1"; "2 = 2"] (S.to_list s); + () + let test_scan () = 1 -- 5 |> S.scan (+) 0 @@ -181,6 +192,7 @@ let suite = "test_group" >:: test_group; "test_uniq" >:: test_uniq; "test_product" >:: test_product; + "test_join" >:: test_join; "test_scan" >:: test_scan; "test_drop" >:: test_drop; "test_rev" >:: test_rev;