From d31a84bab4c38ad494b8f2066045a677124fc0ee Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 30 May 2023 23:51:52 -0400 Subject: [PATCH] add combinators to Fut --- src/moonpool.ml | 151 +++++++++++++++++++++++++++++++++++++++++++++++ src/moonpool.mli | 53 +++++++++++++++++ 2 files changed, 204 insertions(+) diff --git a/src/moonpool.ml b/src/moonpool.ml index 8a2ef573..77275306 100644 --- a/src/moonpool.ml +++ b/src/moonpool.ml @@ -175,6 +175,16 @@ module Fut = struct let[@inline] return x : _ t = of_result (Ok x) let[@inline] fail e bt : _ t = of_result (Error (e, bt)) + let[@inline] is_resolved self : bool = + match A.get self.st with + | Done _ -> true + | Waiting _ -> false + + let[@inline] peek self : _ option = + match A.get self.st with + | Done x -> Some x + | Waiting _ -> None + let on_result (self : _ t) (f : _ waiter) : unit = while let st = A.get self.st in @@ -209,4 +219,145 @@ module Fut = struct do () done + + (* ### combinators ### *) + + let spawn ~on f : _ t = + let fut, promise = make () in + + let task () = + let res = + try Ok (f ()) + with e -> + let bt = Printexc.get_raw_backtrace () in + Error (e, bt) + in + fulfill promise res + in + + Pool.run on task; + fut + + let map ?on ~f fut : _ t = + let map_res r = + match r with + | Ok x -> + (try Ok (f x) + with e -> + let bt = Printexc.get_raw_backtrace () in + Error (e, bt)) + | Error e_bt -> Error e_bt + in + + match peek fut with + | Some r -> of_result (map_res r) + | None -> + let fut2, promise = make () in + on_result fut (fun r -> + let map_and_fulfill () = + let res = map_res r in + fulfill promise res + in + + match on with + | None -> map_and_fulfill () + | Some on -> Pool.run on map_and_fulfill); + + fut2 + + let bind ?on ~f fut : _ t = + let apply_f_to_res r : _ t = + match r with + | Ok x -> + (try f x + with e -> + let bt = Printexc.get_raw_backtrace () in + fail e bt) + | Error (e, bt) -> fail e bt + in + + match peek fut with + | Some r -> apply_f_to_res r + | None -> + let fut2, promise = make () in + on_result fut (fun r -> + let bind_and_fulfill () = + let f_res_fut = apply_f_to_res r in + (* forward result *) + on_result f_res_fut (fun r -> fulfill promise r) + in + + match on with + | None -> bind_and_fulfill () + | Some on -> Pool.run on bind_and_fulfill); + + fut2 + + let peek_ok_assert_ (self : 'a t) : 'a = + match A.get self.st with + | Done (Ok x) -> x + | _ -> assert false + + let join_container_ ~iter ~map ~len cont : _ t = + let fut, promise = make () in + let missing = A.make (len cont) in + + (* callback called when a future in [a] is resolved *) + let on_res = function + | Ok _ -> + let n = A.fetch_and_add missing (-1) in + if n = 1 then ( + (* last future, we know they all succeeded, so resolve [fut] *) + let res = map peek_ok_assert_ cont in + fulfill promise (Ok res) + ) + | Error e_bt -> + (* immediately cancel all other [on_res] *) + let n = A.exchange missing 0 in + if n > 0 then + (* we're the only one to set to 0, so we can fulfill [fut] + with an error. *) + fulfill promise (Error e_bt) + in + + iter (fun fut -> on_result fut on_res) cont; + fut + + let join_array (a : _ t array) : _ array t = + match Array.length a with + | 0 -> return [||] + | 1 -> map ?on:None a.(1) ~f:(fun x -> [| x |]) + | _ -> join_container_ ~len:Array.length ~map:Array.map ~iter:Array.iter a + + let join_list (l : _ t list) : _ list t = + match l with + | [] -> return [] + | [ x ] -> map ?on:None x ~f:(fun x -> [ x ]) + | _ -> join_container_ ~len:List.length ~map:List.map ~iter:List.iter l + + let for_ ~on n f : unit t = + let futs = Array.init n (fun i -> spawn ~on (fun () -> f i)) in + join_container_ + ~len:(fun () -> n) + ~iter:(fun f () -> Array.iter f futs) + ~map:(fun _f () -> ()) + () + + (* ### blocking ### *) + + let wait_block (self : 'a t) : 'a or_error = + match peek self with + | Some x -> + (* fast path *) + x + | None -> + (* use queue only once *) + let q = S_queue.create () in + on_result self (fun r -> S_queue.push q r); + S_queue.pop q + + let wait_block_exn self = + match wait_block self with + | Ok x -> x + | Error (e, bt) -> Printexc.raise_with_backtrace e bt end diff --git a/src/moonpool.mli b/src/moonpool.mli index 273eef63..96a0f2c8 100644 --- a/src/moonpool.mli +++ b/src/moonpool.mli @@ -61,4 +61,57 @@ module Fut : sig (** Already settled future, with a failure *) val of_result : 'a or_error -> 'a t + + val is_resolved : _ t -> bool + (** [is_resolved fut] is [true] iff [fut] is resolved. *) + + val peek : 'a t -> 'a or_error option + (** [peek fut] returns [Some r] if [fut] is currently resolved with [r], + and [None] if [fut] is not resolved yet. *) + + (** {2 Combinators} *) + + val spawn : on:Pool.t -> (unit -> 'a) -> 'a t + (** [spaw ~on f] runs [f()] on the given pool, and return a future that will + hold its result. *) + + val map : ?on:Pool.t -> f:('a -> 'b) -> 'a t -> 'b t + (** [map ?on ~f fut] returns a new future [fut2] that resolves + with [f x] if [fut] resolved with [x]; + and fails with [e] if [fut] fails with [e] or [f x] raises [e]. + @param on if provided, [f] runs on the given pool *) + + val bind : ?on:Pool.t -> f:('a -> 'b t) -> 'a t -> 'b t + (** [map ?on ~f fut] returns a new future [fut2] that resolves + like the future [f x] if [fut] resolved with [x]; + and fails with [e] if [fut] fails with [e] or [f x] raises [e]. + @param on if provided, [f] runs on the given pool *) + + val join_array : 'a t array -> 'a array t + (** Wait for all the futures in the array. Fails if any future fails. *) + + val join_list : 'a t list -> 'a list t + (** Wait for all the futures in the list. Fails if any future fails. *) + + val for_ : on:Pool.t -> int -> (int -> unit) -> unit t + (** [for_ ~on n f] runs [f 0], [f 1], …, [f (n-1)] on the pool, and returns + a future that resolves when all the tasks have resolved, or fails + as soon as one task has failed. *) + + (** {2 Blocking} *) + + val wait_block : 'a t -> 'a or_error + (** [wait_block fut] blocks the current thread until [fut] is resolved, + and returns its value. + + A word of warning: this can easily cause deadlocks. A good rule to avoid + deadlocks is to run this from outside of any pool, or to have an acyclic order + between pools where [wait_block] is only called from a pool on futures evaluated + in a pool that comes lower in the hierarchy. + If this rule is broken, it is possible for all threads in a pool to wait for + futures that can only make progress on these same threads, hence the deadlock. + *) + + val wait_block_exn : 'a t -> 'a + (** Same as {!wait_block} but re-raises the exception if the future failed. *) end