breaking: move fork-join into sub-library moonpool.forkjoin

This commit is contained in:
Simon Cruanes 2024-02-02 20:31:27 -05:00
parent 0f1f39380f
commit 223f22a0d9
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
11 changed files with 373 additions and 37 deletions

View file

@ -29,6 +29,7 @@ module Ws_pool = Ws_pool
module Private = struct module Private = struct
module Ws_deque_ = Ws_deque_ module Ws_deque_ = Ws_deque_
module Suspend_ = Suspend_ module Suspend_ = Suspend_
module Domain_ = Domain_
let num_domains = Domain_pool_.n_domains let num_domains = Domain_pool_.n_domains
end end

View file

@ -204,7 +204,8 @@ module Private : sig
{b NOTE}: this is not stable for now. *) {b NOTE}: this is not stable for now. *)
module Domain = Domain_ module Domain_ = Domain_
(** Utils for domains *)
val num_domains : unit -> int val num_domains : unit -> int
(** Number of domains in the backing domain pool *) (** Number of domains in the backing domain pool *)

12
src/forkjoin/dune Normal file
View file

@ -0,0 +1,12 @@
(library
(name moonpool_forkjoin)
(public_name moonpool.forkjoin)
(synopsis "Fork-join parallelism for moonpool")
(flags :standard -open Moonpool)
(preprocess
(action
(run %{project_root}/src/cpp/cpp.exe %{input-file})))
(libraries
moonpool moonpool.private))

View file

@ -0,0 +1,222 @@
[@@@ifge 5.0]
module A = Moonpool.Atomic
module Suspend_ = Moonpool.Private.Suspend_
module Domain_ = Moonpool_private.Domain_
module State_ = struct
type error = exn * Printexc.raw_backtrace
type 'a or_error = ('a, error) result
type ('a, 'b) t =
| Init
| Left_solved of 'a or_error
| Right_solved of 'b or_error * Suspend_.suspension
| Both_solved of 'a or_error * 'b or_error
let get_exn_ (self : _ t A.t) =
match A.get self with
| Both_solved (Ok a, Ok b) -> a, b
| Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) ->
Printexc.raise_with_backtrace exn bt
| _ -> assert false
let rec set_left_ (self : _ t A.t) (left : _ or_error) =
let old_st = A.get self in
match old_st with
| Init ->
let new_st = Left_solved left in
if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
set_left_ self left
)
| Right_solved (right, cont) ->
let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
set_left_ self left
) else
cont (Ok ())
| Left_solved _ | Both_solved _ -> assert false
let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
let old_st = A.get self in
match old_st with
| Left_solved left ->
let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then set_right_ self right
| Init ->
(* we are first arrived, we suspend until the left computation is done *)
Suspend_.suspend
{
Suspend_.handle =
(fun ~name:_ ~run:_ suspension ->
while
let old_st = A.get self in
match old_st with
| Init ->
not
(A.compare_and_set self old_st
(Right_solved (right, suspension)))
| Left_solved left ->
(* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right));
suspension (Ok ());
false
| Right_solved _ | Both_solved _ -> assert false
do
()
done);
}
| Right_solved _ | Both_solved _ -> assert false
end
let both f g : _ * _ =
let module ST = State_ in
let st = A.make ST.Init in
let runner =
match Runner.get_current_runner () with
| None -> invalid_arg "Fork_join.both must be run from within a runner"
| Some r -> r
in
(* start computing [f] in the background *)
Runner.run_async runner (fun () ->
try
let res = f () in
ST.set_left_ st (Ok res)
with exn ->
let bt = Printexc.get_raw_backtrace () in
ST.set_left_ st (Error (exn, bt)));
let res_right =
try Ok (g ())
with exn ->
let bt = Printexc.get_raw_backtrace () in
Error (exn, bt)
in
ST.set_right_ st res_right;
ST.get_exn_ st
let both_ignore f g = ignore (both f g : _ * _)
let for_ ?chunk_size n (f : int -> int -> unit) : unit =
if n > 0 then (
let has_failed = A.make false in
let missing = A.make n in
let chunk_size =
match chunk_size with
| Some cs -> max 1 (min n cs)
| None ->
(* guess: try to have roughly one task per core *)
max 1 (1 + (n / Moonpool.Private.num_domains ()))
in
let start_tasks ~name ~run (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with
| () ->
if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *)
run ~name (fun () -> suspension (Ok ()))
| exception exn ->
let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then
(* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *)
run ~name (fun () -> suspension (Error (exn, bt)))
in
let i = ref 0 in
while !i < n do
let offset = !i in
let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n);
run ~name (fun () -> task_for ~offset ~len_range);
i := !i + len_range
done
in
Suspend_.suspend
{
Suspend_.handle =
(fun ~name ~run suspension ->
(* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~name suspension);
}
)
let all_array ?chunk_size (fs : _ array) : _ array =
let len = Array.length fs in
let arr = Array.make len None in
(* parallel for *)
for_ ?chunk_size len (fun low high ->
for i = low to high do
let x = fs.(i) () in
arr.(i) <- Some x
done);
(* get all results *)
Array.map
(function
| None -> assert false
| Some x -> x)
arr
let all_list ?chunk_size fs : _ list =
Array.to_list @@ all_array ?chunk_size @@ Array.of_list fs
let all_init ?chunk_size n f : _ list =
let arr = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
let x = f i in
arr.(i) <- Some x
done);
(* get all results *)
List.init n (fun i ->
match arr.(i) with
| None -> assert false
| Some x -> x)
let map_array ?chunk_size f arr : _ array =
let n = Array.length arr in
let res = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
res.(i) <- Some (f arr.(i))
done);
(* get all results *)
Array.map
(function
| None -> assert false
| Some x -> x)
res
let map_list ?chunk_size f (l : _ list) : _ list =
let arr = Array.of_list l in
let n = Array.length arr in
let res = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
res.(i) <- Some (f arr.(i))
done);
(* get all results *)
List.init n (fun i ->
match res.(i) with
| None -> assert false
| Some x -> x)
[@@@endif]

