From 2c3cc8892abb52f7864f57acb78a6a68f055742d Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 12 Nov 2025 00:24:05 -0500 Subject: [PATCH] consolidate thread-local-storage into single record --- src/core/fifo_pool.ml | 5 ----- src/core/hmap_ls_.real.ml | 12 ++++++------ src/core/main.ml | 13 ++++++++----- src/core/runner.ml | 7 ++++++- src/core/runner.mli | 8 +++++++- src/core/types_.ml | 22 +++++++++++++++------- src/core/worker_loop_.ml | 20 ++++++++++---------- src/core/ws_pool.ml | 7 ++++--- 8 files changed, 56 insertions(+), 38 deletions(-) diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index 7fcb1297..6b4b9cf0 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -27,11 +27,6 @@ type worker_state = { let[@inline] size_ (self : state) = Array.length self.threads let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q - -(* -get_thread_state = TLS.get_opt k_worker_state - *) - let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let shutdown_ ~wait (self : state) : unit = diff --git a/src/core/hmap_ls_.real.ml b/src/core/hmap_ls_.real.ml index 8b7950a5..5b6b482e 100644 --- a/src/core/hmap_ls_.real.ml +++ b/src/core/hmap_ls_.real.ml @@ -9,19 +9,19 @@ let k_local_hmap : Hmap.t FLS.t = FLS.create () (** Access the local [hmap], or an empty one if not set *) let[@inline] get_local_hmap () : Hmap.t = - match TLS.get_exn k_cur_fiber with + match TLS.get_exn k_cur_st with | exception TLS.Not_set -> Hmap.empty - | fiber -> FLS.get fiber ~default:Hmap.empty k_local_hmap + | { cur_fiber = fiber; _ } -> FLS.get fiber ~default:Hmap.empty k_local_hmap let[@inline] set_local_hmap (h : Hmap.t) : unit = - match TLS.get_exn k_cur_fiber with + match TLS.get_exn k_cur_st with | exception TLS.Not_set -> () - | fiber -> FLS.set fiber k_local_hmap h + | { cur_fiber = fiber; _ } -> FLS.set fiber k_local_hmap h let[@inline] update_local_hmap (f : Hmap.t -> Hmap.t) : unit = - match TLS.get_exn k_cur_fiber with + match TLS.get_exn k_cur_st with | exception TLS.Not_set -> () - | fiber -> + | { cur_fiber = fiber; _ } -> let h = FLS.get fiber ~default:Hmap.empty k_local_hmap in let h = f h in FLS.set fiber k_local_hmap h diff --git a/src/core/main.ml b/src/core/main.ml index 9325fd3a..e10a09fc 100644 --- a/src/core/main.ml +++ b/src/core/main.ml @@ -1,6 +1,7 @@ exception Oh_no of Exn_bt.t let main' ?(block_signals = false) () (f : Runner.t -> 'a) : 'a = + let module WL = Worker_loop_ in let worker_st = Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ()) ~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt))) @@ -8,15 +9,17 @@ let main' ?(block_signals = false) () (f : Runner.t -> 'a) : 'a = in let runner = Fifo_pool.Private_.runner_of_state worker_st in try - let fiber = Fut.spawn ~on:runner (fun () -> f runner) in - Fut.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner); + let fut = Fut.spawn ~on:runner (fun () -> f runner) in + Fut.on_result fut (fun _ -> Runner.shutdown_without_waiting runner); + + Thread_local_storage.set Runner.For_runner_implementors.k_cur_st + { cur_fiber = Picos.Fiber.create ~forbid:true fut; runner }; (* run the main thread *) - Worker_loop_.worker_loop worker_st - ~block_signals (* do not disturb existing thread *) + WL.worker_loop worker_st ~block_signals (* do not disturb existing thread *) ~ops:Fifo_pool.Private_.worker_ops; - match Fut.peek fiber with + match Fut.peek fut with | Some (Ok x) -> x | Some (Error ebt) -> Exn_bt.raise ebt | None -> assert false diff --git a/src/core/runner.ml b/src/core/runner.ml index a95de289..01bb09ea 100644 --- a/src/core/runner.ml +++ b/src/core/runner.ml @@ -47,7 +47,12 @@ module For_runner_implementors = struct let create ~size ~num_tasks ~shutdown ~run_async () : t = { size; num_tasks; shutdown; run_async } - let k_cur_runner : t TLS.t = Types_.k_cur_runner + type nonrec thread_local_state = thread_local_state = { + mutable runner: t; + mutable cur_fiber: fiber; + } + + let k_cur_st : thread_local_state TLS.t = Types_.k_cur_st end let dummy : t = diff --git a/src/core/runner.mli b/src/core/runner.mli index 37db1ce4..b49199f6 100644 --- a/src/core/runner.mli +++ b/src/core/runner.mli @@ -72,7 +72,13 @@ module For_runner_implementors : sig {b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, so that {!Fork_join} and other 5.x features work properly. *) - val k_cur_runner : t Thread_local_storage.t + type thread_local_state = { + mutable runner: t; + mutable cur_fiber: fiber; + } + (** State set in thread-local-storage for worker threads *) + + val k_cur_st : thread_local_state Thread_local_storage.t (** Key that should be used by each runner to store itself in TLS on every thread it controls, so that tasks running on these threads can access the runner. This is necessary for {!get_current_runner} to work. *) diff --git a/src/core/types_.ml b/src/core/types_.ml index 97209942..4c413609 100644 --- a/src/core/types_.ml +++ b/src/core/types_.ml @@ -11,8 +11,12 @@ type runner = { num_tasks: unit -> int; } -let k_cur_runner : runner TLS.t = TLS.create () -let k_cur_fiber : fiber TLS.t = TLS.create () +type thread_local_state = { + mutable runner: runner; + mutable cur_fiber: fiber; +} + +let k_cur_st : thread_local_state TLS.t = TLS.create () let _dummy_computation : Picos.Computation.packed = let c = Picos.Computation.create () in @@ -20,11 +24,15 @@ let _dummy_computation : Picos.Computation.packed = Picos.Computation.Packed c let _dummy_fiber = Picos.Fiber.create_packed ~forbid:true _dummy_computation -let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner + +let[@inline] get_current_runner () : _ option = + match TLS.get_exn k_cur_st with + | st -> Some st.runner + | exception TLS.Not_set -> None let[@inline] get_current_fiber () : fiber option = - match TLS.get_exn k_cur_fiber with - | f when f != _dummy_fiber -> Some f + match TLS.get_exn k_cur_st with + | { cur_fiber = f; _ } when f != _dummy_fiber -> Some f | _ -> None | exception TLS.Not_set -> None @@ -32,7 +40,7 @@ let error_get_current_fiber_ = "Moonpool: get_current_fiber was called outside of a fiber." let[@inline] get_current_fiber_exn () : fiber = - match TLS.get_exn k_cur_fiber with - | f when f != _dummy_fiber -> f + match TLS.get_exn k_cur_st with + | { cur_fiber = f; _ } when f != _dummy_fiber -> f | _ -> failwith error_get_current_fiber_ | exception TLS.Not_set -> failwith error_get_current_fiber_ diff --git a/src/core/worker_loop_.ml b/src/core/worker_loop_.ml index 7ba781a5..8f8664a7 100644 --- a/src/core/worker_loop_.ml +++ b/src/core/worker_loop_.ml @@ -102,7 +102,12 @@ end module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct open Args - let cur_fiber : fiber ref = ref _dummy_fiber + let cur_st : Runner.For_runner_implementors.thread_local_state = + match TLS.get_exn Runner.For_runner_implementors.k_cur_st with + | st -> st + | exception TLS.Not_set -> + failwith "Moonpool: worker loop: no current state set" + let runner = ops.runner st type state = @@ -118,10 +123,7 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct | T_start { fiber; _ } | T_resume { fiber; _ } -> fiber in - cur_fiber := fiber; - TLS.set k_cur_fiber fiber; - - (* let _ctx = before_task runner in *) + cur_st.cur_fiber <- fiber; (* run the task now, catching errors, handling effects *) assert (task != _dummy_task); @@ -136,9 +138,7 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct let ebt = Exn_bt.make e bt in ops.on_exn st ebt); - (* after_task runner _ctx; *) - cur_fiber := _dummy_fiber; - TLS.set k_cur_fiber _dummy_fiber + cur_st.cur_fiber <- _dummy_fiber let setup ~block_signals () : unit = if !state <> New then invalid_arg "worker_loop.setup: not a new instance"; @@ -161,7 +161,7 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct with _ -> () ); - TLS.set Runner.For_runner_implementors.k_cur_runner runner; + cur_st.runner <- runner; ops.before_start st @@ -181,7 +181,7 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct let teardown () = if !state <> Torn_down then ( state := Torn_down; - cur_fiber := _dummy_fiber; + cur_st.cur_fiber <- _dummy_fiber; ops.cleanup st ) end diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index 7581636f..4e33c652 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -55,7 +55,8 @@ let num_tasks_ (self : state) : int = !n (** TLS, used by worker to store their specific state and be able to retrieve it - from tasks when we schedule new sub-tasks. *) + from tasks when we schedule new sub-tasks. This way we can schedule the new + task directly in the local work queue, where it might be stolen. *) let k_worker_state : worker_state TLS.t = TLS.create () let[@inline] get_current_worker_ () : worker_state option = @@ -179,8 +180,8 @@ and wait_on_main_queue (self : worker_state) : WL.task_full = let before_start (self : worker_state) : unit = let t_id = Thread.id @@ Thread.self () in self.st.on_init_thread ~dom_id:self.dom_id ~t_id (); - TLS.set k_cur_fiber _dummy_fiber; - TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner; + TLS.set Runner.For_runner_implementors.k_cur_st + { cur_fiber = _dummy_fiber; runner = self.st.as_runner }; TLS.set k_worker_state self; (* set thread name *)