diff --git a/core/CCBatch.ml b/core/CCBatch.ml index d27de6ad..760982e6 100644 --- a/core/CCBatch.ml +++ b/core/CCBatch.ml @@ -1,6 +1,6 @@ (* -copyright (c) 2013-2014, simon cruanes +copyright (c) 2013-2014, Simon Cruanes, Gabriel Radanne all rights reserved. redistribution and use in source and binary forms, with or without @@ -43,12 +43,11 @@ module type S = sig type ('a,'b) op (** Operation that converts an ['a t] into a ['b t] *) - val length : (_,_) op -> int - (** Number of intermediate structures needed to compute this operation *) - val apply : ('a,'b) op -> 'a t -> 'b t - (** Apply the operation to the collection. - @param level the optimization level, default is [OptimBase] *) + (** Apply the operation to the collection. *) + + val apply_fold : ('a, 'b) op -> ('c -> 'b -> 'c) -> 'c -> 'a t -> 'c + (** Apply the operation plus a fold to the collection. *) val apply' : 'a t -> ('a,'b) op -> 'b t (** Flip of {!apply} *) @@ -65,6 +64,8 @@ module type S = sig val flat_map : ('a -> 'b t) -> ('a,'b) op + val extern : ('a t -> 'b t) -> ('a,'b) op + val compose : ('b,'c) op -> ('a,'b) op -> ('a,'c) op val (>>>) : ('a,'b) op -> ('b,'c) op -> ('a,'c) op end @@ -75,12 +76,11 @@ module Make(C : COLLECTION) = struct | Nil : ('a,'a) op | Compose : ('a,'b) base_op * ('b, 'c) op -> ('a, 'c) op and (_,_) base_op = - | Id : ('a, 'a) base_op | Map : ('a -> 'b) -> ('a, 'b) base_op | Filter : ('a -> bool) -> ('a, 'a) base_op | FilterMap : ('a -> 'b option) -> ('a,'b) base_op | FlatMap : ('a -> 'b t) -> ('a,'b) base_op - + | Extern : ('a t -> 'b t) -> ('a,'b) base_op (* associativity: put parenthesis on the right *) let rec _compose : type a b c. (a,b) op -> (b,c) op -> (a,c) op @@ -89,20 +89,19 @@ module Make(C : COLLECTION) = struct | Compose (f1, f2) -> Compose (f1, _compose f2 g) | Nil -> g - - (* After optimization, the op is a list of flatmaps, with maybe something else at the end *) + (* After optimization, the op is a list of flatmaps and external operations, + with maybe something else at the end *) type (_,_) optimized_op = - | Base : ('a,'b) base_op -> ('a,'b) optimized_op - | FlatMapPlus : ('a -> 'b t) * ('b, 'c) optimized_op -> ('a, 'c) optimized_op - + | OptNil : ('a, 'a) optimized_op + | OptBase : ('a,'b) base_op * ('b, 'c) optimized_op -> ('a,'c) optimized_op + | OptFlatMap : ('a -> 'b t) * ('b, 'c) optimized_op -> ('a, 'c) optimized_op + | OptExtern : ('a t -> 'b t) * ('b, 'c) optimized_op -> ('a, 'c) optimized_op (* As compose, but optimize recursively on the way. *) let rec optimize_compose : type a b c. (a,b) base_op -> (b,c) op -> (a,c) optimized_op = fun base_op op -> match base_op, op with - | f, Nil -> Base f - | Id, Compose (f, cont) -> optimize_compose f cont - | f, Compose (Id, cont) -> optimize_compose f cont + | f, Nil -> OptBase (f, OptNil) | Map f, Compose (Map g, cont) -> optimize_compose (Map (fun x -> g (f x))) cont | Map f, Compose (Filter p, cont) -> @@ -151,34 +150,45 @@ module Make(C : COLLECTION) = struct | None -> C.empty | Some y -> f' y)) cont + | FlatMap f, Compose (f', tail) -> + merge_flat_map f (optimize_compose f' tail) + | Extern f, Compose (f', tail) -> + OptExtern (f, optimize_compose f' tail) + | op, Compose (Extern f', cont) -> + OptBase (op, optimize_compose (Extern f') cont) - (* flatmap doesn't compose with anything *) - | FlatMap f, Compose (f', cont) -> - FlatMapPlus (f, optimize_compose f' cont) - - - let rec length : type a b. (a,b) op -> int = function - | Nil -> 0 - | Compose (_, cont) -> 1 + length cont - + and merge_flat_map + : type a b c. (a -> b C.t) -> (b,c) optimized_op -> (a,c) optimized_op = + fun f op -> match op with + | OptNil -> OptFlatMap (f, op) + | OptFlatMap (f', cont) -> + merge_flat_map + (fun x -> + let a = f x in + C.flat_map f' a) + cont + | OptExtern _ -> OptFlatMap (f, op) + | OptBase _ -> OptFlatMap (f, op) (* optimize a batch operation by fusion *) let optimize : type a b. (a,b) op -> (a,b) optimized_op = fun op -> match op with | Compose (a, b) -> optimize_compose a b - | Nil -> Base Id + | Nil -> OptNil let rec apply_optimized : type a b. (a,b) optimized_op -> a t -> b t = fun op a -> match op with - | Base f -> apply_base f a - | FlatMapPlus (f,c) -> apply_optimized c @@ C.flat_map f a + | OptNil -> a + | OptBase (f,c) -> apply_optimized c (apply_base f a) + | OptFlatMap (f,c) -> apply_optimized c (C.flat_map f a) + | OptExtern (f,c) -> apply_optimized c (f a) and apply_base : type a b. (a,b) base_op -> a t -> b t = fun op a -> match op with | Map f -> C.map f a | Filter p -> C.filter p a | FlatMap f -> C.flat_map f a | FilterMap f -> C.filter_map f a - | Id -> a + | Extern f -> f a let fusion_fold : type a b c. (a,b) base_op -> (c -> b -> c) -> c -> a -> c = fun op f' -> match op with @@ -186,14 +196,22 @@ module Make(C : COLLECTION) = struct | Filter p -> (fun z x -> if p x then f' z x else z) | FlatMap f -> (fun z x -> C.fold f' z (f x)) | FilterMap f -> (fun z x -> match f x with Some x' -> f' z x' | None -> z) - | Id -> f' + | Extern _ -> assert false - let rec apply_optimized_with_fold : type a b c. (a,b) optimized_op -> (c -> b -> c) -> c -> a t -> c + let rec apply_optimized_with_fold + : type a b c. (a,b) optimized_op -> (c -> b -> c) -> c -> a t -> c = fun op fold z a -> match op with - | Base f -> C.fold (fusion_fold f fold) z a - | FlatMapPlus (f,c) -> apply_optimized_with_fold c fold z @@ C.flat_map f a - - + | OptNil -> C.fold fold z a + | OptBase (Extern f, OptNil) -> + C.fold fold z (f a) + | OptBase (f,OptNil) -> + (* terminal fold *) + C.fold (fusion_fold f fold) z a + | OptBase (f,c) -> + (* make intermediate collection and continue *) + apply_optimized_with_fold c fold z (apply_base f a) + | OptExtern (f,c) -> apply_optimized_with_fold c fold z (f a) + | OptFlatMap (f,c) -> apply_optimized_with_fold c fold z (C.flat_map f a) (* optimize and run *) let apply op a = @@ -213,6 +231,7 @@ module Make(C : COLLECTION) = struct let filter p = Compose (Filter p, Nil) let filter_map f = Compose (FilterMap f, Nil) let flat_map f = Compose (FlatMap f, Nil) + let extern f = Compose (Extern f, Nil) let compose f g = _compose g f let (>>>) f g = _compose f g diff --git a/core/CCBatch.mli b/core/CCBatch.mli index df10080a..7b04b692 100644 --- a/core/CCBatch.mli +++ b/core/CCBatch.mli @@ -35,6 +35,7 @@ module type COLLECTION = sig type 'a t val empty : 'a t + val fold : ('a -> 'b -> 'a) -> 'a -> 'b t -> 'a val map : ('a -> 'b) -> 'a t -> 'b t val filter : ('a -> bool) -> 'a t -> 'a t val filter_map : ('a -> 'b option) -> 'a t -> 'b t @@ -48,13 +49,10 @@ module type S = sig type ('a,'b) op (** Operation that converts an ['a t] into a ['b t] *) - val length : (_,_) op -> int - (** Number of intermediate structures needed to compute this operation *) - val apply : ('a,'b) op -> 'a t -> 'b t (** Apply the operation to the collection. *) - val apply_width_fold : ('a, 'b) op -> ('c -> 'b -> 'c) -> 'c -> 'a t -> 'c + val apply_fold : ('a, 'b) op -> ('c -> 'b -> 'c) -> 'c -> 'a t -> 'c (** Apply the operation plus a fold to the collection. *) val apply' : 'a t -> ('a,'b) op -> 'b t @@ -72,6 +70,9 @@ module type S = sig val flat_map : ('a -> 'b t) -> ('a,'b) op + val extern : ('a t -> 'b t) -> ('a,'b) op + (** Use a specific function that won't be optimized *) + val compose : ('b,'c) op -> ('a,'b) op -> ('a,'c) op val (>>>) : ('a,'b) op -> ('b,'c) op -> ('a,'c) op end diff --git a/tests/bench_batch.ml b/tests/bench_batch.ml index 90489ec3..daac17f6 100644 --- a/tests/bench_batch.ml +++ b/tests/bench_batch.ml @@ -6,7 +6,6 @@ module type COLL = sig val doubleton : 'a -> 'a -> 'a t val (--) : int -> int -> int t val equal : int t -> int t -> bool - val fold : (int -> int -> int) -> int -> int t -> int end module Make(C : COLL) = struct @@ -31,25 +30,13 @@ module Make(C : COLL) = struct let ops = BA.(filter f1 >>> flat_map f3 >>> filter f1 >>> map f2 >>> flat_map f3 >>> map f4) - let batch_simple a = - let a = BA.apply ~level:BA.OptimNone ops a in - ignore (collect a); - a - let batch a = - let a = BA.apply ~level:BA.OptimBase ops a in - ignore (collect a); - a - - let batch2 a = - let a = BA.apply ~level:BA.OptimMergeFlatMap ops a in + let a = BA.apply ops a in ignore (collect a); a let bench_for ~time n = Printf.printf "\n\nbenchmark for %s of len %d\n" C.name n; - Printf.printf "optimization: from %d to %d\n" - (BA.length ops) (BA.length (BA.optimize ops)); flush stdout; let a = C.(0 -- n) in (* debug @@ -57,21 +44,18 @@ module Make(C : COLL) = struct CCPrint.printf "simple: %a\n" (CCArray.pp CCInt.pp) (batch_simple a); CCPrint.printf "batch: %a\n" (CCArray.pp CCInt.pp) (batch a); *) - assert (C.equal (batch_simple a) (naive a)); - assert (C.equal (batch_simple a) (batch a)); + assert (C.equal (batch a) (naive a)); let res = Benchmark.throughputN time [ C.name ^ "_naive", naive, a - ; C.name ^ "_batch_simple", batch_simple, a ; C.name ^ "_batch", batch, a - ; C.name ^ "_batch_merge", batch2, a ] in Benchmark.tabulate res let bench () = bench_for 1 100; - bench_for 2 100_000; - bench_for 2 1_000_000; + bench_for 4 100_000; + bench_for 4 1_000_000; () end