From a1676ff5b640aa867b4488a25cafe7ea7f33afea Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Sun, 29 Oct 2023 21:35:23 -0400 Subject: [PATCH] feat suspend: pass `run_batch` to caller; use that in rest of code fork join is now going to use run_batch. --- src/fifo_pool.ml | 5 ++++- src/fork_join.ml | 54 ++++++++++++++++++++++++++------------------- src/fut.ml | 2 +- src/suspend_.ml | 14 +++++++----- src/suspend_.mli | 15 +++++++++---- src/ws_pool.ml | 57 ++++++++++++++++++++++++++++++++++++++++++------ 6 files changed, 107 insertions(+), 40 deletions(-) diff --git a/src/fifo_pool.ml b/src/fifo_pool.ml index c4457e60..2bcfb364 100644 --- a/src/fifo_pool.ml +++ b/src/fifo_pool.ml @@ -30,7 +30,10 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = let run_task task : unit = let _ctx = before_task runner in (* run the task now, catching errors *) - (try Suspend_.with_suspend task ~run:(fun task' -> schedule_ self task') + (try + Suspend_.with_suspend task + ~run:(fun task' -> schedule_ self task') + ~run_batch:(fun b -> schedule_batch_ self b) with e -> let bt = Printexc.get_raw_backtrace () in on_exn e bt); diff --git a/src/fork_join.ml b/src/fork_join.ml index ac5ba5d7..b5cae866 100644 --- a/src/fork_join.ml +++ b/src/fork_join.ml @@ -61,31 +61,37 @@ let both f g : _ * _ = let open State_ in let st = A.make { suspension = None; left = St_none; right = St_none } in - let start_tasks ~run () : unit = - run (fun () -> - try - let res = f () in - set_left_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_left_ st (St_fail (e, bt))); + let start_tasks ~run:_ ~run_batch () : unit = + let t1 () = + try + let res = f () in + set_left_ st (St_some res) + with e -> + let bt = Printexc.get_raw_backtrace () in + set_left_ st (St_fail (e, bt)) + in - run (fun () -> - try - let res = g () in - set_right_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_right_ st (St_fail (e, bt))) + let t2 () = + try + let res = g () in + set_right_ st (St_some res) + with e -> + let bt = Printexc.get_raw_backtrace () in + set_right_ st (St_fail (e, bt)) + in + + run_batch (fun yield -> + yield t1; + yield t2) in Suspend_.suspend { Suspend_.handle = - (fun ~run suspension -> + (fun ~run ~run_batch suspension -> (* nothing else is started, no race condition possible *) (A.get st).suspension <- Some suspension; - start_tasks ~run ()); + start_tasks ~run ~run_batch ()); }; get_exn st @@ -104,7 +110,7 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = max 1 (1 + (n / D_pool_.n_domains ())) in - let start_tasks ~run (suspension : Suspend_.suspension) = + let start_tasks ~run:_ ~run_batch (suspension : Suspend_.suspension) = let task_for ~offset ~len_range = match f offset (offset + len_range - 1) with | () -> @@ -120,23 +126,27 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = in let i = ref 0 in + let batch = ref [] in while !i < n do let offset = !i in let len_range = min chunk_size (n - offset) in assert (offset + len_range <= n); - run (fun () -> task_for ~offset ~len_range); + batch := (fun () -> task_for ~offset ~len_range) :: !batch; i := !i + len_range - done + done; + + (* schedule all tasks at once *) + run_batch (fun yield -> List.iter yield !batch) in Suspend_.suspend { Suspend_.handle = - (fun ~run suspension -> + (fun ~run ~run_batch suspension -> (* run tasks, then we'll resume [suspension] *) - start_tasks ~run suspension); + start_tasks ~run ~run_batch suspension); } ) diff --git a/src/fut.ml b/src/fut.ml index 639a503b..c22d6f15 100644 --- a/src/fut.ml +++ b/src/fut.ml @@ -379,7 +379,7 @@ let await (fut : 'a t) : 'a = Suspend_.suspend { Suspend_.handle = - (fun ~run k -> + (fun ~run ~run_batch:_ k -> on_result fut (function | Ok _ -> run (fun () -> k (Ok ())) | Error (exn, bt) -> diff --git a/src/suspend_.ml b/src/suspend_.ml index 6555b6bc..71bc48ec 100644 --- a/src/suspend_.ml +++ b/src/suspend_.ml @@ -1,7 +1,11 @@ type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit type task = unit -> unit +type 'a iter = ('a -> unit) -> unit -type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } +type suspension_handler = { + handle: + run:(task -> unit) -> run_batch:(task iter -> unit) -> suspension -> unit; +} [@@unboxed] [@@@ifge 5.0] @@ -13,7 +17,7 @@ type _ Effect.t += Suspend : suspension_handler -> unit Effect.t let[@inline] suspend h = Effect.perform (Suspend h) -let with_suspend ~(run : task -> unit) (f : unit -> unit) : unit = +let with_suspend ~(run : task -> unit) ~run_batch (f : unit -> unit) : unit = let module E = Effect.Deep in (* effect handler *) let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = @@ -25,7 +29,7 @@ let with_suspend ~(run : task -> unit) (f : unit -> unit) : unit = | Ok () -> E.continue k () | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt in - h.handle ~run k') + h.handle ~run ~run_batch k') | _ -> None in @@ -41,7 +45,7 @@ let prepare_for_await () : Dla_.t = | None -> () | Some (run, k) -> run (fun () -> k (Ok ())) and await () : unit = - suspend { handle = (fun ~run k -> A.set st (Some (run, k))) } + suspend { handle = (fun ~run ~run_batch:_ k -> A.set st (Some (run, k))) } in let t = { Dla_.release; await } in @@ -50,7 +54,7 @@ let prepare_for_await () : Dla_.t = [@@@ocaml.alert "+unstable"] [@@@else_] -let[@inline] with_suspend ~run:_ f = f () +let[@inline] with_suspend ~run:_ ~run_batch:_ f = f () let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore } [@@@endif] diff --git a/src/suspend_.mli b/src/suspend_.mli index 77cc06af..6cd6a386 100644 --- a/src/suspend_.mli +++ b/src/suspend_.mli @@ -7,8 +7,12 @@ type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit (** A suspended computation *) type task = unit -> unit +type 'a iter = ('a -> unit) -> unit -type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } +type suspension_handler = { + handle: + run:(task -> unit) -> run_batch:(task iter -> unit) -> suspension -> unit; +} [@@unboxed] (** The handler that knows what to do with the suspended computation. @@ -18,6 +22,8 @@ type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } eventually); - a [run] function that can be used to start tasks to perform some computation. + - a [run_batch] function that can be used to start multiple background + tasks at once This means that a fork-join primitive, for example, can use a single call to {!suspend} to: @@ -51,10 +57,11 @@ val suspend : suspension_handler -> unit val prepare_for_await : unit -> Dla_.t (** Our stub for DLA. Unstable. *) -val with_suspend : run:(task -> unit) -> (unit -> unit) -> unit -(** [with_suspend ~run f] runs [f()] in an environment where [suspend] +val with_suspend : + run:(task -> unit) -> run_batch:(task iter -> unit) -> (unit -> unit) -> unit +(** [with_suspend ~run ~run_batch f] runs [f()] in an environment where [suspend] will work. If [f()] suspends with suspension handler [h], - this calls [h ~run k] where [k] is the suspension. + this calls [h ~run ~run_batch k] where [k] is the suspension. This will not do anything on OCaml 4.x. *) diff --git a/src/ws_pool.ml b/src/ws_pool.ml index a60bf270..4e61dea1 100644 --- a/src/ws_pool.ml +++ b/src/ws_pool.ml @@ -81,6 +81,52 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit longer permitted *) raise Shutdown +let schedule_task_batch_ (self : state) (w : worker_state option) + (batch : task iter) : unit = + let local_q = Queue.create () in + batch (fun x -> Queue.push x local_q); + + let transfer_into_main_q () = + if not (A.get self.active) then raise Shutdown; + (* push into the main queue *) + Mutex.lock self.mutex; + Queue.transfer local_q self.main_q; + if self.n_waiting_nonzero then Condition.signal self.cond; + Mutex.unlock self.mutex + in + + let try_to_schedule_locally (w : worker_state) = + let continue = ref true in + while !continue do + match Queue.peek_opt local_q with + | Some task -> + let pushed = WSQ.push w.q task in + if pushed then + (* continue *) + ignore (Queue.pop local_q : task) + else + continue := false + | None -> continue := false + done + in + + if not (Queue.is_empty local_q) then ( + match w with + | Some w -> + try_to_schedule_locally w; + (* there might be overflow tasks *) + if not (Queue.is_empty local_q) then transfer_into_main_q () + | None -> transfer_into_main_q () + ) + +let[@inline] run_async_ (self : state) (task : task) : unit = + let w = find_current_worker_ () in + schedule_task_ self w task + +let[@inline] run_async_batch_ (self : state) (b : task iter) : unit = + let w = find_current_worker_ () in + schedule_task_batch_ self w b + (** Run this task, now. Must be called from a worker. *) let run_task_now_ (self : state) ~runner task : unit = let (AT_pair (before_task, after_task)) = self.around_task in @@ -88,18 +134,14 @@ let run_task_now_ (self : state) ~runner task : unit = (* run the task now, catching errors *) (try (* run [task()] and handle [suspend] in it *) - Suspend_.with_suspend task ~run:(fun task' -> - let w = find_current_worker_ () in - schedule_task_ self w task') + Suspend_.with_suspend task + ~run:(fun task' -> run_async_ self task') + ~run_batch:(fun b -> run_async_batch_ self b) with e -> let bt = Printexc.get_raw_backtrace () in self.on_exn e bt); after_task runner _ctx -let[@inline] run_async_ (self : state) (task : task) : unit = - let w = find_current_worker_ () in - schedule_task_ self w task - (* TODO: function to schedule many tasks from the outside. - build a queue - lock @@ -254,6 +296,7 @@ let create ?(on_init_thread = default_thread_init_exit_) Runner.For_runner_implementors.create ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) ~run_async:(fun f -> run_async_ pool f) + ~run_async_batch:(fun f -> run_async_batch_ pool f) ~size:(fun () -> size_ pool) ~num_tasks:(fun () -> num_tasks_ pool) ()