add Fork_join.all_{list,init}

primitives to fork-join over n tasks
This commit is contained in:
Simon Cruanes 2023-06-23 23:18:52 -04:00
parent 45838d9607
commit 43487ebe49
4 changed files with 97 additions and 1 deletions

View file

@ -86,3 +86,45 @@ let both f g : _ * _ =
start_tasks ~run ());
};
get_exn st
let all_list fs : _ list =
let len = List.length fs in
let arr = Array.make len None in
let has_failed = A.make false in
let missing = A.make len in
let start_tasks ~run (suspension : Suspend_types_.suspension) =
let task_for i f =
try
let x = f () in
arr.(i) <- Some x;
if A.fetch_and_add missing (-1) = 1 then
(* all tasks done successfully *)
suspension (Ok ())
with exn ->
let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then
(* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *)
suspension (Error (exn, bt))
in
List.iteri (fun i f -> run ~with_handler:true (fun () -> task_for i f)) fs
in
Suspend_.suspend
{
Suspend_types_.handle =
(fun ~run suspension ->
(* nothing else is started, no race condition possible *)
start_tasks ~run suspension);
};
(* get all results *)
List.init len (fun i ->
match arr.(i) with
| None -> assert false
| Some x -> x)
let all_init n f = all_list @@ List.init n (fun i () -> f i)

View file

@ -13,3 +13,15 @@ val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_list : (unit -> 'a) list -> 'a list
(** [all_list fs] runs all functions in [fs] in tasks, and waits for
all the results.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)
val all_init : int -> (int -> 'a) -> 'a list
(** [all_init n f] runs functions [f 0], [f 1], … [f (n-1)] in tasks, and waits for
all the results.
@since 0.3
{b NOTE} this is only available on OCaml 5. *)

View file

@ -1,6 +1,6 @@
(tests
(names t_fib1 t_futs1 t_many t_fork_join)
(names t_fib1 t_futs1 t_many t_fork_join t_fork_join_all)
(enabled_if (>= %{ocaml_version} 5.0))
(libraries moonpool trace ;tracy-client.trace
))

View file

@ -0,0 +1,42 @@
open Moonpool
let rec fib_direct x =
if x <= 1 then
1
else
fib_direct (x - 1) + fib_direct (x - 2)
let rec fib x : int =
if x <= 18 then
fib_direct x
else (
let n1, n2 =
Fork_join.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2))
in
n1 + n2
)
let fib_40 : int =
let pool = Pool.create ~min:8 () in
Fut.spawn ~on:pool (fun () -> fib 40) |> Fut.wait_block_exn
let () = Printf.printf "fib 40 = %d\n%!" fib_40
let run_test () =
let pool = Pool.create ~min:8 () in
let fut =
Fut.spawn ~on:pool (fun () ->
let fibs = Fork_join.all_init 3 (fun _ -> fib 40) in
fibs)
in
let res = Fut.wait_block_exn fut in
Pool.shutdown pool;
assert (res = (Array.make 3 fib_40 |> Array.to_list))
let () =
(* now make sure we can do this with multiple pools in parallel *)
let jobs = Array.init 2 (fun _ -> Thread.create run_test ()) in
Array.iter Thread.join jobs