View file

@ -0,0 +1,109 @@
(** Fork-join primitives.
{b NOTE} These are only available on OCaml 5.0 and above.
@since 0.3 *)
[@@@ifge 5.0]
val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b
(** [both f g] runs [f()] and [g()], potentially in parallel,
and returns their result when both are done.
If any of [f()] and [g()] fails, then the whole computation fails.
This must be run from within the pool: for example, inside {!Pool.run}
or inside a {!Fut.spawn} computation.
This is because it relies on an effect handler to be installed.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val both_ignore : (unit -> _) -> (unit -> _) -> unit
(** Same as [both f g |> ignore].
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val for_ : ?chunk_size:int -> int -> (int -> int -> unit) -> unit
(** [for_ n f] is the parallel version of [for i=0 to n-1 do f i done].
[f] is called with parameters [low] and [high] and must use them like so:
{[ for j = low to high do (* … actual work *) done ]}.
If [chunk_size=1] then [low=high] and the loop is not actually needed.
@param chunk_size controls the granularity of parallelism.
The default chunk size is not specified.
See {!all_array} or {!all_list} for more details.
Example:
{[
let total_sum = Atomic.make 0
let() = for_ ~chunk_size:5 100
(fun low high ->
(* iterate on the range sequentially. The range should have 5 items or less. *)
let local_sum = ref 0 in
for j=low to high do
local_sum := !local_sum + j
done;
ignore (Atomic.fetch_and_add total_sum !local_sum : int)))
let() = assert (Atomic.get total_sum = 4950)
]}
Note how we still compute a local sum sequentially in [(fun low high -> )],
before combining it wholesale into [total_sum]. When the chunk size is large,
this can have a dramatic impact on the synchronization overhead.
When [chunk_size] is not provided, the library will attempt to guess a value
that keeps all cores busy but runs as few tasks as possible to reduce
the synchronization overhead.
Use [~chunk_size:1] if you explicitly want to
run each iteration of the loop in its own task.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_array : ?chunk_size:int -> (unit -> 'a) array -> 'a array
(** [all_array fs] runs all functions in [fs] in tasks, and waits for
all the results.
@param chunk_size if equal to [n], groups items by [n] to be run in
a single task. Default is [1].
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_list : ?chunk_size:int -> (unit -> 'a) list -> 'a list
(** [all_list fs] runs all functions in [fs] in tasks, and waits for
all the results.
@param chunk_size if equal to [n], groups items by [n] to be run in
a single task. Default is not specified.
This parameter is available since 0.3.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_init : ?chunk_size:int -> int -> (int -> 'a) -> 'a list
(** [all_init n f] runs functions [f 0], [f 1], … [f (n-1)] in tasks, and waits for
all the results.
@param chunk_size if equal to [n], groups items by [n] to be run in
a single task. Default is not specified.
This parameter is available since 0.3.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val map_array : ?chunk_size:int -> ('a -> 'b) -> 'a array -> 'b array
(** [map_array f arr] is like [Array.map f arr], but runs in parallel.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val map_list : ?chunk_size:int -> ('a -> 'b) -> 'a list -> 'b list
(** [map_list f l] is like [List.map f l], but runs in parallel.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
[@@@endif]

View file

@ -15,6 +15,7 @@
(>= %{ocaml_version} 5.0)) (>= %{ocaml_version} 5.0))
(libraries (libraries
moonpool moonpool
moonpool.forkjoin
trace trace
trace-tef trace-tef
qcheck-core qcheck-core

View file

@ -1,6 +1,7 @@
[@@@ifge 5.0] [@@@ifge 5.0]
open Moonpool open Moonpool
module FJ = Moonpool_forkjoin
let rec fib_direct x = let rec fib_direct x =
if x <= 1 then if x <= 1 then
@ -14,7 +15,7 @@ let fib ~on x : int Fut.t =
fib_direct x fib_direct x
else ( else (
let n1, n2 = let n1, n2 =
Fork_join.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) FJ.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2))
in in
n1 + n2 n1 + n2
) )

View file

@ -3,6 +3,7 @@
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
open Moonpool open Moonpool
module FJ = Moonpool_forkjoin
let rec fib_direct x = let rec fib_direct x =
if x <= 1 then if x <= 1 then
@ -15,9 +16,7 @@ let rec fib x : int =
if x <= 18 then if x <= 18 then
fib_direct x fib_direct x
else ( else (
let n1, n2 = let n1, n2 = FJ.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2)) in
Fork_join.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2))
in
n1 + n2 n1 + n2
) )
@ -32,7 +31,7 @@ let run_test () =
let fut = let fut =
Fut.spawn ~on:pool (fun () -> Fut.spawn ~on:pool (fun () ->
let fibs = Fork_join.all_init 3 (fun _ -> fib 40) in let fibs = FJ.all_init 3 (fun _ -> fib 40) in
fibs) fibs)
in in

View file

@ -4,6 +4,7 @@ let spf = Printf.sprintf
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
open! Moonpool open! Moonpool
module FJ = Moonpool_forkjoin
let pool = Ws_pool.create ~num_threads:4 () let pool = Ws_pool.create ~num_threads:4 ()
@ -11,7 +12,7 @@ let () =
let x = let x =
Ws_pool.run_wait_block pool (fun () -> Ws_pool.run_wait_block pool (fun () ->
let x, y = let x, y =
Fork_join.both FJ.both
(fun () -> (fun () ->
Thread.delay 0.005; Thread.delay 0.005;
1) 1)
@ -26,7 +27,7 @@ let () =
let () = let () =
try try
Ws_pool.run_wait_block pool (fun () -> Ws_pool.run_wait_block pool (fun () ->
Fork_join.both_ignore FJ.both_ignore
(fun () -> Thread.delay 0.005) (fun () -> Thread.delay 0.005)
(fun () -> (fun () ->
Thread.delay 0.02; Thread.delay 0.02;
@ -37,21 +38,20 @@ let () =
let () = let () =
let par_sum = let par_sum =
Ws_pool.run_wait_block pool (fun () -> Ws_pool.run_wait_block pool (fun () ->
Fork_join.all_init 42 (fun i -> i * i) |> List.fold_left ( + ) 0) FJ.all_init 42 (fun i -> i * i) |> List.fold_left ( + ) 0)
in in
let exp_sum = List.init 42 (fun x -> x * x) |> List.fold_left ( + ) 0 in let exp_sum = List.init 42 (fun x -> x * x) |> List.fold_left ( + ) 0 in
assert (par_sum = exp_sum) assert (par_sum = exp_sum)
let () = let () =
Ws_pool.run_wait_block pool (fun () -> Ws_pool.run_wait_block pool (fun () -> FJ.for_ 0 (fun _ _ -> assert false));
Fork_join.for_ 0 (fun _ _ -> assert false));
() ()
let () = let () =
let total_sum = Atomic.make 0 in let total_sum = Atomic.make 0 in
Ws_pool.run_wait_block pool (fun () -> Ws_pool.run_wait_block pool (fun () ->
Fork_join.for_ ~chunk_size:5 100 (fun low high -> FJ.for_ ~chunk_size:5 100 (fun low high ->
(* iterate on the range sequentially. The range should have 5 items or less. *) (* iterate on the range sequentially. The range should have 5 items or less. *)
let local_sum = ref 0 in let local_sum = ref 0 in
for i = low to high do for i = low to high do
@ -64,7 +64,7 @@ let () =
let total_sum = Atomic.make 0 in let total_sum = Atomic.make 0 in
Ws_pool.run_wait_block pool (fun () -> Ws_pool.run_wait_block pool (fun () ->
Fork_join.for_ ~chunk_size:1 100 (fun low high -> FJ.for_ ~chunk_size:1 100 (fun low high ->
assert (low = high); assert (low = high);
ignore (Atomic.fetch_and_add total_sum low : int))); ignore (Atomic.fetch_and_add total_sum low : int)));
assert (Atomic.get total_sum = 4950) assert (Atomic.get total_sum = 4950)
@ -82,7 +82,7 @@ let rec fib_fork_join n =
fib_direct n fib_direct n
else ( else (
let a, b = let a, b =
Fork_join.both FJ.both
(fun () -> fib_fork_join (n - 1)) (fun () -> fib_fork_join (n - 1))
(fun () -> fib_fork_join (n - 2)) (fun () -> fib_fork_join (n - 2))
in in
@ -254,13 +254,13 @@ module Evaluator = struct
| Ret x -> x | Ret x -> x
| Comp_fib n -> fib_fork_join n | Comp_fib n -> fib_fork_join n
| Add (a, b) -> | Add (a, b) ->
let a, b = Fork_join.both (fun () -> eval a) (fun () -> eval b) in let a, b = FJ.both (fun () -> eval a) (fun () -> eval b) in
a + b a + b
| Pipe (a, f) -> eval a |> apply_fun_seq f | Pipe (a, f) -> eval a |> apply_fun_seq f
| Map_arr (chunk_size, f, a, r) -> | Map_arr (chunk_size, f, a, r) ->
let tasks = List.map (fun x () -> eval x) a in let tasks = List.map (fun x () -> eval x) a in
Fork_join.all_list ~chunk_size tasks FJ.all_list ~chunk_size tasks
|> Fork_join.map_list ~chunk_size (apply_fun_seq f) |> FJ.map_list ~chunk_size (apply_fun_seq f)
|> eval_reducer r |> eval_reducer r
in in
@ -290,12 +290,8 @@ let t_for_nested ~min ~chunk_size () =
let l1, l2 = let l1, l2 =
let@ pool = Ws_pool.with_ ~num_threads:min () in let@ pool = Ws_pool.with_ ~num_threads:min () in
let@ () = Ws_pool.run_wait_block pool in let@ () = Ws_pool.run_wait_block pool in
let l1 = let l1 = FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l in
Fork_join.map_list ~chunk_size (Fork_join.map_list ~chunk_size neg) l let l2 = FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l1 in
in
let l2 =
Fork_join.map_list ~chunk_size (Fork_join.map_list ~chunk_size neg) l1
in
l1, l2 l1, l2
in in
@ -313,12 +309,8 @@ let t_map ~chunk_size () =
let@ pool = Ws_pool.with_ ~num_threads:4 () in let@ pool = Ws_pool.with_ ~num_threads:4 () in
let@ () = Ws_pool.run_wait_block pool in let@ () = Ws_pool.run_wait_block pool in
let a1 = let a1 = FJ.map_list ~chunk_size string_of_int l |> Array.of_list in
Fork_join.map_list ~chunk_size string_of_int l |> Array.of_list let a2 = FJ.map_array ~chunk_size string_of_int @@ Array.of_list l in
in
let a2 =
Fork_join.map_array ~chunk_size string_of_int @@ Array.of_list l
in
if a1 <> a2 then Q.Test.fail_reportf "a1=%s, a2=%s" (ppa a1) (ppa a2); if a1 <> a2 then Q.Test.fail_reportf "a1=%s, a2=%s" (ppa a1) (ppa a2);
true) true)

View file

@ -7,6 +7,7 @@ let ( let@ ) = ( @@ )
let ppl = Q.Print.(list @@ list int) let ppl = Q.Print.(list @@ list int)
open! Moonpool open! Moonpool
module FJ = Moonpool_forkjoin
let run ~min () = let run ~min () =
let@ _sp = let@ _sp =
@ -31,17 +32,13 @@ let run ~min () =
let@ () = Ws_pool.run_wait_block pool in let@ () = Ws_pool.run_wait_block pool in
let l1, l2 = let l1, l2 =
Fork_join.both FJ.both
(fun () -> (fun () ->
let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.left" in let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.left" in
Fork_join.map_list ~chunk_size FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l)
(Fork_join.map_list ~chunk_size neg)
l)
(fun () -> (fun () ->
let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.right" in let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.right" in
Fork_join.map_list ~chunk_size FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) ref_l1)
(Fork_join.map_list ~chunk_size neg)
ref_l1)
in in
l1, l2 l1, l2
in in

View file

@ -1,6 +1,7 @@
[@@@ifge 5.0] [@@@ifge 5.0]
open Moonpool open Moonpool
module FJ = Moonpool_forkjoin
let rec select_sort arr i len = let rec select_sort arr i len =
if len >= 2 then ( if len >= 2 then (
@ -54,7 +55,7 @@ let rec quicksort arr i len : unit =
) )
done; done;
Fork_join.both_ignore FJ.both_ignore
(fun () -> quicksort arr i (!low - i)) (fun () -> quicksort arr i (!low - i))
(fun () -> quicksort arr !low (len - (!low - i))) (fun () -> quicksort arr !low (len - (!low - i)))
) )