diff --git a/src/core/moonpool.ml b/src/core/moonpool.ml index 24da3674..44e4dffc 100644 --- a/src/core/moonpool.ml +++ b/src/core/moonpool.ml @@ -29,6 +29,7 @@ module Ws_pool = Ws_pool module Private = struct module Ws_deque_ = Ws_deque_ module Suspend_ = Suspend_ + module Domain_ = Domain_ let num_domains = Domain_pool_.n_domains end diff --git a/src/core/moonpool.mli b/src/core/moonpool.mli index f7b82df5..9b591fc1 100644 --- a/src/core/moonpool.mli +++ b/src/core/moonpool.mli @@ -204,7 +204,8 @@ module Private : sig {b NOTE}: this is not stable for now. *) - module Domain = Domain_ + module Domain_ = Domain_ + (** Utils for domains *) val num_domains : unit -> int (** Number of domains in the backing domain pool *) diff --git a/src/forkjoin/dune b/src/forkjoin/dune new file mode 100644 index 00000000..e64b8f22 --- /dev/null +++ b/src/forkjoin/dune @@ -0,0 +1,12 @@ + + +(library + (name moonpool_forkjoin) + (public_name moonpool.forkjoin) + (synopsis "Fork-join parallelism for moonpool") + (flags :standard -open Moonpool) + (preprocess + (action + (run %{project_root}/src/cpp/cpp.exe %{input-file}))) + (libraries + moonpool moonpool.private)) diff --git a/src/forkjoin/moonpool_forkjoin.ml b/src/forkjoin/moonpool_forkjoin.ml new file mode 100644 index 00000000..4b8be02d --- /dev/null +++ b/src/forkjoin/moonpool_forkjoin.ml @@ -0,0 +1,222 @@ +[@@@ifge 5.0] + +module A = Moonpool.Atomic +module Suspend_ = Moonpool.Private.Suspend_ +module Domain_ = Moonpool_private.Domain_ + +module State_ = struct + type error = exn * Printexc.raw_backtrace + type 'a or_error = ('a, error) result + + type ('a, 'b) t = + | Init + | Left_solved of 'a or_error + | Right_solved of 'b or_error * Suspend_.suspension + | Both_solved of 'a or_error * 'b or_error + + let get_exn_ (self : _ t A.t) = + match A.get self with + | Both_solved (Ok a, Ok b) -> a, b + | Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) -> + Printexc.raise_with_backtrace exn bt + | _ -> assert false + + let rec set_left_ (self : _ t A.t) (left : _ or_error) = + let old_st = A.get self in + match old_st with + | Init -> + let new_st = Left_solved left in + if not (A.compare_and_set self old_st new_st) then ( + Domain_.relax (); + set_left_ self left + ) + | Right_solved (right, cont) -> + let new_st = Both_solved (left, right) in + if not (A.compare_and_set self old_st new_st) then ( + Domain_.relax (); + set_left_ self left + ) else + cont (Ok ()) + | Left_solved _ | Both_solved _ -> assert false + + let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit = + let old_st = A.get self in + match old_st with + | Left_solved left -> + let new_st = Both_solved (left, right) in + if not (A.compare_and_set self old_st new_st) then set_right_ self right + | Init -> + (* we are first arrived, we suspend until the left computation is done *) + Suspend_.suspend + { + Suspend_.handle = + (fun ~name:_ ~run:_ suspension -> + while + let old_st = A.get self in + match old_st with + | Init -> + not + (A.compare_and_set self old_st + (Right_solved (right, suspension))) + | Left_solved left -> + (* other thread is done, no risk of race condition *) + A.set self (Both_solved (left, right)); + suspension (Ok ()); + false + | Right_solved _ | Both_solved _ -> assert false + do + () + done); + } + | Right_solved _ | Both_solved _ -> assert false +end + +let both f g : _ * _ = + let module ST = State_ in + let st = A.make ST.Init in + + let runner = + match Runner.get_current_runner () with + | None -> invalid_arg "Fork_join.both must be run from within a runner" + | Some r -> r + in + + (* start computing [f] in the background *) + Runner.run_async runner (fun () -> + try + let res = f () in + ST.set_left_ st (Ok res) + with exn -> + let bt = Printexc.get_raw_backtrace () in + ST.set_left_ st (Error (exn, bt))); + + let res_right = + try Ok (g ()) + with exn -> + let bt = Printexc.get_raw_backtrace () in + Error (exn, bt) + in + + ST.set_right_ st res_right; + ST.get_exn_ st + +let both_ignore f g = ignore (both f g : _ * _) + +let for_ ?chunk_size n (f : int -> int -> unit) : unit = + if n > 0 then ( + let has_failed = A.make false in + let missing = A.make n in + + let chunk_size = + match chunk_size with + | Some cs -> max 1 (min n cs) + | None -> + (* guess: try to have roughly one task per core *) + max 1 (1 + (n / Moonpool.Private.num_domains ())) + in + + let start_tasks ~name ~run (suspension : Suspend_.suspension) = + let task_for ~offset ~len_range = + match f offset (offset + len_range - 1) with + | () -> + if A.fetch_and_add missing (-len_range) = len_range then + (* all tasks done successfully *) + run ~name (fun () -> suspension (Ok ())) + | exception exn -> + let bt = Printexc.get_raw_backtrace () in + if not (A.exchange has_failed true) then + (* first one to fail, and [missing] must be >= 2 + because we're not decreasing it. *) + run ~name (fun () -> suspension (Error (exn, bt))) + in + + let i = ref 0 in + while !i < n do + let offset = !i in + + let len_range = min chunk_size (n - offset) in + assert (offset + len_range <= n); + + run ~name (fun () -> task_for ~offset ~len_range); + i := !i + len_range + done + in + + Suspend_.suspend + { + Suspend_.handle = + (fun ~name ~run suspension -> + (* run tasks, then we'll resume [suspension] *) + start_tasks ~run ~name suspension); + } + ) + +let all_array ?chunk_size (fs : _ array) : _ array = + let len = Array.length fs in + let arr = Array.make len None in + + (* parallel for *) + for_ ?chunk_size len (fun low high -> + for i = low to high do + let x = fs.(i) () in + arr.(i) <- Some x + done); + + (* get all results *) + Array.map + (function + | None -> assert false + | Some x -> x) + arr + +let all_list ?chunk_size fs : _ list = + Array.to_list @@ all_array ?chunk_size @@ Array.of_list fs + +let all_init ?chunk_size n f : _ list = + let arr = Array.make n None in + + for_ ?chunk_size n (fun low high -> + for i = low to high do + let x = f i in + arr.(i) <- Some x + done); + + (* get all results *) + List.init n (fun i -> + match arr.(i) with + | None -> assert false + | Some x -> x) + +let map_array ?chunk_size f arr : _ array = + let n = Array.length arr in + let res = Array.make n None in + + for_ ?chunk_size n (fun low high -> + for i = low to high do + res.(i) <- Some (f arr.(i)) + done); + + (* get all results *) + Array.map + (function + | None -> assert false + | Some x -> x) + res + +let map_list ?chunk_size f (l : _ list) : _ list = + let arr = Array.of_list l in + let n = Array.length arr in + let res = Array.make n None in + + for_ ?chunk_size n (fun low high -> + for i = low to high do + res.(i) <- Some (f arr.(i)) + done); + + (* get all results *) + List.init n (fun i -> + match res.(i) with + | None -> assert false + | Some x -> x) + +[@@@endif] diff --git a/src/forkjoin/moonpool_forkjoin.mli b/src/forkjoin/moonpool_forkjoin.mli new file mode 100644 index 00000000..3ffa537d --- /dev/null +++ b/src/forkjoin/moonpool_forkjoin.mli @@ -0,0 +1,109 @@ +(** Fork-join primitives. + + {b NOTE} These are only available on OCaml 5.0 and above. + + @since 0.3 *) + +[@@@ifge 5.0] + +val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b +(** [both f g] runs [f()] and [g()], potentially in parallel, + and returns their result when both are done. + If any of [f()] and [g()] fails, then the whole computation fails. + + This must be run from within the pool: for example, inside {!Pool.run} + or inside a {!Fut.spawn} computation. + This is because it relies on an effect handler to be installed. + + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val both_ignore : (unit -> _) -> (unit -> _) -> unit +(** Same as [both f g |> ignore]. + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val for_ : ?chunk_size:int -> int -> (int -> int -> unit) -> unit +(** [for_ n f] is the parallel version of [for i=0 to n-1 do f i done]. + + [f] is called with parameters [low] and [high] and must use them like so: + {[ for j = low to high do (* … actual work *) done ]}. + If [chunk_size=1] then [low=high] and the loop is not actually needed. + + @param chunk_size controls the granularity of parallelism. + The default chunk size is not specified. + See {!all_array} or {!all_list} for more details. + + Example: + {[ + let total_sum = Atomic.make 0 + + let() = for_ ~chunk_size:5 100 + (fun low high -> + (* iterate on the range sequentially. The range should have 5 items or less. *) + let local_sum = ref 0 in + for j=low to high do + local_sum := !local_sum + j + done; + ignore (Atomic.fetch_and_add total_sum !local_sum : int))) + + let() = assert (Atomic.get total_sum = 4950) + ]} + + Note how we still compute a local sum sequentially in [(fun low high -> …)], + before combining it wholesale into [total_sum]. When the chunk size is large, + this can have a dramatic impact on the synchronization overhead. + + When [chunk_size] is not provided, the library will attempt to guess a value + that keeps all cores busy but runs as few tasks as possible to reduce + the synchronization overhead. + + Use [~chunk_size:1] if you explicitly want to + run each iteration of the loop in its own task. + + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val all_array : ?chunk_size:int -> (unit -> 'a) array -> 'a array +(** [all_array fs] runs all functions in [fs] in tasks, and waits for + all the results. + + @param chunk_size if equal to [n], groups items by [n] to be run in + a single task. Default is [1]. + + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val all_list : ?chunk_size:int -> (unit -> 'a) list -> 'a list +(** [all_list fs] runs all functions in [fs] in tasks, and waits for + all the results. + + @param chunk_size if equal to [n], groups items by [n] to be run in + a single task. Default is not specified. + This parameter is available since 0.3. + + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val all_init : ?chunk_size:int -> int -> (int -> 'a) -> 'a list +(** [all_init n f] runs functions [f 0], [f 1], … [f (n-1)] in tasks, and waits for + all the results. + + @param chunk_size if equal to [n], groups items by [n] to be run in + a single task. Default is not specified. + This parameter is available since 0.3. + + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val map_array : ?chunk_size:int -> ('a -> 'b) -> 'a array -> 'b array +(** [map_array f arr] is like [Array.map f arr], but runs in parallel. + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +val map_list : ?chunk_size:int -> ('a -> 'b) -> 'a list -> 'b list +(** [map_list f l] is like [List.map f l], but runs in parallel. + @since 0.3 + {b NOTE} this is only available on OCaml 5. *) + +[@@@endif] diff --git a/test/effect-based/dune b/test/effect-based/dune index 125ed267..4b654519 100644 --- a/test/effect-based/dune +++ b/test/effect-based/dune @@ -15,6 +15,7 @@ (>= %{ocaml_version} 5.0)) (libraries moonpool + moonpool.forkjoin trace trace-tef qcheck-core diff --git a/test/effect-based/t_fib_fork_join.ml b/test/effect-based/t_fib_fork_join.ml index 4e6639b2..25e7d49d 100644 --- a/test/effect-based/t_fib_fork_join.ml +++ b/test/effect-based/t_fib_fork_join.ml @@ -1,6 +1,7 @@ [@@@ifge 5.0] open Moonpool +module FJ = Moonpool_forkjoin let rec fib_direct x = if x <= 1 then @@ -14,7 +15,7 @@ let fib ~on x : int Fut.t = fib_direct x else ( let n1, n2 = - Fork_join.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) + FJ.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) in n1 + n2 ) diff --git a/test/effect-based/t_fib_fork_join_all.ml b/test/effect-based/t_fib_fork_join_all.ml index 3caee9b9..f80670ca 100644 --- a/test/effect-based/t_fib_fork_join_all.ml +++ b/test/effect-based/t_fib_fork_join_all.ml @@ -3,6 +3,7 @@ let ( let@ ) = ( @@ ) open Moonpool +module FJ = Moonpool_forkjoin let rec fib_direct x = if x <= 1 then @@ -15,9 +16,7 @@ let rec fib x : int = if x <= 18 then fib_direct x else ( - let n1, n2 = - Fork_join.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2)) - in + let n1, n2 = FJ.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2)) in n1 + n2 ) @@ -32,7 +31,7 @@ let run_test () = let fut = Fut.spawn ~on:pool (fun () -> - let fibs = Fork_join.all_init 3 (fun _ -> fib 40) in + let fibs = FJ.all_init 3 (fun _ -> fib 40) in fibs) in diff --git a/test/effect-based/t_fork_join.ml b/test/effect-based/t_fork_join.ml index 5c7134ca..83c291ab 100644 --- a/test/effect-based/t_fork_join.ml +++ b/test/effect-based/t_fork_join.ml @@ -4,6 +4,7 @@ let spf = Printf.sprintf let ( let@ ) = ( @@ ) open! Moonpool +module FJ = Moonpool_forkjoin let pool = Ws_pool.create ~num_threads:4 () @@ -11,7 +12,7 @@ let () = let x = Ws_pool.run_wait_block pool (fun () -> let x, y = - Fork_join.both + FJ.both (fun () -> Thread.delay 0.005; 1) @@ -26,7 +27,7 @@ let () = let () = try Ws_pool.run_wait_block pool (fun () -> - Fork_join.both_ignore + FJ.both_ignore (fun () -> Thread.delay 0.005) (fun () -> Thread.delay 0.02; @@ -37,21 +38,20 @@ let () = let () = let par_sum = Ws_pool.run_wait_block pool (fun () -> - Fork_join.all_init 42 (fun i -> i * i) |> List.fold_left ( + ) 0) + FJ.all_init 42 (fun i -> i * i) |> List.fold_left ( + ) 0) in let exp_sum = List.init 42 (fun x -> x * x) |> List.fold_left ( + ) 0 in assert (par_sum = exp_sum) let () = - Ws_pool.run_wait_block pool (fun () -> - Fork_join.for_ 0 (fun _ _ -> assert false)); + Ws_pool.run_wait_block pool (fun () -> FJ.for_ 0 (fun _ _ -> assert false)); () let () = let total_sum = Atomic.make 0 in Ws_pool.run_wait_block pool (fun () -> - Fork_join.for_ ~chunk_size:5 100 (fun low high -> + FJ.for_ ~chunk_size:5 100 (fun low high -> (* iterate on the range sequentially. The range should have 5 items or less. *) let local_sum = ref 0 in for i = low to high do @@ -64,7 +64,7 @@ let () = let total_sum = Atomic.make 0 in Ws_pool.run_wait_block pool (fun () -> - Fork_join.for_ ~chunk_size:1 100 (fun low high -> + FJ.for_ ~chunk_size:1 100 (fun low high -> assert (low = high); ignore (Atomic.fetch_and_add total_sum low : int))); assert (Atomic.get total_sum = 4950) @@ -82,7 +82,7 @@ let rec fib_fork_join n = fib_direct n else ( let a, b = - Fork_join.both + FJ.both (fun () -> fib_fork_join (n - 1)) (fun () -> fib_fork_join (n - 2)) in @@ -254,13 +254,13 @@ module Evaluator = struct | Ret x -> x | Comp_fib n -> fib_fork_join n | Add (a, b) -> - let a, b = Fork_join.both (fun () -> eval a) (fun () -> eval b) in + let a, b = FJ.both (fun () -> eval a) (fun () -> eval b) in a + b | Pipe (a, f) -> eval a |> apply_fun_seq f | Map_arr (chunk_size, f, a, r) -> let tasks = List.map (fun x () -> eval x) a in - Fork_join.all_list ~chunk_size tasks - |> Fork_join.map_list ~chunk_size (apply_fun_seq f) + FJ.all_list ~chunk_size tasks + |> FJ.map_list ~chunk_size (apply_fun_seq f) |> eval_reducer r in @@ -290,12 +290,8 @@ let t_for_nested ~min ~chunk_size () = let l1, l2 = let@ pool = Ws_pool.with_ ~num_threads:min () in let@ () = Ws_pool.run_wait_block pool in - let l1 = - Fork_join.map_list ~chunk_size (Fork_join.map_list ~chunk_size neg) l - in - let l2 = - Fork_join.map_list ~chunk_size (Fork_join.map_list ~chunk_size neg) l1 - in + let l1 = FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l in + let l2 = FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l1 in l1, l2 in @@ -313,12 +309,8 @@ let t_map ~chunk_size () = let@ pool = Ws_pool.with_ ~num_threads:4 () in let@ () = Ws_pool.run_wait_block pool in - let a1 = - Fork_join.map_list ~chunk_size string_of_int l |> Array.of_list - in - let a2 = - Fork_join.map_array ~chunk_size string_of_int @@ Array.of_list l - in + let a1 = FJ.map_list ~chunk_size string_of_int l |> Array.of_list in + let a2 = FJ.map_array ~chunk_size string_of_int @@ Array.of_list l in if a1 <> a2 then Q.Test.fail_reportf "a1=%s, a2=%s" (ppa a1) (ppa a2); true) diff --git a/test/effect-based/t_fork_join_heavy.ml b/test/effect-based/t_fork_join_heavy.ml index bacb1d18..7fac119c 100644 --- a/test/effect-based/t_fork_join_heavy.ml +++ b/test/effect-based/t_fork_join_heavy.ml @@ -7,6 +7,7 @@ let ( let@ ) = ( @@ ) let ppl = Q.Print.(list @@ list int) open! Moonpool +module FJ = Moonpool_forkjoin let run ~min () = let@ _sp = @@ -31,17 +32,13 @@ let run ~min () = let@ () = Ws_pool.run_wait_block pool in let l1, l2 = - Fork_join.both + FJ.both (fun () -> let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.left" in - Fork_join.map_list ~chunk_size - (Fork_join.map_list ~chunk_size neg) - l) + FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l) (fun () -> let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.right" in - Fork_join.map_list ~chunk_size - (Fork_join.map_list ~chunk_size neg) - ref_l1) + FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) ref_l1) in l1, l2 in diff --git a/test/effect-based/t_sort.ml b/test/effect-based/t_sort.ml index 8ccc372f..f0da71b8 100644 --- a/test/effect-based/t_sort.ml +++ b/test/effect-based/t_sort.ml @@ -1,6 +1,7 @@ [@@@ifge 5.0] open Moonpool +module FJ = Moonpool_forkjoin let rec select_sort arr i len = if len >= 2 then ( @@ -54,7 +55,7 @@ let rec quicksort arr i len : unit = ) done; - Fork_join.both_ignore + FJ.both_ignore (fun () -> quicksort arr i (!low - i)) (fun () -> quicksort arr !low (len - (!low - i))) )