mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-05 19:00:33 -05:00
ws_pool: use TLS for quick worker storage access; reduce contention
This commit is contained in:
parent
b4ddd82ee8
commit
e67bffeca5
9 changed files with 130 additions and 18 deletions
|
|
@ -28,6 +28,7 @@
|
|||
(>= 1.9.0)
|
||||
:with-test)))
|
||||
(depopts
|
||||
thread-local-storage
|
||||
(domain-local-await (>= 0.2)))
|
||||
(tags
|
||||
(thread pool domain futures fork-join)))
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ depends: [
|
|||
"mdx" {>= "1.9.0" & with-test}
|
||||
]
|
||||
depopts: [
|
||||
"thread-local-storage"
|
||||
"domain-local-await" {>= "0.2"}
|
||||
]
|
||||
build: [
|
||||
|
|
|
|||
3
src/dune
3
src/dune
|
|
@ -6,6 +6,9 @@
|
|||
(action
|
||||
(run %{project_root}/src/cpp/cpp.exe %{input-file})))
|
||||
(libraries threads either
|
||||
(select thread_local_storage.ml from
|
||||
(thread-local-storage -> thread_local_storage.stub.ml)
|
||||
(-> thread_local_storage.real.ml))
|
||||
(select dla_.ml from
|
||||
(domain-local-await -> dla_.real.ml)
|
||||
( -> dla_.dummy.ml))))
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ module Pool = Fifo_pool
|
|||
module Ws_pool = Ws_pool
|
||||
module Runner = Runner
|
||||
module Fifo_pool = Fifo_pool
|
||||
module Thread_local_storage = Thread_local_storage
|
||||
|
||||
module Private = struct
|
||||
module Ws_deque_ = Ws_deque_
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ module Lock = Lock
|
|||
module Fut = Fut
|
||||
module Chan = Chan
|
||||
module Fork_join = Fork_join
|
||||
module Thread_local_storage = Thread_local_storage
|
||||
|
||||
(** A simple blocking queue.
|
||||
|
||||
|
|
|
|||
21
src/thread_local_storage.mli
Normal file
21
src/thread_local_storage.mli
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
(** Thread local storage *)
|
||||
|
||||
(* TODO: alias this to the library if present *)
|
||||
|
||||
type 'a key
|
||||
(** A TLS key for values of type ['a]. This allows the
|
||||
storage of a single value of type ['a] per thread. *)
|
||||
|
||||
val new_key : (unit -> 'a) -> 'a key
|
||||
(** Allocate a new, generative key.
|
||||
When the key is used for the first time on a thread,
|
||||
the function is called to produce it.
|
||||
|
||||
This should only ever be called at toplevel to produce
|
||||
constants, do not use it in a loop. *)
|
||||
|
||||
val get : 'a key -> 'a
|
||||
(** Get the value for the current thread. *)
|
||||
|
||||
val set : 'a key -> 'a -> unit
|
||||
(** Set the value for the current thread. *)
|
||||
80
src/thread_local_storage.real.ml
Normal file
80
src/thread_local_storage.real.ml
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *)
|
||||
|
||||
(* sanity check *)
|
||||
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())
|
||||
|
||||
type 'a key = {
|
||||
index: int; (** Unique index for this key. *)
|
||||
compute: unit -> 'a;
|
||||
(** Initializer for values for this key. Called at most
|
||||
once per thread. *)
|
||||
}
|
||||
|
||||
(** Counter used to allocate new keys *)
|
||||
let counter = Atomic.make 0
|
||||
|
||||
(** Value used to detect a TLS slot that was not initialized yet *)
|
||||
let sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter
|
||||
|
||||
let new_key compute : _ key =
|
||||
let index = Atomic.fetch_and_add counter 1 in
|
||||
{ index; compute }
|
||||
|
||||
type thread_internal_state = {
|
||||
_id: int; (** Thread ID (here for padding reasons) *)
|
||||
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
|
||||
}
|
||||
(** A partial representation of the internal type [Thread.t], allowing
|
||||
us to access the second field (unused after the thread
|
||||
has started) and stash TLS data in it. *)
|
||||
|
||||
let ceil_pow_2_minus_1 (n : int) : int =
|
||||
let n = n lor (n lsr 1) in
|
||||
let n = n lor (n lsr 2) in
|
||||
let n = n lor (n lsr 4) in
|
||||
let n = n lor (n lsr 8) in
|
||||
let n = n lor (n lsr 16) in
|
||||
if Sys.int_size > 32 then
|
||||
n lor (n lsr 32)
|
||||
else
|
||||
n
|
||||
|
||||
(** Grow the array so that [index] is valid. *)
|
||||
let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array =
|
||||
let new_length = ceil_pow_2_minus_1 (index + 1) in
|
||||
let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in
|
||||
Array.blit old 0 new_ 0 (Array.length old);
|
||||
new_
|
||||
|
||||
let[@inline] get_tls_ (index : int) : Obj.t array =
|
||||
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
|
||||
let tls = thread.tls in
|
||||
if Obj.is_int tls then (
|
||||
let new_tls = grow_tls [||] index in
|
||||
thread.tls <- Obj.magic new_tls;
|
||||
new_tls
|
||||
) else (
|
||||
let tls = (Obj.magic tls : Obj.t array) in
|
||||
if index < Array.length tls then
|
||||
tls
|
||||
else (
|
||||
let new_tls = grow_tls tls index in
|
||||
thread.tls <- Obj.magic new_tls;
|
||||
new_tls
|
||||
)
|
||||
)
|
||||
|
||||
let get key =
|
||||
let tls = get_tls_ key.index in
|
||||
let value = Array.unsafe_get tls key.index in
|
||||
if value != sentinel_value_for_uninit_tls_ () then
|
||||
Obj.magic value
|
||||
else (
|
||||
let value = key.compute () in
|
||||
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value));
|
||||
value
|
||||
)
|
||||
|
||||
let set key value =
|
||||
let tls = get_tls_ key.index in
|
||||
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value))
|
||||
3
src/thread_local_storage.stub.ml
Normal file
3
src/thread_local_storage.stub.ml
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
(* just defer to library *)
|
||||
include Thread_local_storage
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
module WSQ = Ws_deque_
|
||||
module A = Atomic_
|
||||
module TLS = Thread_local_storage
|
||||
include Runner
|
||||
|
||||
let ( let@ ) = ( @@ )
|
||||
|
|
@ -36,25 +37,20 @@ let num_tasks_ (self : state) : int =
|
|||
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
|
||||
!n
|
||||
|
||||
exception Got_worker of worker_state
|
||||
(** TLS, used by worker to store their specific state
|
||||
and be able to retrieve it from tasks when we schedule new
|
||||
sub-tasks. *)
|
||||
let k_worker_state : worker_state option ref TLS.key =
|
||||
TLS.new_key (fun () -> ref None)
|
||||
|
||||
(* FIXME: replace with TLS *)
|
||||
let[@inline] find_current_worker_ (self : state) : worker_state option =
|
||||
let self_id = Thread.id @@ Thread.self () in
|
||||
try
|
||||
(* see if we're in one of the worker threads *)
|
||||
for i = 0 to Array.length self.workers - 1 do
|
||||
let w = self.workers.(i) in
|
||||
if Thread.id w.thread = self_id then raise_notrace (Got_worker w)
|
||||
done;
|
||||
None
|
||||
with Got_worker w -> Some w
|
||||
let[@inline] find_current_worker_ () : worker_state option =
|
||||
!(TLS.get k_worker_state)
|
||||
|
||||
(** Try to wake up a waiter, if there's any. *)
|
||||
let[@inline] try_wake_someone_ (self : state) : unit =
|
||||
if self.n_waiting_nonzero then (
|
||||
Mutex.lock self.mutex;
|
||||
Condition.broadcast self.cond;
|
||||
Condition.signal self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
)
|
||||
|
||||
|
|
@ -71,7 +67,7 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
|
|||
(* push into the main queue *)
|
||||
Mutex.lock self.mutex;
|
||||
Queue.push task self.main_q;
|
||||
if self.n_waiting_nonzero then Condition.broadcast self.cond;
|
||||
if self.n_waiting_nonzero then Condition.signal self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
) else
|
||||
(* notify the caller that scheduling tasks is no
|
||||
|
|
@ -87,7 +83,7 @@ let run_task_now_ (self : state) ~runner task : unit =
|
|||
(try
|
||||
(* run [task()] and handle [suspend] in it *)
|
||||
Suspend_.with_suspend task ~run:(fun task' ->
|
||||
let w = find_current_worker_ self in
|
||||
let w = find_current_worker_ () in
|
||||
schedule_task_ self w task')
|
||||
with e ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
|
|
@ -95,7 +91,7 @@ let run_task_now_ (self : state) ~runner task : unit =
|
|||
after_task runner _ctx
|
||||
|
||||
let[@inline] run_async_ (self : state) (task : task) : unit =
|
||||
let w = find_current_worker_ self in
|
||||
let w = find_current_worker_ () in
|
||||
schedule_task_ self w task
|
||||
|
||||
(* TODO: function to schedule many tasks from the outside.
|
||||
|
|
@ -140,6 +136,7 @@ let try_to_steal_work_loop (self : state) ~runner w : bool =
|
|||
while !n_retries_left > 0 do
|
||||
match try_to_steal_work_once_ self w with
|
||||
| Some task ->
|
||||
try_wake_someone_ self;
|
||||
run_task_now_ self ~runner task;
|
||||
has_stolen := true;
|
||||
n_retries_left := 0
|
||||
|
|
@ -153,12 +150,16 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
|
|||
let continue = ref true in
|
||||
while !continue && A.get self.active do
|
||||
match WSQ.pop w.q with
|
||||
| Some task -> run_task_now_ self ~runner task
|
||||
| Some task ->
|
||||
try_wake_someone_ self;
|
||||
run_task_now_ self ~runner task
|
||||
| None -> continue := false
|
||||
done
|
||||
|
||||
(** Main loop for a worker thread. *)
|
||||
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
||||
TLS.get k_worker_state := Some w;
|
||||
|
||||
let main_loop () : unit =
|
||||
let continue = ref true in
|
||||
while !continue && A.get self.active do
|
||||
|
|
@ -172,7 +173,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
|||
Mutex.unlock self.mutex;
|
||||
run_task_now_ self ~runner task
|
||||
| exception Queue.Empty ->
|
||||
wait_ self;
|
||||
if A.get self.active then wait_ self;
|
||||
Mutex.unlock self.mutex
|
||||
)
|
||||
done;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue