refactor ws_pool: do not nest effect handlers; fixes

- we differentiate between starting a task and resuming a task
- we dynamically find if we're on one of the pool's runner
  in `resume`/`run_another_task` in the main suspend handler
  (this way we can use the local work stealing queue
  if we're in the same pool, even if we're not on the
  worker that ran the "suspend" call itself)
This commit is contained in:
Simon Cruanes 2024-02-27 21:23:37 -05:00
parent b9cf0616b8
commit 856dc85d41
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4

View file

@ -14,13 +14,20 @@ module Id = struct
let equal : t -> t -> bool = ( == )
end
type task_full = {
f: task;
ls: Task_local_storage.storage;
}
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type task_full =
| T_start of {
ls: Task_local_storage.storage;
f: task;
}
| T_resume : {
ls: Task_local_storage.storage;
k: 'a -> unit;
x: 'a;
}
-> task_full
type worker_state = {
pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t;
@ -73,10 +80,8 @@ let[@inline] try_wake_someone_ (self : state) : unit =
)
(** Run [task] as is, on the pool. *)
let schedule_task_ (self : state) ~ls (w : worker_state option) (f : task)
: unit =
let schedule_task_ (self : state) ~w (task : task_full) : unit =
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let task = { f; ls } in
match w with
| Some w when Id.equal self.id_ w.pool_id_ ->
(* we're on this same pool, schedule in the worker's state. Otherwise
@ -105,40 +110,59 @@ let schedule_task_ (self : state) ~ls (w : worker_state option) (f : task)
raise Shutdown
(** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner (w : worker_state) ~ls task :
unit =
let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
: unit =
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let (AT_pair (before_task, after_task)) = self.around_task in
let ls =
match task with
| T_start { ls; _ } | T_resume { ls; _ } -> ls
in
w.cur_ls := ls;
let _ctx = before_task runner in
let[@inline] on_suspend () =
let w' = find_current_worker_ () in
let ls= !(w.cur_ls) in
w', ls
let w =
match find_current_worker_ () with
| Some w -> w
| None -> assert false
in
let ls = !(w.cur_ls) in
ls
in
let run_another_task (w,ls) task' =
let run_another_task ls (task' : task) =
let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
let ls' = Task_local_storage.Private_.Storage.copy ls in
schedule_task_ self w ~ls:ls' task'
schedule_task_ self ~w @@ T_start { ls = ls'; f = task' }
in
let resume (w,ls) k r =
schedule_task_ self w ~ls (fun () -> k r)
let resume ls k x =
let w =
match find_current_worker_ () with
| Some w when Id.equal w.pool_id_ self.id_ -> Some w
| _ -> None
in
schedule_task_ self ~w @@ T_resume { ls; k; x }
in
(* run the task now, catching errors *)
(try
(* run [task()] and handle [suspend] in it *)
[@@@ifge 5.0]
Suspend_.with_suspend (WSH {
on_suspend;
run=run_another_task;
resume;
}) task
[@@@else_]
task ()
[@@@endif]
match task with
| T_start { f = task; _ } ->
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend
(WSH { on_suspend; run = run_another_task; resume })
task
| T_resume { k; x; _ } ->
(* this is already in an effect handler *)
k x
with e ->
let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt);
@ -146,9 +170,9 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~ls task :
after_task runner _ctx;
w.cur_ls := Task_local_storage.Private_.Storage.dummy
let[@inline] run_async_ (self : state) ~ls (f : task) : unit =
let run_async_ (self : state) ~ls (f : task) : unit =
let w = find_current_worker_ () in
schedule_task_ self w ~ls f
schedule_task_ self ~w @@ T_start { f; ls }
(* TODO: function to schedule many tasks from the outside.
- build a queue
@ -194,7 +218,7 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
match WSQ.pop w.q with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner w ~ls:task.ls task.f
run_task_now_ self ~runner ~w task
| None -> continue := false
done
@ -207,7 +231,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
worker_run_self_tasks_ self ~runner w;
try_steal ()
and run_task task : unit =
run_task_now_ self ~runner w ~ls:task.ls task.f;
run_task_now_ self ~runner ~w task;
main ()
and try_steal () =
match try_to_steal_work_once_ self w with
@ -266,7 +290,8 @@ type ('a, 'b) create_args =
'a
(** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ = { f = ignore; ls = Task_local_storage.Private_.Storage.dummy ; }
let dummy_task_ : task_full =
T_start { f = ignore; ls = Task_local_storage.Private_.Storage.dummy }
let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())