feat suspend: pass run_batch to caller; use that in rest of code

fork join is now going to use run_batch.
This commit is contained in:
Simon Cruanes 2023-10-29 21:35:23 -04:00
parent 0f6bd6288d
commit a1676ff5b6
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
6 changed files with 107 additions and 40 deletions

View file

@ -30,7 +30,10 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let run_task task : unit = let run_task task : unit =
let _ctx = before_task runner in let _ctx = before_task runner in
(* run the task now, catching errors *) (* run the task now, catching errors *)
(try Suspend_.with_suspend task ~run:(fun task' -> schedule_ self task') (try
Suspend_.with_suspend task
~run:(fun task' -> schedule_ self task')
~run_batch:(fun b -> schedule_batch_ self b)
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
on_exn e bt); on_exn e bt);

View file

@ -61,31 +61,37 @@ let both f g : _ * _ =
let open State_ in let open State_ in
let st = A.make { suspension = None; left = St_none; right = St_none } in let st = A.make { suspension = None; left = St_none; right = St_none } in
let start_tasks ~run () : unit = let start_tasks ~run:_ ~run_batch () : unit =
run (fun () -> let t1 () =
try try
let res = f () in let res = f () in
set_left_ st (St_some res) set_left_ st (St_some res)
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
set_left_ st (St_fail (e, bt))); set_left_ st (St_fail (e, bt))
in
run (fun () -> let t2 () =
try try
let res = g () in let res = g () in
set_right_ st (St_some res) set_right_ st (St_some res)
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
set_right_ st (St_fail (e, bt))) set_right_ st (St_fail (e, bt))
in
run_batch (fun yield ->
yield t1;
yield t2)
in in
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~run suspension -> (fun ~run ~run_batch suspension ->
(* nothing else is started, no race condition possible *) (* nothing else is started, no race condition possible *)
(A.get st).suspension <- Some suspension; (A.get st).suspension <- Some suspension;
start_tasks ~run ()); start_tasks ~run ~run_batch ());
}; };
get_exn st get_exn st
@ -104,7 +110,7 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
max 1 (1 + (n / D_pool_.n_domains ())) max 1 (1 + (n / D_pool_.n_domains ()))
in in
let start_tasks ~run (suspension : Suspend_.suspension) = let start_tasks ~run:_ ~run_batch (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range = let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with match f offset (offset + len_range - 1) with
| () -> | () ->
@ -120,23 +126,27 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
in in
let i = ref 0 in let i = ref 0 in
let batch = ref [] in
while !i < n do while !i < n do
let offset = !i in let offset = !i in
let len_range = min chunk_size (n - offset) in let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n); assert (offset + len_range <= n);
run (fun () -> task_for ~offset ~len_range); batch := (fun () -> task_for ~offset ~len_range) :: !batch;
i := !i + len_range i := !i + len_range
done done;
(* schedule all tasks at once *)
run_batch (fun yield -> List.iter yield !batch)
in in
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~run suspension -> (fun ~run ~run_batch suspension ->
(* run tasks, then we'll resume [suspension] *) (* run tasks, then we'll resume [suspension] *)
start_tasks ~run suspension); start_tasks ~run ~run_batch suspension);
} }
) )

View file

@ -379,7 +379,7 @@ let await (fut : 'a t) : 'a =
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~run k -> (fun ~run ~run_batch:_ k ->
on_result fut (function on_result fut (function
| Ok _ -> run (fun () -> k (Ok ())) | Ok _ -> run (fun () -> k (Ok ()))
| Error (exn, bt) -> | Error (exn, bt) ->

View file

@ -1,7 +1,11 @@
type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit
type task = unit -> unit type task = unit -> unit
type 'a iter = ('a -> unit) -> unit
type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } type suspension_handler = {
handle:
run:(task -> unit) -> run_batch:(task iter -> unit) -> suspension -> unit;
}
[@@unboxed] [@@unboxed]
[@@@ifge 5.0] [@@@ifge 5.0]
@ -13,7 +17,7 @@ type _ Effect.t += Suspend : suspension_handler -> unit Effect.t
let[@inline] suspend h = Effect.perform (Suspend h) let[@inline] suspend h = Effect.perform (Suspend h)
let with_suspend ~(run : task -> unit) (f : unit -> unit) : unit = let with_suspend ~(run : task -> unit) ~run_batch (f : unit -> unit) : unit =
let module E = Effect.Deep in let module E = Effect.Deep in
(* effect handler *) (* effect handler *)
let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option =
@ -25,7 +29,7 @@ let with_suspend ~(run : task -> unit) (f : unit -> unit) : unit =
| Ok () -> E.continue k () | Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in in
h.handle ~run k') h.handle ~run ~run_batch k')
| _ -> None | _ -> None
in in
@ -41,7 +45,7 @@ let prepare_for_await () : Dla_.t =
| None -> () | None -> ()
| Some (run, k) -> run (fun () -> k (Ok ())) | Some (run, k) -> run (fun () -> k (Ok ()))
and await () : unit = and await () : unit =
suspend { handle = (fun ~run k -> A.set st (Some (run, k))) } suspend { handle = (fun ~run ~run_batch:_ k -> A.set st (Some (run, k))) }
in in
let t = { Dla_.release; await } in let t = { Dla_.release; await } in
@ -50,7 +54,7 @@ let prepare_for_await () : Dla_.t =
[@@@ocaml.alert "+unstable"] [@@@ocaml.alert "+unstable"]
[@@@else_] [@@@else_]
let[@inline] with_suspend ~run:_ f = f () let[@inline] with_suspend ~run:_ ~run_batch:_ f = f ()
let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore } let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore }
[@@@endif] [@@@endif]

