This commit is contained in:
Simon Cruanes 2024-02-02 20:25:11 -05:00
parent b0fe279f42
commit 37c42b68bc
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
8 changed files with 12 additions and 335 deletions

View file

@ -3,4 +3,4 @@
(preprocess
(action
(run %{project_root}/src/cpp/cpp.exe %{input-file})))
(libraries moonpool unix trace trace-tef domainslib))
(libraries moonpool moonpool.forkjoin unix trace trace-tef domainslib))

View file

@ -1,5 +1,6 @@
open Moonpool
module Trace = Trace_core
module FJ = Moonpool_forkjoin
let ( let@ ) = ( @@ )
@ -25,7 +26,7 @@ let fib_fj ~on x : int Fut.t =
fib_direct x
else (
let n1, n2 =
Fork_join.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2))
FJ.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2))
in
n1 + n2
)

View file

@ -1,6 +1,7 @@
(* compute Pi *)
open Moonpool
module FJ = Moonpool_forkjoin
let ( let@ ) = ( @@ )
let j = ref 0
@ -76,7 +77,7 @@ let run_fork_join ~kind num_steps : float =
let global_sum = Lock.create 0. in
Ws_pool.run_wait_block ~name:"pi.fj" pool (fun () ->
Fork_join.for_
FJ.for_
~chunk_size:(3 + (num_steps / num_tasks))
num_steps
(fun low high ->

View file

@ -1,220 +0,0 @@
[@@@ifge 5.0]
module A = Atomic_
module State_ = struct
type error = exn * Printexc.raw_backtrace
type 'a or_error = ('a, error) result
type ('a, 'b) t =
| Init
| Left_solved of 'a or_error
| Right_solved of 'b or_error * Suspend_.suspension
| Both_solved of 'a or_error * 'b or_error
let get_exn_ (self : _ t A.t) =
match A.get self with
| Both_solved (Ok a, Ok b) -> a, b
| Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) ->
Printexc.raise_with_backtrace exn bt
| _ -> assert false
let rec set_left_ (self : _ t A.t) (left : _ or_error) =
let old_st = A.get self in
match old_st with
| Init ->
let new_st = Left_solved left in
if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
set_left_ self left
)
| Right_solved (right, cont) ->
let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
set_left_ self left
) else
cont (Ok ())
| Left_solved _ | Both_solved _ -> assert false
let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
let old_st = A.get self in
match old_st with
| Left_solved left ->
let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then set_right_ self right
| Init ->
(* we are first arrived, we suspend until the left computation is done *)
Suspend_.suspend
{
Suspend_.handle =
(fun ~name:_ ~run:_ suspension ->
while
let old_st = A.get self in
match old_st with
| Init ->
not
(A.compare_and_set self old_st
(Right_solved (right, suspension)))
| Left_solved left ->
(* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right));
suspension (Ok ());
false
| Right_solved _ | Both_solved _ -> assert false
do
()
done);
}
| Right_solved _ | Both_solved _ -> assert false
end
let both f g : _ * _ =
let module ST = State_ in
let st = A.make ST.Init in
let runner =
match Runner.get_current_runner () with
| None -> invalid_arg "Fork_join.both must be run from within a runner"
| Some r -> r
in
(* start computing [f] in the background *)
Runner.run_async runner (fun () ->
try
let res = f () in
ST.set_left_ st (Ok res)
with exn ->
let bt = Printexc.get_raw_backtrace () in
ST.set_left_ st (Error (exn, bt)));
let res_right =
try Ok (g ())
with exn ->
let bt = Printexc.get_raw_backtrace () in
Error (exn, bt)
in
ST.set_right_ st res_right;
ST.get_exn_ st
let both_ignore f g = ignore (both f g : _ * _)
let for_ ?chunk_size n (f : int -> int -> unit) : unit =
if n > 0 then (
let has_failed = A.make false in
let missing = A.make n in
let chunk_size =
match chunk_size with
| Some cs -> max 1 (min n cs)
| None ->
(* guess: try to have roughly one task per core *)
max 1 (1 + (n / Domain_pool_.n_domains ()))
in
let start_tasks ~name ~run (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with
| () ->
if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *)
run ~name (fun () -> suspension (Ok ()))
| exception exn ->
let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then
(* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *)
run ~name (fun () -> suspension (Error (exn, bt)))
in
let i = ref 0 in
while !i < n do
let offset = !i in
let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n);
run ~name (fun () -> task_for ~offset ~len_range);
i := !i + len_range
done
in
Suspend_.suspend
{
Suspend_.handle =
(fun ~name ~run suspension ->
(* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~name suspension);
}
)
let all_array ?chunk_size (fs : _ array) : _ array =
let len = Array.length fs in
let arr = Array.make len None in
(* parallel for *)
for_ ?chunk_size len (fun low high ->
for i = low to high do
let x = fs.(i) () in
arr.(i) <- Some x
done);
(* get all results *)
Array.map
(function
| None -> assert false
| Some x -> x)
arr
let all_list ?chunk_size fs : _ list =
Array.to_list @@ all_array ?chunk_size @@ Array.of_list fs
let all_init ?chunk_size n f : _ list =
let arr = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
let x = f i in
arr.(i) <- Some x
done);
(* get all results *)
List.init n (fun i ->
match arr.(i) with
| None -> assert false
| Some x -> x)
let map_array ?chunk_size f arr : _ array =
let n = Array.length arr in
let res = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
res.(i) <- Some (f arr.(i))
done);
(* get all results *)
Array.map
(function
| None -> assert false
| Some x -> x)
res
let map_list ?chunk_size f (l : _ list) : _ list =
let arr = Array.of_list l in
let n = Array.length arr in
let res = Array.make n None in
for_ ?chunk_size n (fun low high ->
for i = low to high do
res.(i) <- Some (f arr.(i))
done);
(* get all results *)
List.init n (fun i ->
match res.(i) with
| None -> assert false
| Some x -> x)
[@@@endif]

