ws_pool: use TLS for quick worker storage access; reduce contention

This commit is contained in:
Simon Cruanes 2023-10-27 15:18:50 -04:00
parent b4ddd82ee8
commit e67bffeca5
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
9 changed files with 130 additions and 18 deletions

View file

@ -28,6 +28,7 @@
(>= 1.9.0)
:with-test)))
(depopts
thread-local-storage
(domain-local-await (>= 0.2)))
(tags
(thread pool domain futures fork-join)))

View file

@ -19,6 +19,7 @@ depends: [
"mdx" {>= "1.9.0" & with-test}
]
depopts: [
"thread-local-storage"
"domain-local-await" {>= "0.2"}
]
build: [

View file

@ -6,6 +6,9 @@
(action
(run %{project_root}/src/cpp/cpp.exe %{input-file})))
(libraries threads either
(select thread_local_storage.ml from
(thread-local-storage -> thread_local_storage.stub.ml)
(-> thread_local_storage.real.ml))
(select dla_.ml from
(domain-local-await -> dla_.real.ml)
( -> dla_.dummy.ml))))

View file

@ -13,6 +13,7 @@ module Pool = Fifo_pool
module Ws_pool = Ws_pool
module Runner = Runner
module Fifo_pool = Fifo_pool
module Thread_local_storage = Thread_local_storage
module Private = struct
module Ws_deque_ = Ws_deque_

View file

@ -26,6 +26,7 @@ module Lock = Lock
module Fut = Fut
module Chan = Chan
module Fork_join = Fork_join
module Thread_local_storage = Thread_local_storage
(** A simple blocking queue.

View file

@ -0,0 +1,21 @@
(** Thread local storage *)
(* TODO: alias this to the library if present *)
type 'a key
(** A TLS key for values of type ['a]. This allows the
storage of a single value of type ['a] per thread. *)
val new_key : (unit -> 'a) -> 'a key
(** Allocate a new, generative key.
When the key is used for the first time on a thread,
the function is called to produce it.
This should only ever be called at toplevel to produce
constants, do not use it in a loop. *)
val get : 'a key -> 'a
(** Get the value for the current thread. *)
val set : 'a key -> 'a -> unit
(** Set the value for the current thread. *)

View file

@ -0,0 +1,80 @@
(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *)
(* sanity check *)
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())
type 'a key = {
index: int; (** Unique index for this key. *)
compute: unit -> 'a;
(** Initializer for values for this key. Called at most
once per thread. *)
}
(** Counter used to allocate new keys *)
let counter = Atomic.make 0
(** Value used to detect a TLS slot that was not initialized yet *)
let sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter
let new_key compute : _ key =
let index = Atomic.fetch_and_add counter 1 in
{ index; compute }
type thread_internal_state = {
_id: int; (** Thread ID (here for padding reasons) *)
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
}
(** A partial representation of the internal type [Thread.t], allowing
us to access the second field (unused after the thread
has started) and stash TLS data in it. *)
let ceil_pow_2_minus_1 (n : int) : int =
let n = n lor (n lsr 1) in
let n = n lor (n lsr 2) in
let n = n lor (n lsr 4) in
let n = n lor (n lsr 8) in
let n = n lor (n lsr 16) in
if Sys.int_size > 32 then
n lor (n lsr 32)
else
n
(** Grow the array so that [index] is valid. *)
let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array =
let new_length = ceil_pow_2_minus_1 (index + 1) in
let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in
Array.blit old 0 new_ 0 (Array.length old);
new_
let[@inline] get_tls_ (index : int) : Obj.t array =
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
let tls = thread.tls in
if Obj.is_int tls then (
let new_tls = grow_tls [||] index in
thread.tls <- Obj.magic new_tls;
new_tls
) else (
let tls = (Obj.magic tls : Obj.t array) in
if index < Array.length tls then
tls
else (
let new_tls = grow_tls tls index in
thread.tls <- Obj.magic new_tls;
new_tls
)
)
let get key =
let tls = get_tls_ key.index in
let value = Array.unsafe_get tls key.index in
if value != sentinel_value_for_uninit_tls_ () then
Obj.magic value
else (
let value = key.compute () in
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value));
value
)
let set key value =
let tls = get_tls_ key.index in
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value))

View file

@ -0,0 +1,3 @@
(* just defer to library *)
include Thread_local_storage

View file

@ -1,5 +1,6 @@
module WSQ = Ws_deque_
module A = Atomic_
module TLS = Thread_local_storage
include Runner
let ( let@ ) = ( @@ )
@ -36,25 +37,20 @@ let num_tasks_ (self : state) : int =
Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers;
!n
exception Got_worker of worker_state
(** TLS, used by worker to store their specific state
and be able to retrieve it from tasks when we schedule new
sub-tasks. *)
let k_worker_state : worker_state option ref TLS.key =
TLS.new_key (fun () -> ref None)
(* FIXME: replace with TLS *)
let[@inline] find_current_worker_ (self : state) : worker_state option =
let self_id = Thread.id @@ Thread.self () in
try
(* see if we're in one of the worker threads *)
for i = 0 to Array.length self.workers - 1 do
let w = self.workers.(i) in
if Thread.id w.thread = self_id then raise_notrace (Got_worker w)
done;
None
with Got_worker w -> Some w
let[@inline] find_current_worker_ () : worker_state option =
!(TLS.get k_worker_state)
(** Try to wake up a waiter, if there's any. *)
let[@inline] try_wake_someone_ (self : state) : unit =
if self.n_waiting_nonzero then (
Mutex.lock self.mutex;
Condition.broadcast self.cond;
Condition.signal self.cond;
Mutex.unlock self.mutex
)
@ -71,7 +67,7 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
(* push into the main queue *)
Mutex.lock self.mutex;
Queue.push task self.main_q;
if self.n_waiting_nonzero then Condition.broadcast self.cond;
if self.n_waiting_nonzero then Condition.signal self.cond;
Mutex.unlock self.mutex
) else
(* notify the caller that scheduling tasks is no
@ -87,7 +83,7 @@ let run_task_now_ (self : state) ~runner task : unit =
(try
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:(fun task' ->
let w = find_current_worker_ self in
let w = find_current_worker_ () in
schedule_task_ self w task')
with e ->
let bt = Printexc.get_raw_backtrace () in
@ -95,7 +91,7 @@ let run_task_now_ (self : state) ~runner task : unit =
after_task runner _ctx
let[@inline] run_async_ (self : state) (task : task) : unit =
let w = find_current_worker_ self in
let w = find_current_worker_ () in
schedule_task_ self w task
(* TODO: function to schedule many tasks from the outside.
@ -140,6 +136,7 @@ let try_to_steal_work_loop (self : state) ~runner w : bool =
while !n_retries_left > 0 do
match try_to_steal_work_once_ self w with
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner task;
has_stolen := true;
n_retries_left := 0
@ -153,12 +150,16 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
let continue = ref true in
while !continue && A.get self.active do
match WSQ.pop w.q with
| Some task -> run_task_now_ self ~runner task
| Some task ->
try_wake_someone_ self;
run_task_now_ self ~runner task
| None -> continue := false
done
(** Main loop for a worker thread. *)
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
TLS.get k_worker_state := Some w;
let main_loop () : unit =
let continue = ref true in
while !continue && A.get self.active do
@ -172,7 +173,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
Mutex.unlock self.mutex;
run_task_now_ self ~runner task
| exception Queue.Empty ->
wait_ self;
if A.get self.active then wait_ self;
Mutex.unlock self.mutex
)
done;