mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-09 04:35:33 -05:00
add Fork_join.all_{list,init}
primitives to fork-join over n tasks
This commit is contained in:
parent
45838d9607
commit
43487ebe49
4 changed files with 97 additions and 1 deletions
|
|
@ -86,3 +86,45 @@ let both f g : _ * _ =
|
|||
start_tasks ~run ());
|
||||
};
|
||||
get_exn st
|
||||
|
||||
let all_list fs : _ list =
|
||||
let len = List.length fs in
|
||||
let arr = Array.make len None in
|
||||
let has_failed = A.make false in
|
||||
let missing = A.make len in
|
||||
|
||||
let start_tasks ~run (suspension : Suspend_types_.suspension) =
|
||||
let task_for i f =
|
||||
try
|
||||
let x = f () in
|
||||
arr.(i) <- Some x;
|
||||
|
||||
if A.fetch_and_add missing (-1) = 1 then
|
||||
(* all tasks done successfully *)
|
||||
suspension (Ok ())
|
||||
with exn ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
if not (A.exchange has_failed true) then
|
||||
(* first one to fail, and [missing] must be >= 2
|
||||
because we're not decreasing it. *)
|
||||
suspension (Error (exn, bt))
|
||||
in
|
||||
|
||||
List.iteri (fun i f -> run ~with_handler:true (fun () -> task_for i f)) fs
|
||||
in
|
||||
|
||||
Suspend_.suspend
|
||||
{
|
||||
Suspend_types_.handle =
|
||||
(fun ~run suspension ->
|
||||
(* nothing else is started, no race condition possible *)
|
||||
start_tasks ~run suspension);
|
||||
};
|
||||
|
||||
(* get all results *)
|
||||
List.init len (fun i ->
|
||||
match arr.(i) with
|
||||
| None -> assert false
|
||||
| Some x -> x)
|
||||
|
||||
let all_init n f = all_list @@ List.init n (fun i () -> f i)
|
||||
|
|
|
|||
|
|
@ -13,3 +13,15 @@ val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b
|
|||
|
||||
@since 0.3
|
||||
{b NOTE} this is only available on OCaml 5. *)
|
||||
|
||||
val all_list : (unit -> 'a) list -> 'a list
|
||||
(** [all_list fs] runs all functions in [fs] in tasks, and waits for
|
||||
all the results.
|
||||
@since 0.3
|
||||
{b NOTE} this is only available on OCaml 5. *)
|
||||
|
||||
val all_init : int -> (int -> 'a) -> 'a list
|
||||
(** [all_init n f] runs functions [f 0], [f 1], … [f (n-1)] in tasks, and waits for
|
||||
all the results.
|
||||
@since 0.3
|
||||
{b NOTE} this is only available on OCaml 5. *)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
(tests
|
||||
(names t_fib1 t_futs1 t_many t_fork_join)
|
||||
(names t_fib1 t_futs1 t_many t_fork_join t_fork_join_all)
|
||||
(enabled_if (>= %{ocaml_version} 5.0))
|
||||
(libraries moonpool trace ;tracy-client.trace
|
||||
))
|
||||
|
|
|
|||
42
test/await/t_fork_join_all.ml
Normal file
42
test/await/t_fork_join_all.ml
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
open Moonpool
|
||||
|
||||
let rec fib_direct x =
|
||||
if x <= 1 then
|
||||
1
|
||||
else
|
||||
fib_direct (x - 1) + fib_direct (x - 2)
|
||||
|
||||
let rec fib x : int =
|
||||
if x <= 18 then
|
||||
fib_direct x
|
||||
else (
|
||||
let n1, n2 =
|
||||
Fork_join.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2))
|
||||
in
|
||||
n1 + n2
|
||||
)
|
||||
|
||||
let fib_40 : int =
|
||||
let pool = Pool.create ~min:8 () in
|
||||
Fut.spawn ~on:pool (fun () -> fib 40) |> Fut.wait_block_exn
|
||||
|
||||
let () = Printf.printf "fib 40 = %d\n%!" fib_40
|
||||
|
||||
let run_test () =
|
||||
let pool = Pool.create ~min:8 () in
|
||||
|
||||
let fut =
|
||||
Fut.spawn ~on:pool (fun () ->
|
||||
let fibs = Fork_join.all_init 3 (fun _ -> fib 40) in
|
||||
fibs)
|
||||
in
|
||||
|
||||
let res = Fut.wait_block_exn fut in
|
||||
Pool.shutdown pool;
|
||||
|
||||
assert (res = (Array.make 3 fib_40 |> Array.to_list))
|
||||
|
||||
let () =
|
||||
(* now make sure we can do this with multiple pools in parallel *)
|
||||
let jobs = Array.init 2 (fun _ -> Thread.create run_test ()) in
|
||||
Array.iter Thread.join jobs
|
||||
Loading…
Add table
Reference in a new issue