diff --git a/dune-project b/dune-project index ddf01fef..f6ace0a6 100644 --- a/dune-project +++ b/dune-project @@ -28,6 +28,7 @@ (>= 1.9.0) :with-test))) (depopts + thread-local-storage (domain-local-await (>= 0.2))) (tags (thread pool domain futures fork-join))) diff --git a/moonpool.opam b/moonpool.opam index 547b18c1..62bdcf6e 100644 --- a/moonpool.opam +++ b/moonpool.opam @@ -19,6 +19,7 @@ depends: [ "mdx" {>= "1.9.0" & with-test} ] depopts: [ + "thread-local-storage" "domain-local-await" {>= "0.2"} ] build: [ diff --git a/src/dune b/src/dune index 313191a5..5275ab40 100644 --- a/src/dune +++ b/src/dune @@ -6,6 +6,9 @@ (action (run %{project_root}/src/cpp/cpp.exe %{input-file}))) (libraries threads either + (select thread_local_storage.ml from + (thread-local-storage -> thread_local_storage.stub.ml) + (-> thread_local_storage.real.ml)) (select dla_.ml from (domain-local-await -> dla_.real.ml) ( -> dla_.dummy.ml)))) diff --git a/src/moonpool.ml b/src/moonpool.ml index ed1af755..b4118536 100644 --- a/src/moonpool.ml +++ b/src/moonpool.ml @@ -13,6 +13,7 @@ module Pool = Fifo_pool module Ws_pool = Ws_pool module Runner = Runner module Fifo_pool = Fifo_pool +module Thread_local_storage = Thread_local_storage module Private = struct module Ws_deque_ = Ws_deque_ diff --git a/src/moonpool.mli b/src/moonpool.mli index 4028e858..b744dc51 100644 --- a/src/moonpool.mli +++ b/src/moonpool.mli @@ -26,6 +26,7 @@ module Lock = Lock module Fut = Fut module Chan = Chan module Fork_join = Fork_join +module Thread_local_storage = Thread_local_storage (** A simple blocking queue. diff --git a/src/thread_local_storage.mli b/src/thread_local_storage.mli new file mode 100644 index 00000000..b7b50706 --- /dev/null +++ b/src/thread_local_storage.mli @@ -0,0 +1,21 @@ +(** Thread local storage *) + +(* TODO: alias this to the library if present *) + +type 'a key +(** A TLS key for values of type ['a]. This allows the + storage of a single value of type ['a] per thread. *) + +val new_key : (unit -> 'a) -> 'a key +(** Allocate a new, generative key. + When the key is used for the first time on a thread, + the function is called to produce it. + + This should only ever be called at toplevel to produce + constants, do not use it in a loop. *) + +val get : 'a key -> 'a +(** Get the value for the current thread. *) + +val set : 'a key -> 'a -> unit +(** Set the value for the current thread. *) diff --git a/src/thread_local_storage.real.ml b/src/thread_local_storage.real.ml new file mode 100644 index 00000000..2d33f62c --- /dev/null +++ b/src/thread_local_storage.real.ml @@ -0,0 +1,80 @@ +(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *) + +(* sanity check *) +let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ()) + +type 'a key = { + index: int; (** Unique index for this key. *) + compute: unit -> 'a; + (** Initializer for values for this key. Called at most + once per thread. *) +} + +(** Counter used to allocate new keys *) +let counter = Atomic.make 0 + +(** Value used to detect a TLS slot that was not initialized yet *) +let sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter + +let new_key compute : _ key = + let index = Atomic.fetch_and_add counter 1 in + { index; compute } + +type thread_internal_state = { + _id: int; (** Thread ID (here for padding reasons) *) + mutable tls: Obj.t; (** Our data, stowed away in this unused field *) +} +(** A partial representation of the internal type [Thread.t], allowing + us to access the second field (unused after the thread + has started) and stash TLS data in it. *) + +let ceil_pow_2_minus_1 (n : int) : int = + let n = n lor (n lsr 1) in + let n = n lor (n lsr 2) in + let n = n lor (n lsr 4) in + let n = n lor (n lsr 8) in + let n = n lor (n lsr 16) in + if Sys.int_size > 32 then + n lor (n lsr 32) + else + n + +(** Grow the array so that [index] is valid. *) +let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array = + let new_length = ceil_pow_2_minus_1 (index + 1) in + let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in + Array.blit old 0 new_ 0 (Array.length old); + new_ + +let[@inline] get_tls_ (index : int) : Obj.t array = + let thread : thread_internal_state = Obj.magic (Thread.self ()) in + let tls = thread.tls in + if Obj.is_int tls then ( + let new_tls = grow_tls [||] index in + thread.tls <- Obj.magic new_tls; + new_tls + ) else ( + let tls = (Obj.magic tls : Obj.t array) in + if index < Array.length tls then + tls + else ( + let new_tls = grow_tls tls index in + thread.tls <- Obj.magic new_tls; + new_tls + ) + ) + +let get key = + let tls = get_tls_ key.index in + let value = Array.unsafe_get tls key.index in + if value != sentinel_value_for_uninit_tls_ () then + Obj.magic value + else ( + let value = key.compute () in + Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)); + value + ) + +let set key value = + let tls = get_tls_ key.index in + Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)) diff --git a/src/thread_local_storage.stub.ml b/src/thread_local_storage.stub.ml new file mode 100644 index 00000000..88712b6d --- /dev/null +++ b/src/thread_local_storage.stub.ml @@ -0,0 +1,3 @@ + +(* just defer to library *) +include Thread_local_storage diff --git a/src/ws_pool.ml b/src/ws_pool.ml index b0e055cd..4d1e0c70 100644 --- a/src/ws_pool.ml +++ b/src/ws_pool.ml @@ -1,5 +1,6 @@ module WSQ = Ws_deque_ module A = Atomic_ +module TLS = Thread_local_storage include Runner let ( let@ ) = ( @@ ) @@ -36,25 +37,20 @@ let num_tasks_ (self : state) : int = Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; !n -exception Got_worker of worker_state +(** TLS, used by worker to store their specific state + and be able to retrieve it from tasks when we schedule new + sub-tasks. *) +let k_worker_state : worker_state option ref TLS.key = + TLS.new_key (fun () -> ref None) -(* FIXME: replace with TLS *) -let[@inline] find_current_worker_ (self : state) : worker_state option = - let self_id = Thread.id @@ Thread.self () in - try - (* see if we're in one of the worker threads *) - for i = 0 to Array.length self.workers - 1 do - let w = self.workers.(i) in - if Thread.id w.thread = self_id then raise_notrace (Got_worker w) - done; - None - with Got_worker w -> Some w +let[@inline] find_current_worker_ () : worker_state option = + !(TLS.get k_worker_state) (** Try to wake up a waiter, if there's any. *) let[@inline] try_wake_someone_ (self : state) : unit = if self.n_waiting_nonzero then ( Mutex.lock self.mutex; - Condition.broadcast self.cond; + Condition.signal self.cond; Mutex.unlock self.mutex ) @@ -71,7 +67,7 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit (* push into the main queue *) Mutex.lock self.mutex; Queue.push task self.main_q; - if self.n_waiting_nonzero then Condition.broadcast self.cond; + if self.n_waiting_nonzero then Condition.signal self.cond; Mutex.unlock self.mutex ) else (* notify the caller that scheduling tasks is no @@ -87,7 +83,7 @@ let run_task_now_ (self : state) ~runner task : unit = (try (* run [task()] and handle [suspend] in it *) Suspend_.with_suspend task ~run:(fun task' -> - let w = find_current_worker_ self in + let w = find_current_worker_ () in schedule_task_ self w task') with e -> let bt = Printexc.get_raw_backtrace () in @@ -95,7 +91,7 @@ let run_task_now_ (self : state) ~runner task : unit = after_task runner _ctx let[@inline] run_async_ (self : state) (task : task) : unit = - let w = find_current_worker_ self in + let w = find_current_worker_ () in schedule_task_ self w task (* TODO: function to schedule many tasks from the outside. @@ -140,6 +136,7 @@ let try_to_steal_work_loop (self : state) ~runner w : bool = while !n_retries_left > 0 do match try_to_steal_work_once_ self w with | Some task -> + try_wake_someone_ self; run_task_now_ self ~runner task; has_stolen := true; n_retries_left := 0 @@ -153,12 +150,16 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit = let continue = ref true in while !continue && A.get self.active do match WSQ.pop w.q with - | Some task -> run_task_now_ self ~runner task + | Some task -> + try_wake_someone_ self; + run_task_now_ self ~runner task | None -> continue := false done (** Main loop for a worker thread. *) let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = + TLS.get k_worker_state := Some w; + let main_loop () : unit = let continue = ref true in while !continue && A.get self.active do @@ -172,7 +173,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = Mutex.unlock self.mutex; run_task_now_ self ~runner task | exception Queue.Empty -> - wait_ self; + if A.get self.active then wait_ self; Mutex.unlock self.mutex ) done;