View file

@ -7,8 +7,12 @@ type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit
(** A suspended computation *) (** A suspended computation *)
type task = unit -> unit type task = unit -> unit
type 'a iter = ('a -> unit) -> unit
type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } type suspension_handler = {
handle:
run:(task -> unit) -> run_batch:(task iter -> unit) -> suspension -> unit;
}
[@@unboxed] [@@unboxed]
(** The handler that knows what to do with the suspended computation. (** The handler that knows what to do with the suspended computation.
@ -18,6 +22,8 @@ type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit }
eventually); eventually);
- a [run] function that can be used to start tasks to perform some - a [run] function that can be used to start tasks to perform some
computation. computation.
- a [run_batch] function that can be used to start multiple background
tasks at once
This means that a fork-join primitive, for example, can use a single call This means that a fork-join primitive, for example, can use a single call
to {!suspend} to: to {!suspend} to:
@ -51,10 +57,11 @@ val suspend : suspension_handler -> unit
val prepare_for_await : unit -> Dla_.t val prepare_for_await : unit -> Dla_.t
(** Our stub for DLA. Unstable. *) (** Our stub for DLA. Unstable. *)
val with_suspend : run:(task -> unit) -> (unit -> unit) -> unit val with_suspend :
(** [with_suspend ~run f] runs [f()] in an environment where [suspend] run:(task -> unit) -> run_batch:(task iter -> unit) -> (unit -> unit) -> unit
(** [with_suspend ~run ~run_batch f] runs [f()] in an environment where [suspend]
will work. If [f()] suspends with suspension handler [h], will work. If [f()] suspends with suspension handler [h],
this calls [h ~run k] where [k] is the suspension. this calls [h ~run ~run_batch k] where [k] is the suspension.
This will not do anything on OCaml 4.x. This will not do anything on OCaml 4.x.
*) *)

View file

@ -81,6 +81,52 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
longer permitted *) longer permitted *)
raise Shutdown raise Shutdown
let schedule_task_batch_ (self : state) (w : worker_state option)
(batch : task iter) : unit =
let local_q = Queue.create () in
batch (fun x -> Queue.push x local_q);
let transfer_into_main_q () =
if not (A.get self.active) then raise Shutdown;
(* push into the main queue *)
Mutex.lock self.mutex;
Queue.transfer local_q self.main_q;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
in
let try_to_schedule_locally (w : worker_state) =
let continue = ref true in
while !continue do
match Queue.peek_opt local_q with
| Some task ->
let pushed = WSQ.push w.q task in
if pushed then
(* continue *)
ignore (Queue.pop local_q : task)
else
continue := false
| None -> continue := false
done
in
if not (Queue.is_empty local_q) then (
match w with
| Some w ->
try_to_schedule_locally w;
(* there might be overflow tasks *)
if not (Queue.is_empty local_q) then transfer_into_main_q ()
| None -> transfer_into_main_q ()
)
let[@inline] run_async_ (self : state) (task : task) : unit =
let w = find_current_worker_ () in
schedule_task_ self w task
let[@inline] run_async_batch_ (self : state) (b : task iter) : unit =
let w = find_current_worker_ () in
schedule_task_batch_ self w b
(** Run this task, now. Must be called from a worker. *) (** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner task : unit = let run_task_now_ (self : state) ~runner task : unit =
let (AT_pair (before_task, after_task)) = self.around_task in let (AT_pair (before_task, after_task)) = self.around_task in
@ -88,18 +134,14 @@ let run_task_now_ (self : state) ~runner task : unit =
(* run the task now, catching errors *) (* run the task now, catching errors *)
(try (try
(* run [task()] and handle [suspend] in it *) (* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:(fun task' -> Suspend_.with_suspend task
let w = find_current_worker_ () in ~run:(fun task' -> run_async_ self task')
schedule_task_ self w task') ~run_batch:(fun b -> run_async_batch_ self b)
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt); self.on_exn e bt);
after_task runner _ctx after_task runner _ctx
let[@inline] run_async_ (self : state) (task : task) : unit =
let w = find_current_worker_ () in
schedule_task_ self w task
(* TODO: function to schedule many tasks from the outside. (* TODO: function to schedule many tasks from the outside.
- build a queue - build a queue
- lock - lock
@ -254,6 +296,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
Runner.For_runner_implementors.create Runner.For_runner_implementors.create
~shutdown:(fun ~wait () -> shutdown_ pool ~wait) ~shutdown:(fun ~wait () -> shutdown_ pool ~wait)
~run_async:(fun f -> run_async_ pool f) ~run_async:(fun f -> run_async_ pool f)
~run_async_batch:(fun f -> run_async_batch_ pool f)
~size:(fun () -> size_ pool) ~size:(fun () -> size_ pool)
~num_tasks:(fun () -> num_tasks_ pool) ~num_tasks:(fun () -> num_tasks_ pool)
() ()