View file

@ -1,109 +0,0 @@
(** Fork-join primitives.
{b NOTE} These are only available on OCaml 5.0 and above.
@since 0.3 *)
[@@@ifge 5.0]
val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b
(** [both f g] runs [f()] and [g()], potentially in parallel,
and returns their result when both are done.
If any of [f()] and [g()] fails, then the whole computation fails.
This must be run from within the pool: for example, inside {!Pool.run}
or inside a {!Fut.spawn} computation.
This is because it relies on an effect handler to be installed.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val both_ignore : (unit -> _) -> (unit -> _) -> unit
(** Same as [both f g |> ignore].
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val for_ : ?chunk_size:int -> int -> (int -> int -> unit) -> unit
(** [for_ n f] is the parallel version of [for i=0 to n-1 do f i done].
[f] is called with parameters [low] and [high] and must use them like so:
{[ for j = low to high do (* … actual work *) done ]}.
If [chunk_size=1] then [low=high] and the loop is not actually needed.
@param chunk_size controls the granularity of parallelism.
The default chunk size is not specified.
See {!all_array} or {!all_list} for more details.
Example:
{[
let total_sum = Atomic.make 0
let() = for_ ~chunk_size:5 100
(fun low high ->
(* iterate on the range sequentially. The range should have 5 items or less. *)
let local_sum = ref 0 in
for j=low to high do
local_sum := !local_sum + j
done;
ignore (Atomic.fetch_and_add total_sum !local_sum : int)))
let() = assert (Atomic.get total_sum = 4950)
]}
Note how we still compute a local sum sequentially in [(fun low high -> )],
before combining it wholesale into [total_sum]. When the chunk size is large,
this can have a dramatic impact on the synchronization overhead.
When [chunk_size] is not provided, the library will attempt to guess a value
that keeps all cores busy but runs as few tasks as possible to reduce
the synchronization overhead.
Use [~chunk_size:1] if you explicitly want to
run each iteration of the loop in its own task.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_array : ?chunk_size:int -> (unit -> 'a) array -> 'a array
(** [all_array fs] runs all functions in [fs] in tasks, and waits for
all the results.
@param chunk_size if equal to [n], groups items by [n] to be run in
a single task. Default is [1].
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_list : ?chunk_size:int -> (unit -> 'a) list -> 'a list
(** [all_list fs] runs all functions in [fs] in tasks, and waits for
all the results.
@param chunk_size if equal to [n], groups items by [n] to be run in
a single task. Default is not specified.
This parameter is available since 0.3.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_init : ?chunk_size:int -> int -> (int -> 'a) -> 'a list
(** [all_init n f] runs functions [f 0], [f 1], … [f (n-1)] in tasks, and waits for
all the results.
@param chunk_size if equal to [n], groups items by [n] to be run in
a single task. Default is not specified.
This parameter is available since 0.3.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val map_array : ?chunk_size:int -> ('a -> 'b) -> 'a array -> 'b array
(** [map_array f arr] is like [Array.map f arr], but runs in parallel.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val map_list : ?chunk_size:int -> ('a -> 'b) -> 'a list -> 'b list
(** [map_list f l] is like [List.map f l], but runs in parallel.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
[@@@endif]

View file

@ -18,7 +18,6 @@ module Blocking_queue = Bb_queue
module Bounded_queue = Bounded_queue
module Chan = Chan
module Fifo_pool = Fifo_pool
module Fork_join = Fork_join
module Fut = Fut
module Lock = Lock
module Immediate_runner = Immediate_runner
@ -30,4 +29,6 @@ module Ws_pool = Ws_pool
module Private = struct
module Ws_deque_ = Ws_deque_
module Suspend_ = Suspend_
let num_domains = Domain_pool_.n_domains
end

View file

@ -62,7 +62,6 @@ val await : 'a Fut.t -> 'a
module Lock = Lock
module Fut = Fut
module Chan = Chan
module Fork_join = Fork_join
module Thread_local_storage = Thread_local_storage_
(** A simple blocking queue.
@ -204,4 +203,9 @@ module Private : sig
This is only going to work on OCaml 5.x.
{b NOTE}: this is not stable for now. *)
module Domain = Domain_
val num_domains : unit -> int
(** Number of domains in the backing domain pool *)
end

View file

@ -37,6 +37,5 @@ let with_suspend ~name ~on_suspend ~(run : name:string -> task -> unit)
[@@@else_]
let[@inline] with_suspend ~name:_ ~on_suspend:_ ~run:_ f = f ()
let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore }
[@@@endif]