From 43487ebe4940b857e6cce6acf5d5eef1cf96b9b0 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 23 Jun 2023 23:18:52 -0400 Subject: [PATCH] add `Fork_join.all_{list,init}` primitives to fork-join over n tasks --- src/fork_join.ml | 42 +++++++++++++++++++++++++++++++++++ src/fork_join.mli | 12 ++++++++++ test/await/dune | 2 +- test/await/t_fork_join_all.ml | 42 +++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 test/await/t_fork_join_all.ml diff --git a/src/fork_join.ml b/src/fork_join.ml index ad7b9ad0..a90afe72 100644 --- a/src/fork_join.ml +++ b/src/fork_join.ml @@ -86,3 +86,45 @@ let both f g : _ * _ = start_tasks ~run ()); }; get_exn st + +let all_list fs : _ list = + let len = List.length fs in + let arr = Array.make len None in + let has_failed = A.make false in + let missing = A.make len in + + let start_tasks ~run (suspension : Suspend_types_.suspension) = + let task_for i f = + try + let x = f () in + arr.(i) <- Some x; + + if A.fetch_and_add missing (-1) = 1 then + (* all tasks done successfully *) + suspension (Ok ()) + with exn -> + let bt = Printexc.get_raw_backtrace () in + if not (A.exchange has_failed true) then + (* first one to fail, and [missing] must be >= 2 + because we're not decreasing it. *) + suspension (Error (exn, bt)) + in + + List.iteri (fun i f -> run ~with_handler:true (fun () -> task_for i f)) fs + in + + Suspend_.suspend + { + Suspend_types_.handle = + (fun ~run suspension -> + (* nothing else is started, no race condition possible *) + start_tasks ~run suspension); + }; + + (* get all results *) + List.init len (fun i -> + match arr.(i) with + | None -> assert false + | Some x -> x) + +let all_init n f = all_list @@ List.init n (fun i () -> f i) diff --git a/src/fork_join.mli b/src/fork_join.mli index 8cf36306..3cb4e75e 100644 --- a/src/fork_join.mli +++ b/src/fork_join.mli @@ -13,3 +13,15 @@ val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b @since 0.3 {b NOTE} this is only available on OCaml 5. *) + +val all_list : (unit -> 'a) list -> 'a list +(** [all_list fs] runs all functions in [fs] in tasks, and waits for + all the results. + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val all_init : int -> (int -> 'a) -> 'a list +(** [all_init n f] runs functions [f 0], [f 1], … [f (n-1)] in tasks, and waits for + all the results. + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) diff --git a/test/await/dune b/test/await/dune index 7ed8b854..cc84813d 100644 --- a/test/await/dune +++ b/test/await/dune @@ -1,6 +1,6 @@ (tests - (names t_fib1 t_futs1 t_many t_fork_join) + (names t_fib1 t_futs1 t_many t_fork_join t_fork_join_all) (enabled_if (>= %{ocaml_version} 5.0)) (libraries moonpool trace ;tracy-client.trace )) diff --git a/test/await/t_fork_join_all.ml b/test/await/t_fork_join_all.ml new file mode 100644 index 00000000..4539159d --- /dev/null +++ b/test/await/t_fork_join_all.ml @@ -0,0 +1,42 @@ +open Moonpool + +let rec fib_direct x = + if x <= 1 then + 1 + else + fib_direct (x - 1) + fib_direct (x - 2) + +let rec fib x : int = + if x <= 18 then + fib_direct x + else ( + let n1, n2 = + Fork_join.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2)) + in + n1 + n2 + ) + +let fib_40 : int = + let pool = Pool.create ~min:8 () in + Fut.spawn ~on:pool (fun () -> fib 40) |> Fut.wait_block_exn + +let () = Printf.printf "fib 40 = %d\n%!" fib_40 + +let run_test () = + let pool = Pool.create ~min:8 () in + + let fut = + Fut.spawn ~on:pool (fun () -> + let fibs = Fork_join.all_init 3 (fun _ -> fib 40) in + fibs) + in + + let res = Fut.wait_block_exn fut in + Pool.shutdown pool; + + assert (res = (Array.make 3 fib_40 |> Array.to_list)) + +let () = + (* now make sure we can do this with multiple pools in parallel *) + let jobs = Array.init 2 (fun _ -> Thread.create run_test ()) in + Array.iter Thread.join jobs