move fork join into its own module

This commit is contained in:
Simon Cruanes 2023-06-21 12:28:01 -04:00
parent 009855ce0d
commit 309424a58f
6 changed files with 105 additions and 101 deletions

88
src/fork_join.ml Normal file
View file

@ -0,0 +1,88 @@
module A = Atomic_
module State_ = struct
type 'a single_res =
| St_none
| St_some of 'a
| St_fail of exn * Printexc.raw_backtrace
type ('a, 'b) t = {
mutable suspension:
((unit, exn * Printexc.raw_backtrace) result -> unit) option;
(** suspended caller *)
left: 'a single_res;
right: 'b single_res;
}
let get_exn (self : _ t A.t) =
match A.get self with
| { left = St_fail (e, bt); _ } | { right = St_fail (e, bt); _ } ->
Printexc.raise_with_backtrace e bt
| { left = St_some x; right = St_some y; _ } -> x, y
| _ -> assert false
let check_if_state_complete_ (self : _ t) : unit =
match self.left, self.right, self.suspension with
| St_some _, St_some _, Some f -> f (Ok ())
| St_fail (e, bt), _, Some f | _, St_fail (e, bt), Some f ->
f (Error (e, bt))
| _ -> ()
let set_left_ (self : _ t A.t) (x : _ single_res) =
while
let old_st = A.get self in
let new_st = { old_st with left = x } in
if A.compare_and_set self old_st new_st then (
check_if_state_complete_ new_st;
false
) else
true
do
Domain_.relax ()
done
let set_right_ (self : _ t A.t) (y : _ single_res) =
while
let old_st = A.get self in
let new_st = { old_st with right = y } in
if A.compare_and_set self old_st new_st then (
check_if_state_complete_ new_st;
false
) else
true
do
Domain_.relax ()
done
end
let both f g : _ * _ =
let open State_ in
let st = A.make { suspension = None; left = St_none; right = St_none } in
let start_tasks ~run () : unit =
run (fun () ->
try
let res = f () in
set_left_ st (St_some res)
with e ->
let bt = Printexc.get_raw_backtrace () in
set_left_ st (St_fail (e, bt)));
run (fun () ->
try
let res = g () in
set_right_ st (St_some res)
with e ->
let bt = Printexc.get_raw_backtrace () in
set_right_ st (St_fail (e, bt)))
in
Suspend_.suspend
{
Suspend_types_.handle =
(fun ~run suspension ->
(* nothing else is started, no race condition possible *)
(A.get st).suspension <- Some suspension;
start_tasks ~run ());
};
get_exn st

15
src/fork_join.mli Normal file
View file

@ -0,0 +1,15 @@
(** Fork-join primitives.
@since 0.3 *)
val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b
(** [both f g] runs [f()] and [g()], potentially in parallel,
and returns their result when both are done.
If any of [f()] and [g()] fails, then the whole computation fails.
This must be run from within the pool: for example, inside {!Pool.run}
or inside a {!Fut.spawn} computation.
This is because it relies on an effect handler to be installed.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)

View file

@ -5,5 +5,6 @@ let start_thread_on_some_domain f x =
module Atomic = Atomic_
module Blocking_queue = Bb_queue
module Chan = Chan
module Fork_join = Fork_join
module Fut = Fut
module Pool = Pool

View file

@ -13,6 +13,7 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t
module Fut = Fut
module Chan = Chan
module Fork_join = Fork_join
(** A simple blocking queue.

View file

@ -227,90 +227,3 @@ let shutdown_ ~wait (self : t) : unit =
let shutdown_without_waiting (self : t) : unit = shutdown_ self ~wait:false
let shutdown (self : t) : unit = shutdown_ self ~wait:true
module Fork_join_ = struct
type 'a single_res =
| St_none
| St_some of 'a
| St_fail of exn * Printexc.raw_backtrace
type ('a, 'b) t = {
mutable suspension:
((unit, exn * Printexc.raw_backtrace) result -> unit) option;
(** suspended caller *)
left: 'a single_res;
right: 'b single_res;
}
let get_exn (self : _ t A.t) =
match A.get self with
| { left = St_fail (e, bt); _ } | { right = St_fail (e, bt); _ } ->
Printexc.raise_with_backtrace e bt
| { left = St_some x; right = St_some y; _ } -> x, y
| _ -> assert false
let check_if_state_complete_ (self : _ t) : unit =
match self.left, self.right, self.suspension with
| St_some _, St_some _, Some f -> f (Ok ())
| St_fail (e, bt), _, Some f | _, St_fail (e, bt), Some f ->
f (Error (e, bt))
| _ -> ()
let set_left_ (self : _ t A.t) (x : _ single_res) =
while
let old_st = A.get self in
let new_st = { old_st with left = x } in
if A.compare_and_set self old_st new_st then (
check_if_state_complete_ new_st;
false
) else
true
do
Domain_.relax ()
done
let set_right_ (self : _ t A.t) (y : _ single_res) =
while
let old_st = A.get self in
let new_st = { old_st with right = y } in
if A.compare_and_set self old_st new_st then (
check_if_state_complete_ new_st;
false
) else
true
do
Domain_.relax ()
done
end
let fork_join f g : _ * _ =
let open Fork_join_ in
let st = A.make { suspension = None; left = St_none; right = St_none } in
let start_tasks ~run () : unit =
run (fun () ->
try
let res = f () in
set_left_ st (St_some res)
with e ->
let bt = Printexc.get_raw_backtrace () in
set_left_ st (St_fail (e, bt)));
run (fun () ->
try
let res = g () in
set_right_ st (St_some res)
with e ->
let bt = Printexc.get_raw_backtrace () in
set_right_ st (St_fail (e, bt)))
in
Suspend_.suspend
{
Suspend_types_.handle =
(fun ~run suspension ->
(* nothing else is started, no race condition possible *)
(A.get st).suspension <- Some suspension;
start_tasks ~run ());
};
get_exn st

View file

@ -91,17 +91,3 @@ val run_wait_block : t -> (unit -> 'a) -> 'a
{b NOTE} be careful with deadlocks (see notes in {!Fut.wait_block}).
@since 0.3 *)
(** {2 Fork-join computations} *)
val fork_join : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b
(** [fork_join f g] runs [f()] and [g()], potentially in parallel,
and returns their result when both are done.
If any of [f()] and [g()] fails, then the whole computation fails.
This must be run from within the pool, inside {!run}
(or inside a {!Fut.spawn} computation).
This is because it relies on an effect handler to be installed.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)