fix ws_pool: make sure we capture the current worker before suspend

This commit is contained in:
Simon Cruanes 2024-02-23 20:54:51 -05:00
parent 4cdec87aea
commit ed171c1171
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
3 changed files with 17 additions and 10 deletions

View file

@ -113,17 +113,17 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~ls task :
let _ctx = before_task runner in
let[@inline] on_suspend () =
!(w.cur_ls)
let w' = find_current_worker_ () in
let ls= !(w.cur_ls) in
w', ls
in
let run_another_task ls task' =
let w = find_current_worker_ () in
let run_another_task (w,ls) task' =
let ls' = Task_local_storage.Private_.Storage.copy ls in
schedule_task_ self w ~ls:ls' task'
in
let resume ls k r =
let w = find_current_worker_ () in
let resume (w,ls) k r =
schedule_task_ self w ~ls (fun () -> k r)
in

View file

@ -26,6 +26,12 @@ module Private_ = struct
and children = any FM.t
and any = Any : _ t -> any [@@unboxed]
(** Key to access the current fiber. *)
let k_current_fiber : any option Task_local_storage.key =
Task_local_storage.new_key ~init:(fun () -> None) ()
let[@inline] get_cur () : any option = Task_local_storage.get k_current_fiber
end
include Private_
@ -148,10 +154,6 @@ let add_child_ ~protect (self : _ t) (child : _ t) =
()
done
(** Key to access the current fiber. *)
let k_current_fiber : any option Task_local_storage.key =
Task_local_storage.new_key ~init:(fun () -> None) ()
let spawn_ ~on (f : _ -> 'a) : 'a t =
let id = Handle.generate_fresh () in
let res, _promise = Fut.make () in
@ -167,6 +169,7 @@ let spawn_ ~on (f : _ -> 'a) : 'a t =
let run () =
(* make sure the fiber is accessible from inside itself *)
Task_local_storage.set k_current_fiber (Some (Any fib));
assert (Task_local_storage.get k_current_fiber |> Option.is_some);
try
let res = f () in
resolve_ok_ fib res

View file

@ -22,6 +22,10 @@ module Private_ : sig
}
(** Type definition, exposed so that {!any} can be unboxed.
Please do not rely on that. *)
type any = Any : _ t -> any [@@unboxed]
val get_cur : unit -> any option
end
(**/**)
@ -38,7 +42,7 @@ type 'a callback = 'a Exn_bt.result -> unit
type cancel_callback = Exn_bt.t -> unit
(** Type erased fiber *)
type any = Any : _ t -> any [@@unboxed]
type any = Private_.any = Any : _ t -> any [@@unboxed]
val self : unit -> any
(** [self ()] is the current fiber.