diff --git a/sequence.ml b/sequence.ml index 4e0afaf..738fc28 100644 --- a/sequence.ml +++ b/sequence.ml @@ -29,3 +29,23 @@ let map f seq = let filter p seq = let seq_fun' k = seq.seq_fun (fun x -> if p x then k x) in { seq_fun=seq_fun'; } + +(** Concatenate two sequences *) +let concat s1 s2 = + let seq_fun k = s1.seq_fun k; s2.seq_fun k in + { seq_fun; } + +(** Take at most [n] elements from the sequence *) +let take n seq = + let count = ref 0 in + let seq_fun k = seq.seq_fun + (fun x -> + if !count < n then begin incr count; k x end) + in { seq_fun; } + +(** Drop the [n] first elements of the sequence *) +let drop n seq = + let count = ref 0 in + let seq_fun k = seq.seq_fun + (fun x -> if !count >= n then k x else incr count) + in { seq_fun; } diff --git a/sequence.mli b/sequence.mli index 7c4f7f4..b613f7c 100644 --- a/sequence.mli +++ b/sequence.mli @@ -4,9 +4,13 @@ type 'a sequence (** Sequence abstract iterator type *) +(** {2 Build a sequence} *) + val from_iter : (('a -> unit) -> unit) -> 'a sequence (** Build a sequence from a iter function *) +(** {2 Use a sequence} *) + val iter : ('a -> unit) -> 'a sequence -> unit (** Consume the sequence, passing all its arguments to the function *) @@ -18,3 +22,12 @@ val map : ('a -> 'b) -> 'a sequence -> 'b sequence val filter : ('a -> bool) -> 'a sequence -> 'a sequence (** Filter on elements of the sequence *) + +val concat : 'a sequence -> 'a sequence -> 'a sequence + (** Concatenate two sequences *) + +val take : int -> 'a sequence -> 'a sequence + (** Take at most [n] elements from the sequence *) + +val drop : int -> 'a sequence -> 'a sequence + (** Drop the [n] first elements of the sequence *) diff --git a/tests.ml b/tests.ml index 11307b5..206ac87 100644 --- a/tests.ml +++ b/tests.ml @@ -23,6 +23,7 @@ let rec pp_list ?(sep=", ") pp_item formatter = function let _ = let l = [0;1;2;3;4;5;6] in let l' = list_of_seq (Sequence.filter (fun x -> x mod 2 = 0) (seq_of_list l)) in - Format.printf "l=@[[%a]@]; l'=@[[%a]@]@." - (pp_list Format.pp_print_int) l - (pp_list Format.pp_print_int) l' + let l'' = list_of_seq (Sequence.take 3 (Sequence.drop 1 (seq_of_list l))) in + Format.printf "l=@[[%a]@]@." (pp_list Format.pp_print_int) l; + Format.printf "l'=@[[%a]@]@." (pp_list Format.pp_print_int) l'; + Format.printf "l''=@[[%a]@]@." (pp_list Format.pp_print_int) l'';