diff --git a/src/fork_join.ml b/src/fork_join.ml index ccf9a753..f1733514 100644 --- a/src/fork_join.ml +++ b/src/fork_join.ml @@ -92,52 +92,53 @@ let both f g : _ * _ = let both_ignore f g = ignore (both f g : _ * _) let for_ ?chunk_size n (f : int -> int -> unit) : unit = - let has_failed = A.make false in - let missing = A.make n in + if n > 0 then ( + let has_failed = A.make false in + let missing = A.make n in - let chunk_size = - match chunk_size with - | Some cs -> max 1 (min n cs) - | None -> - (* guess: try to have roughly one task per core *) - max 1 (1 + (n / D_pool_.n_domains ())) - in - - let start_tasks ~run (suspension : Suspend_.suspension) = - let task_for ~offset ~len_range = - match f offset (offset + len_range - 1) with - | () -> - if A.fetch_and_add missing (-len_range) = len_range then - (* all tasks done successfully *) - suspension (Ok ()) - | exception exn -> - let bt = Printexc.get_raw_backtrace () in - if not (A.exchange has_failed true) then - (* first one to fail, and [missing] must be >= 2 - because we're not decreasing it. *) - suspension (Error (exn, bt)) + let chunk_size = + match chunk_size with + | Some cs -> max 1 (min n cs) + | None -> + (* guess: try to have roughly one task per core *) + max 1 (1 + (n / D_pool_.n_domains ())) in - let i = ref 0 in - while !i < n do - let offset = !i in + let start_tasks ~run (suspension : Suspend_.suspension) = + let task_for ~offset ~len_range = + match f offset (offset + len_range - 1) with + | () -> + if A.fetch_and_add missing (-len_range) = len_range then + (* all tasks done successfully *) + suspension (Ok ()) + | exception exn -> + let bt = Printexc.get_raw_backtrace () in + if not (A.exchange has_failed true) then + (* first one to fail, and [missing] must be >= 2 + because we're not decreasing it. *) + suspension (Error (exn, bt)) + in - let len_range = min chunk_size (n - offset) in - assert (offset + len_range <= n); + let i = ref 0 in + while !i < n do + let offset = !i in - run ~with_handler:true (fun () -> task_for ~offset ~len_range); - i := !i + len_range - done - in + let len_range = min chunk_size (n - offset) in + assert (offset + len_range <= n); - Suspend_.suspend - { - Suspend_.handle = - (fun ~run suspension -> - (* run tasks, then we'll resume [suspension] *) - start_tasks ~run suspension); - }; - () + run ~with_handler:true (fun () -> task_for ~offset ~len_range); + i := !i + len_range + done + in + + Suspend_.suspend + { + Suspend_.handle = + (fun ~run suspension -> + (* run tasks, then we'll resume [suspension] *) + start_tasks ~run suspension); + } + ) let all_array ?chunk_size (fs : _ array) : _ array = let len = Array.length fs in diff --git a/test/effect-based/t_fork_join.ml b/test/effect-based/t_fork_join.ml index f54883fb..007baa55 100644 --- a/test/effect-based/t_fork_join.ml +++ b/test/effect-based/t_fork_join.ml @@ -42,6 +42,11 @@ let () = let exp_sum = List.init 42 (fun x -> x * x) |> List.fold_left ( + ) 0 in assert (par_sum = exp_sum) +let () = + Pool.run_wait_block pool (fun () -> + Fork_join.for_ 0 (fun _ _ -> assert false)); + () + let () = let total_sum = Atomic.make 0 in