diff --git a/src/fifo_pool.ml b/src/fifo_pool.ml index 1a95d715..c4cc59ac 100644 --- a/src/fifo_pool.ml +++ b/src/fifo_pool.ml @@ -1,3 +1,4 @@ +module TLS = Thread_local_storage_ include Runner let ( let@ ) = ( @@ ) @@ -18,6 +19,7 @@ let schedule_ (self : state) (task : task) : unit = type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = + TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; let (AT_pair (before_task, after_task)) = around_task in let run_task task : unit = diff --git a/src/runner.ml b/src/runner.ml index 91cde5a2..0fcf2392 100644 --- a/src/runner.ml +++ b/src/runner.ml @@ -1,3 +1,5 @@ +module TLS = Thread_local_storage_ + type task = unit -> unit type t = { @@ -34,4 +36,9 @@ let run_wait_block self (f : unit -> 'a) : 'a = module For_runner_implementors = struct let create ~size ~num_tasks ~shutdown ~run_async () : t = { size; num_tasks; shutdown; run_async } + + let k_cur_runner : t option ref TLS.key = TLS.new_key (fun () -> ref None) end + +let[@inline] get_current_runner () : _ option = + !(TLS.get For_runner_implementors.k_cur_runner) diff --git a/src/runner.mli b/src/runner.mli index 3ac2f724..471d21af 100644 --- a/src/runner.mli +++ b/src/runner.mli @@ -63,4 +63,11 @@ module For_runner_implementors : sig {b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, so that {!Fork_join} and other 5.x features work properly. *) + + val k_cur_runner : t option ref Thread_local_storage_.key end + +val get_current_runner : unit -> t option +(** Access the current runner. This returns [Some r] if the call + happens on a thread that belongs in a runner. + @since NEXT_RELEASE *) diff --git a/src/ws_pool.ml b/src/ws_pool.ml index 4623a3e3..874cbd5c 100644 --- a/src/ws_pool.ml +++ b/src/ws_pool.ml @@ -153,6 +153,7 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit = (** Main loop for a worker thread. *) let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = + TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; TLS.get k_worker_state := Some w; let rec main () : unit =