richer api for Gen

This commit is contained in:
Simon Cruanes 2014-02-01 15:26:07 +01:00
parent 2936595dbb
commit a3b4e28295
2 changed files with 182 additions and 68 deletions

186
gen.ml
View file

@ -53,6 +53,12 @@ module type S = sig
unfolding the ['b] value into a new ['b], and a ['a] which is yielded,
until [None] is returned. *)
val init : ?limit:int -> (int -> 'a) -> 'a t
(** Calls the function, starting from 0, on increasing indices.
If [limit] is provided and is a positive int, iteration will
stop at the limit (excluded).
For instance [init ~limit:4 id] will yield 0, 1, 2, and 3. *)
(** {2 Basic combinators} *)
val is_empty : _ t -> bool
@ -124,12 +130,6 @@ module type S = sig
val filterMap : ('a -> 'b option) -> 'a t -> 'b t
(** Maps some elements to 'b, drop the other ones *)
val zipWith : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t
(** Combine common part of the enums (stops when one is exhausted) *)
val zip : 'a t -> 'b t -> ('a * 'b) t
(** Zip together the common part of the enums *)
val zipIndex : 'a t -> (int * 'a) t
(** Zip elements with their index in the enum *)
@ -146,14 +146,6 @@ module type S = sig
val exists : ('a -> bool) -> 'a t -> bool
(** Is the predicate true for at least one element? *)
val for_all2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if all pairs of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val exists2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if some pair of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val min : ?lt:('a -> 'a -> bool) -> 'a t -> 'a
(** Minimum element, according to the given comparison function.
@raise Invalid_argument if the generator is empty *)
@ -172,15 +164,44 @@ module type S = sig
val compare : ?cmp:('a -> 'a -> int) -> 'a t -> 'a t -> int
(** Synonym for {! lexico} *)
val find : ('a -> bool) -> 'a t -> 'a option
(** [find p e] returns the first element of [e] to satisfy [p],
or None. *)
val sum : int t -> int
(** Sum of all elements *)
(** {2 Multiple iterators} *)
val map2 : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t
val iter2 : ('a -> 'b -> unit) -> 'a t -> 'b t -> unit
val fold2 : ('acc -> 'a -> 'b -> 'acc) -> 'acc -> 'a t -> 'b t -> 'acc
val for_all2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if all pairs of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val exists2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if some pair of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val zipWith : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t
(** Combine common part of the enums (stops when one is exhausted) *)
val zip : 'a t -> 'b t -> ('a * 'b) t
(** Zip together the common part of the enums *)
(** {2 Complex combinators} *)
val merge : 'a gen t -> 'a t
(** Pick elements fairly in each sub-generator. The merge of enums
[e1, e2, ... en] picks one element in [e1], then one element in [e2],
then in [e3], ..., then in [en], and then starts again at [e1]. Once
a generator is empty, it is skipped; when they are all empty,
[e1, e2, ... ] picks elements in [e1], [e2],
in [e3], [e1], [e2] .... Once a generator is empty, it is skipped;
when they are all empty, and none remains in the input,
their merge is also empty.
For instance, [merge [1;3;5] [2;4;6]] will be [1;2;3;4;5;6]. *)
For instance, [merge [1;3;5] [2;4;6]] will be, in disorder, [1;2;3;4;5;6]. *)
val intersection : ?cmp:('a -> 'a -> int) -> 'a t -> 'a t -> 'a t
(** Intersection of two sorted sequences. Only elements that occur in both
@ -227,6 +248,11 @@ module type S = sig
val sort_uniq : ?cmp:('a -> 'a -> int) -> 'a t -> 'a t
(** Sort and remove duplicates. The enum must be finite. *)
val chunks : int -> 'a t -> 'a array t
(** [chunks n e] returns a generator of arrays of length [n], composed
of successive elements of [e]. The last array may be smaller
than [n] *)
(* TODO later
val permutations : 'a t -> 'a gen t
(** Permutations of the enum. Each permutation becomes unavailable once
@ -318,11 +344,6 @@ let rec fold f acc gen =
| None -> acc
| Some x -> fold f (f acc x) gen
let rec fold2 f acc e1 e2 =
match e1(), e2() with
| Some x, Some y -> fold2 f (f acc x y) e1 e2
| _ -> acc
let reduce f g =
let acc = match g () with
| None -> raise (Invalid_argument "reduce")
@ -340,6 +361,16 @@ let unfold f acc =
acc := acc';
Some x
let init ?(limit=max_int) f =
let r = ref 0 in
fun () ->
if !r >= limit
then None
else
let x = f !r in
let _ = incr r in
Some x
let rec iter f gen =
match gen() with
| None -> ()
@ -546,16 +577,6 @@ let filterMap f gen =
| (Some _) as res -> res
in next
let zipWith f a b =
let stop = ref false in
fun () ->
if !stop then None
else match a(), b() with
| Some xa, Some xb -> Some (f xa xb)
| _ -> stop:=true; None
let zip a b = zipWith (fun x y -> x,y) a b
let zipIndex gen =
let r = ref ~-1 in
fun () ->
@ -626,16 +647,6 @@ let rec exists p gen =
| None -> false
| Some x -> p x || exists p gen
let rec for_all2 p e1 e2 =
match e1(), e2() with
| Some x, Some y -> p x y && for_all2 p e1 e2
| _ -> true
let rec exists2 p e1 e2 =
match e1(), e2() with
| Some x, Some y -> p x y || exists2 p e1 e2
| _ -> false
let min ?(lt=fun x y -> x < y) gen =
let first = match gen () with
| Some x -> x
@ -672,6 +683,54 @@ let lexico ?(cmp=Pervasives.compare) gen1 gen2 =
let compare ?cmp gen1 gen2 = lexico ?cmp gen1 gen2
let rec find p e = match e () with
| None -> None
| Some x when p x -> Some x
| Some _ -> find p e
let sum e =
let rec sum acc = match e() with
| None -> acc
| Some x -> sum (x+acc)
in sum 0
(** {2 Multiple Iterators} *)
let map2 f e1 e2 =
fun () -> match e1(), e2() with
| Some x, Some y -> Some (f x y)
| _ -> None
let rec iter2 f e1 e2 =
match e1(), e2() with
| Some x, Some y -> f x y; iter2 f e1 e2
| _ -> ()
let rec fold2 f acc e1 e2 =
match e1(), e2() with
| Some x, Some y -> fold2 f (f acc x y) e1 e2
| _ -> acc
let rec for_all2 p e1 e2 =
match e1(), e2() with
| Some x, Some y -> p x y && for_all2 p e1 e2
| _ -> true
let rec exists2 p e1 e2 =
match e1(), e2() with
| Some x, Some y -> p x y || exists2 p e1 e2
| _ -> false
let zipWith f a b =
let stop = ref false in
fun () ->
if !stop then None
else match a(), b() with
| Some xa, Some xb -> Some (f xa xb)
| _ -> stop:=true; None
let zip a b = zipWith (fun x y -> x,y) a b
(** {3 Complex combinators} *)
module MergeState = struct
@ -1036,6 +1095,26 @@ let sort ?(cmp=Pervasives.compare) gen =
let sort_uniq ?(cmp=Pervasives.compare) gen =
uniq ~eq:(fun x y -> cmp x y = 0) (sort ~cmp gen)
let chunks n e =
let rec next () =
match e() with
| None -> None
| Some x ->
let a = Array.make n x in
fill a (n-1)
and fill a i =
(* fill the array. [i] elements remain to fill *)
if i = n
then Some a
else match e() with
| None -> Some (Array.sub a 0 i) (* last array is not full *)
| Some x ->
a.(i) <- x;
fill a (i+1)
in
next
(*
let permutations enum =
failwith "not implemented" (* TODO *)
@ -1146,6 +1225,8 @@ module Restart = struct
let unfold f acc () = unfold f acc
let init ?limit f () = init ?limit f
let cycle enum =
assert (not (is_empty (enum ())));
fun () ->
@ -1160,8 +1241,6 @@ module Restart = struct
let fold f acc e = fold f acc (e ())
let fold2 f acc e1 e2 = fold2 f acc (e1 ()) (e2 ())
let reduce f e = reduce f (e ())
let scan f acc e () = scan f acc (e ())
@ -1170,8 +1249,6 @@ module Restart = struct
let iteri f e = iteri f (e ())
let iter2 f e1 e2 = iter2 f (e1 ()) (e2 ())
let length e = length (e ())
let map f e () = map f (e ())
@ -1221,6 +1298,15 @@ module Restart = struct
let exists2 p e1 e2 =
exists2 p (e1 ()) (e2 ())
let map2 f e1 e2 () =
map2 f (e1()) (e2())
let iter2 f e1 e2 =
iter2 f (e1()) (e2())
let fold2 f acc e1 e2 =
fold2 f acc (e1()) (e2())
let min ?lt e = min ?lt (e ())
let max ?lt e = max ?lt (e ())
@ -1232,6 +1318,10 @@ module Restart = struct
let compare ?cmp e1 e2 = compare ?cmp (e1 ()) (e2 ())
let sum e = sum (e())
let find f e = find f (e())
let merge e () = merge (e ())
let intersection ?cmp e1 e2 () =
@ -1264,6 +1354,8 @@ module Restart = struct
let e' = sort ~cmp e in
uniq ~eq:(fun x y -> cmp x y = 0) e'
let chunks n e () = chunks n (e())
let of_list l () = of_list l
let to_rev_list e = to_rev_list (e ())

64
gen.mli
View file

@ -69,6 +69,12 @@ module type S = sig
unfolding the ['b] value into a new ['b], and a ['a] which is yielded,
until [None] is returned. *)
val init : ?limit:int -> (int -> 'a) -> 'a t
(** Calls the function, starting from 0, on increasing indices.
If [limit] is provided and is a positive int, iteration will
stop at the limit (excluded).
For instance [init ~limit:4 id] will yield 0, 1, 2, and 3. *)
(** {2 Basic combinators} *)
val is_empty : _ t -> bool
@ -77,10 +83,6 @@ module type S = sig
val fold : ('b -> 'a -> 'b) -> 'b -> 'a t -> 'b
(** Fold on the generator, tail-recursively *)
val fold2 : ('c -> 'a -> 'b -> 'c) -> 'c -> 'a t -> 'b t -> 'c
(** Fold on the two enums in parallel. Stops once one of the enums
is exhausted. *)
val reduce : ('a -> 'a -> 'a) -> 'a t -> 'a
(** Fold on non-empty sequences (otherwise raise Invalid_argument) *)
@ -93,9 +95,6 @@ module type S = sig
val iteri : (int -> 'a -> unit) -> 'a t -> unit
(** Iterate on elements with their index in the enum, from 0 *)
val iter2 : ('a -> 'b -> unit) -> 'a t -> 'b t -> unit
(** Iterate on the two sequences. Stops once one of them is exhausted.*)
val length : _ t -> int
(** Length of an enum (linear time) *)
@ -140,12 +139,6 @@ module type S = sig
val filterMap : ('a -> 'b option) -> 'a t -> 'b t
(** Maps some elements to 'b, drop the other ones *)
val zipWith : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t
(** Combine common part of the enums (stops when one is exhausted) *)
val zip : 'a t -> 'b t -> ('a * 'b) t
(** Zip together the common part of the enums *)
val zipIndex : 'a t -> (int * 'a) t
(** Zip elements with their index in the enum *)
@ -162,14 +155,6 @@ module type S = sig
val exists : ('a -> bool) -> 'a t -> bool
(** Is the predicate true for at least one element? *)
val for_all2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if all pairs of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val exists2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if some pair of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val min : ?lt:('a -> 'a -> bool) -> 'a t -> 'a
(** Minimum element, according to the given comparison function.
@raise Invalid_argument if the generator is empty *)
@ -188,6 +173,38 @@ module type S = sig
val compare : ?cmp:('a -> 'a -> int) -> 'a t -> 'a t -> int
(** Synonym for {! lexico} *)
val find : ('a -> bool) -> 'a t -> 'a option
(** [find p e] returns the first element of [e] to satisfy [p],
or None. *)
val sum : int t -> int
(** Sum of all elements *)
(** {2 Multiple iterators} *)
val map2 : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t
(** Map on the two sequences. Stops once one of them is exhausted.*)
val iter2 : ('a -> 'b -> unit) -> 'a t -> 'b t -> unit
(** Iterate on the two sequences. Stops once one of them is exhausted.*)
val fold2 : ('acc -> 'a -> 'b -> 'acc) -> 'acc -> 'a t -> 'b t -> 'acc
(** Fold the common prefix of the two iterators *)
val for_all2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if all pairs of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val exists2 : ('a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** Succeeds if some pair of elements satisfy the predicate.
Ignores elements of an iterator if the other runs dry. *)
val zipWith : ('a -> 'b -> 'c) -> 'a t -> 'b t -> 'c t
(** Combine common part of the enums (stops when one is exhausted) *)
val zip : 'a t -> 'b t -> ('a * 'b) t
(** Zip together the common part of the enums *)
(** {2 Complex combinators} *)
val merge : 'a gen t -> 'a t
@ -243,6 +260,11 @@ module type S = sig
val sort_uniq : ?cmp:('a -> 'a -> int) -> 'a t -> 'a t
(** Sort and remove duplicates. The enum must be finite. *)
val chunks : int -> 'a t -> 'a array t
(** [chunks n e] returns a generator of arrays of length [n], composed
of successive elements of [e]. The last array may be smaller
than [n] *)
(* TODO later
val permutations : 'a t -> 'a gen t
(** Permutations of the enum. Each permutation becomes unavailable once