refactor ws_pool: do not nest effect handlers; fixes

- we differentiate between starting a task and resuming a task
- we dynamically find if we're on one of the pool's runner
  in `resume`/`run_another_task` in the main suspend handler
  (this way we can use the local work stealing queue
  if we're in the same pool, even if we're not on the
  worker that ran the "suspend" call itself)
This commit is contained in:
Simon Cruanes 2024-02-27 21:23:37 -05:00
parent b9cf0616b8
commit 856dc85d41
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -14,13 +14,20 @@ module Id = struct
let equal : t -> t -> bool = ( == ) let equal : t -> t -> bool = ( == )
end end
type task_full = {
f: task;
ls: Task_local_storage.storage;
}
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type task_full =
| T_start of {
ls: Task_local_storage.storage;
f: task;
}
| T_resume : {
ls: Task_local_storage.storage;
k: 'a -> unit;
x: 'a;
}
-> task_full
type worker_state = { type worker_state = {
pool_id_: Id.t; (** Unique per pool *) pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t; mutable thread: Thread.t;
@ -73,10 +80,8 @@ let[@inline] try_wake_someone_ (self : state) : unit =
) )
(** Run [task] as is, on the pool. *) (** Run [task] as is, on the pool. *)
let schedule_task_ (self : state) ~ls (w : worker_state option) (f : task) let schedule_task_ (self : state) ~w (task : task_full) : unit =
: unit =
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let task = { f; ls } in
match w with match w with
| Some w when Id.equal self.id_ w.pool_id_ -> | Some w when Id.equal self.id_ w.pool_id_ ->
(* we're on this same pool, schedule in the worker's state. Otherwise (* we're on this same pool, schedule in the worker's state. Otherwise
@ -105,40 +110,59 @@ let schedule_task_ (self : state) ~ls (w : worker_state option) (f : task)
raise Shutdown raise Shutdown
(** Run this task, now. Must be called from a worker. *) (** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner (w : worker_state) ~ls task : let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
unit = : unit =
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let (AT_pair (before_task, after_task)) = self.around_task in let (AT_pair (before_task, after_task)) = self.around_task in
let ls =
match task with
| T_start { ls; _ } | T_resume { ls; _ } -> ls
in
w.cur_ls := ls; w.cur_ls := ls;
let _ctx = before_task runner in let _ctx = before_task runner in
let[@inline] on_suspend () = let[@inline] on_suspend () =
let w' = find_current_worker_ () in let w =
let ls= !(w.cur_ls) in match find_current_worker_ () with
w', ls | Some w -> w
| None -> assert false
in
let ls = !(w.cur_ls) in
ls
in in
let run_another_task (w,ls) task' = let run_another_task ls (task' : task) =
let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
let ls' = Task_local_storage.Private_.Storage.copy ls in let ls' = Task_local_storage.Private_.Storage.copy ls in
schedule_task_ self w ~ls:ls' task' schedule_task_ self ~w @@ T_start { ls = ls'; f = task' }
in in
let resume (w,ls) k r = let resume ls k x =
schedule_task_ self w ~ls (fun () -> k r) let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
schedule_task_ self ~w @@ T_resume { ls; k; x }
in in
(* run the task now, catching errors *) (* run the task now, catching errors *)
(try (try
(* run [task()] and handle [suspend] in it *) match task with
[@@@ifge 5.0] | T_start { f = task; _ } ->
Suspend_.with_suspend (WSH { (* run [task()] and handle [suspend] in it *)
on_suspend; Suspend_.with_suspend
run=run_another_task; (WSH { on_suspend; run = run_another_task; resume })
resume; task
}) task | T_resume { k; x; _ } ->
[@@@else_] (* this is already in an effect handler *)
task () k x
[@@@endif]
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt); self.on_exn e bt);
@ -146,9 +170,9 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~ls task :
after_task runner _ctx; after_task runner _ctx;
w.cur_ls := Task_local_storage.Private_.Storage.dummy w.cur_ls := Task_local_storage.Private_.Storage.dummy
let[@inline] run_async_ (self : state) ~ls (f : task) : unit = let run_async_ (self : state) ~ls (f : task) : unit =
let w = find_current_worker_ () in let w = find_current_worker_ () in
schedule_task_ self w ~ls f schedule_task_ self ~w @@ T_start { f; ls }
(* TODO: function to schedule many tasks from the outside. (* TODO: function to schedule many tasks from the outside.
- build a queue - build a queue
@ -194,7 +218,7 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
match WSQ.pop w.q with match WSQ.pop w.q with
| Some task -> | Some task ->
try_wake_someone_ self; try_wake_someone_ self;
run_task_now_ self ~runner w ~ls:task.ls task.f run_task_now_ self ~runner ~w task
| None -> continue := false | None -> continue := false
done done
@ -207,7 +231,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
worker_run_self_tasks_ self ~runner w; worker_run_self_tasks_ self ~runner w;
try_steal () try_steal ()
and run_task task : unit = and run_task task : unit =
run_task_now_ self ~runner w ~ls:task.ls task.f; run_task_now_ self ~runner ~w task;
main () main ()
and try_steal () = and try_steal () =
match try_to_steal_work_once_ self w with match try_to_steal_work_once_ self w with
@ -266,7 +290,8 @@ type ('a, 'b) create_args =
'a 'a
(** Arguments used in {!create}. See {!create} for explanations. *) (** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ = { f = ignore; ls = Task_local_storage.Private_.Storage.dummy ; } let dummy_task_ : task_full =
T_start { f = ignore; ls = Task_local_storage.Private_.Storage.dummy }
let create ?(on_init_thread = default_thread_init_exit_) let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())