From 309424a58fa0677e384984020a55858f2225222d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 21 Jun 2023 12:28:01 -0400 Subject: [PATCH] move fork join into its own module --- src/fork_join.ml | 88 +++++++++++++++++++++++++++++++++++++++++++++++ src/fork_join.mli | 15 ++++++++ src/moonpool.ml | 1 + src/moonpool.mli | 1 + src/pool.ml | 87 ---------------------------------------------- src/pool.mli | 14 -------- 6 files changed, 105 insertions(+), 101 deletions(-) create mode 100644 src/fork_join.ml create mode 100644 src/fork_join.mli diff --git a/src/fork_join.ml b/src/fork_join.ml new file mode 100644 index 00000000..8f859928 --- /dev/null +++ b/src/fork_join.ml @@ -0,0 +1,88 @@ +module A = Atomic_ + +module State_ = struct + type 'a single_res = + | St_none + | St_some of 'a + | St_fail of exn * Printexc.raw_backtrace + + type ('a, 'b) t = { + mutable suspension: + ((unit, exn * Printexc.raw_backtrace) result -> unit) option; + (** suspended caller *) + left: 'a single_res; + right: 'b single_res; + } + + let get_exn (self : _ t A.t) = + match A.get self with + | { left = St_fail (e, bt); _ } | { right = St_fail (e, bt); _ } -> + Printexc.raise_with_backtrace e bt + | { left = St_some x; right = St_some y; _ } -> x, y + | _ -> assert false + + let check_if_state_complete_ (self : _ t) : unit = + match self.left, self.right, self.suspension with + | St_some _, St_some _, Some f -> f (Ok ()) + | St_fail (e, bt), _, Some f | _, St_fail (e, bt), Some f -> + f (Error (e, bt)) + | _ -> () + + let set_left_ (self : _ t A.t) (x : _ single_res) = + while + let old_st = A.get self in + let new_st = { old_st with left = x } in + if A.compare_and_set self old_st new_st then ( + check_if_state_complete_ new_st; + false + ) else + true + do + Domain_.relax () + done + + let set_right_ (self : _ t A.t) (y : _ single_res) = + while + let old_st = A.get self in + let new_st = { old_st with right = y } in + if A.compare_and_set self old_st new_st then ( + check_if_state_complete_ new_st; + false + ) else + true + do + Domain_.relax () + done +end + +let both f g : _ * _ = + let open State_ in + let st = A.make { suspension = None; left = St_none; right = St_none } in + + let start_tasks ~run () : unit = + run (fun () -> + try + let res = f () in + set_left_ st (St_some res) + with e -> + let bt = Printexc.get_raw_backtrace () in + set_left_ st (St_fail (e, bt))); + + run (fun () -> + try + let res = g () in + set_right_ st (St_some res) + with e -> + let bt = Printexc.get_raw_backtrace () in + set_right_ st (St_fail (e, bt))) + in + + Suspend_.suspend + { + Suspend_types_.handle = + (fun ~run suspension -> + (* nothing else is started, no race condition possible *) + (A.get st).suspension <- Some suspension; + start_tasks ~run ()); + }; + get_exn st diff --git a/src/fork_join.mli b/src/fork_join.mli new file mode 100644 index 00000000..8cf36306 --- /dev/null +++ b/src/fork_join.mli @@ -0,0 +1,15 @@ +(** Fork-join primitives. + + @since 0.3 *) + +val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b +(** [both f g] runs [f()] and [g()], potentially in parallel, + and returns their result when both are done. + If any of [f()] and [g()] fails, then the whole computation fails. + + This must be run from within the pool: for example, inside {!Pool.run} + or inside a {!Fut.spawn} computation. + This is because it relies on an effect handler to be installed. + + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) diff --git a/src/moonpool.ml b/src/moonpool.ml index 37fabde7..99fe4521 100644 --- a/src/moonpool.ml +++ b/src/moonpool.ml @@ -5,5 +5,6 @@ let start_thread_on_some_domain f x = module Atomic = Atomic_ module Blocking_queue = Bb_queue module Chan = Chan +module Fork_join = Fork_join module Fut = Fut module Pool = Pool diff --git a/src/moonpool.mli b/src/moonpool.mli index e276b11e..1aab2165 100644 --- a/src/moonpool.mli +++ b/src/moonpool.mli @@ -13,6 +13,7 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t module Fut = Fut module Chan = Chan +module Fork_join = Fork_join (** A simple blocking queue. diff --git a/src/pool.ml b/src/pool.ml index 3518e33e..b9f511b1 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -227,90 +227,3 @@ let shutdown_ ~wait (self : t) : unit = let shutdown_without_waiting (self : t) : unit = shutdown_ self ~wait:false let shutdown (self : t) : unit = shutdown_ self ~wait:true - -module Fork_join_ = struct - type 'a single_res = - | St_none - | St_some of 'a - | St_fail of exn * Printexc.raw_backtrace - - type ('a, 'b) t = { - mutable suspension: - ((unit, exn * Printexc.raw_backtrace) result -> unit) option; - (** suspended caller *) - left: 'a single_res; - right: 'b single_res; - } - - let get_exn (self : _ t A.t) = - match A.get self with - | { left = St_fail (e, bt); _ } | { right = St_fail (e, bt); _ } -> - Printexc.raise_with_backtrace e bt - | { left = St_some x; right = St_some y; _ } -> x, y - | _ -> assert false - - let check_if_state_complete_ (self : _ t) : unit = - match self.left, self.right, self.suspension with - | St_some _, St_some _, Some f -> f (Ok ()) - | St_fail (e, bt), _, Some f | _, St_fail (e, bt), Some f -> - f (Error (e, bt)) - | _ -> () - - let set_left_ (self : _ t A.t) (x : _ single_res) = - while - let old_st = A.get self in - let new_st = { old_st with left = x } in - if A.compare_and_set self old_st new_st then ( - check_if_state_complete_ new_st; - false - ) else - true - do - Domain_.relax () - done - - let set_right_ (self : _ t A.t) (y : _ single_res) = - while - let old_st = A.get self in - let new_st = { old_st with right = y } in - if A.compare_and_set self old_st new_st then ( - check_if_state_complete_ new_st; - false - ) else - true - do - Domain_.relax () - done -end - -let fork_join f g : _ * _ = - let open Fork_join_ in - let st = A.make { suspension = None; left = St_none; right = St_none } in - - let start_tasks ~run () : unit = - run (fun () -> - try - let res = f () in - set_left_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_left_ st (St_fail (e, bt))); - - run (fun () -> - try - let res = g () in - set_right_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_right_ st (St_fail (e, bt))) - in - - Suspend_.suspend - { - Suspend_types_.handle = - (fun ~run suspension -> - (* nothing else is started, no race condition possible *) - (A.get st).suspension <- Some suspension; - start_tasks ~run ()); - }; - get_exn st diff --git a/src/pool.mli b/src/pool.mli index bde9aff6..ecf82288 100644 --- a/src/pool.mli +++ b/src/pool.mli @@ -91,17 +91,3 @@ val run_wait_block : t -> (unit -> 'a) -> 'a {b NOTE} be careful with deadlocks (see notes in {!Fut.wait_block}). @since 0.3 *) - -(** {2 Fork-join computations} *) - -val fork_join : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b -(** [fork_join f g] runs [f()] and [g()], potentially in parallel, - and returns their result when both are done. - If any of [f()] and [g()] fails, then the whole computation fails. - - This must be run from within the pool, inside {!run} - (or inside a {!Fut.spawn} computation). - This is because it relies on an effect handler to be installed. - - @since 0.3 - {b NOTE} this is only available on OCaml 5. *)