mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-06 03:05:30 -05:00
move to thread-local-storage 0.2 with get/set API
This commit is contained in:
parent
3388098fcc
commit
265d4f73dd
10 changed files with 112 additions and 89 deletions
|
|
@ -30,6 +30,7 @@
|
|||
(depopts
|
||||
(trace (>= 0.6))
|
||||
thread-local-storage)
|
||||
(conflicts (thread-local-storage (< 0.2)))
|
||||
(tags
|
||||
(thread pool domain futures fork-join)))
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ depopts: [
|
|||
"trace" {>= "0.6"}
|
||||
"thread-local-storage"
|
||||
]
|
||||
conflicts: [
|
||||
"thread-local-storage" {< "0.2"}
|
||||
]
|
||||
build: [
|
||||
["dune" "subst"] {dev}
|
||||
[
|
||||
|
|
|
|||
|
|
@ -31,18 +31,17 @@ let schedule_ (self : state) (task : task_full) : unit =
|
|||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||
type worker_state = { mutable cur_ls: Task_local_storage.t option }
|
||||
|
||||
let k_worker_state : worker_state option ref TLS.key =
|
||||
TLS.new_key (fun () -> ref None)
|
||||
let k_worker_state : worker_state TLS.t = TLS.create ()
|
||||
|
||||
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
||||
let w = { cur_ls = None } in
|
||||
TLS.get k_worker_state := Some w;
|
||||
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
|
||||
TLS.set k_worker_state w;
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
|
||||
|
||||
let (AT_pair (before_task, after_task)) = around_task in
|
||||
|
||||
let on_suspend () =
|
||||
match !(TLS.get k_worker_state) with
|
||||
match TLS.get_opt k_worker_state with
|
||||
| Some { cur_ls = Some ls; _ } -> ls
|
||||
| _ -> assert false
|
||||
in
|
||||
|
|
@ -55,7 +54,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
|||
| T_start { ls; _ } | T_resume { ls; _ } -> ls
|
||||
in
|
||||
w.cur_ls <- Some ls;
|
||||
TLS.get k_cur_storage := Some ls;
|
||||
TLS.set k_cur_storage ls;
|
||||
let _ctx = before_task runner in
|
||||
|
||||
(* run the task now, catching errors, handling effects *)
|
||||
|
|
@ -74,7 +73,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
|||
on_exn e bt);
|
||||
after_task runner _ctx;
|
||||
w.cur_ls <- None;
|
||||
TLS.get k_cur_storage := None
|
||||
TLS.set k_cur_storage _dummy_ls
|
||||
in
|
||||
|
||||
let main_loop () =
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ module For_runner_implementors = struct
|
|||
let create ~size ~num_tasks ~shutdown ~run_async () : t =
|
||||
{ size; num_tasks; shutdown; run_async }
|
||||
|
||||
let k_cur_runner : t option ref TLS.key = Types_.k_cur_runner
|
||||
let k_cur_runner : t TLS.t = Types_.k_cur_runner
|
||||
end
|
||||
|
||||
let dummy : t =
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ module For_runner_implementors : sig
|
|||
{b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x,
|
||||
so that {!Fork_join} and other 5.x features work properly. *)
|
||||
|
||||
val k_cur_runner : t option ref Thread_local_storage_.key
|
||||
val k_cur_runner : t Thread_local_storage_.t
|
||||
(** Key that should be used by each runner to store itself in TLS
|
||||
on every thread it controls, so that tasks running on these threads
|
||||
can access the runner. This is necessary for {!get_current_runner}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ let key_count_ = A.make 0
|
|||
type t = local_storage
|
||||
type ls_value += Dummy
|
||||
|
||||
let dummy : t = ref [||]
|
||||
let dummy : t = _dummy_ls
|
||||
|
||||
(** Resize array of TLS values *)
|
||||
let[@inline never] resize_ (cur : ls_value array ref) n =
|
||||
|
|
@ -57,7 +57,9 @@ let new_key (type t) ~init () : t key =
|
|||
|
||||
let[@inline] get_cur_ () : ls_value array ref =
|
||||
match get_current_storage () with
|
||||
| Some r -> r
|
||||
| Some r ->
|
||||
assert (r != dummy);
|
||||
r
|
||||
| None -> failwith "Task local storage must be accessed from within a runner."
|
||||
|
||||
let[@inline] get (key : 'a key) : 'a =
|
||||
|
|
|
|||
|
|
@ -27,11 +27,9 @@ type runner = {
|
|||
num_tasks: unit -> int;
|
||||
}
|
||||
|
||||
let k_cur_runner : runner option ref TLS.key = TLS.new_key (fun () -> ref None)
|
||||
|
||||
let k_cur_storage : local_storage option ref TLS.key =
|
||||
TLS.new_key (fun () -> ref None)
|
||||
|
||||
let[@inline] get_current_runner () : _ option = !(TLS.get k_cur_runner)
|
||||
let[@inline] get_current_storage () : _ option = !(TLS.get k_cur_storage)
|
||||
let k_cur_runner : runner TLS.t = TLS.create ()
|
||||
let k_cur_storage : local_storage TLS.t = TLS.create ()
|
||||
let _dummy_ls : local_storage = ref [||]
|
||||
let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner
|
||||
let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage
|
||||
let[@inline] create_local_storage () = ref [||]
|
||||
|
|
|
|||
|
|
@ -65,11 +65,10 @@ let num_tasks_ (self : state) : int =
|
|||
(** TLS, used by worker to store their specific state
|
||||
and be able to retrieve it from tasks when we schedule new
|
||||
sub-tasks. *)
|
||||
let k_worker_state : worker_state option ref TLS.key =
|
||||
TLS.new_key (fun () -> ref None)
|
||||
let k_worker_state : worker_state TLS.t = TLS.create ()
|
||||
|
||||
let[@inline] find_current_worker_ () : worker_state option =
|
||||
!(TLS.get k_worker_state)
|
||||
TLS.get_opt k_worker_state
|
||||
|
||||
(** Try to wake up a waiter, if there's any. *)
|
||||
let[@inline] try_wake_someone_ (self : state) : unit =
|
||||
|
|
@ -121,7 +120,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
|
|||
in
|
||||
|
||||
w.cur_ls <- Some ls;
|
||||
TLS.get k_cur_storage := Some ls;
|
||||
TLS.set k_cur_storage ls;
|
||||
let _ctx = before_task runner in
|
||||
|
||||
let[@inline] on_suspend () : _ ref =
|
||||
|
|
@ -166,7 +165,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
|
|||
|
||||
after_task runner _ctx;
|
||||
w.cur_ls <- None;
|
||||
TLS.get k_cur_storage := None
|
||||
TLS.set k_cur_storage _dummy_ls
|
||||
|
||||
let run_async_ (self : state) ~ls (f : task) : unit =
|
||||
let w = find_current_worker_ () in
|
||||
|
|
@ -222,8 +221,8 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
|
|||
|
||||
(** Main loop for a worker thread. *)
|
||||
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
||||
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
|
||||
TLS.get k_worker_state := Some w;
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
|
||||
TLS.set k_worker_state w;
|
||||
|
||||
let rec main () : unit =
|
||||
worker_run_self_tasks_ self ~runner w;
|
||||
|
|
@ -358,7 +357,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
let thread = Thread.self () in
|
||||
let t_id = Thread.id thread in
|
||||
on_init_thread ~dom_id:dom_idx ~t_id ();
|
||||
TLS.get k_cur_storage := None;
|
||||
TLS.set k_cur_storage _dummy_ls;
|
||||
|
||||
(* set thread name *)
|
||||
Option.iter
|
||||
|
|
|
|||
|
|
@ -1,21 +1,13 @@
|
|||
(** Thread local storage *)
|
||||
|
||||
(* TODO: alias this to the library if present *)
|
||||
type 'a t
|
||||
(** A TLS slot for values of type ['a]. This allows the storage of a
|
||||
single value of type ['a] per thread. *)
|
||||
|
||||
type 'a key
|
||||
(** A TLS key for values of type ['a]. This allows the
|
||||
storage of a single value of type ['a] per thread. *)
|
||||
val create : unit -> 'a t
|
||||
|
||||
val new_key : (unit -> 'a) -> 'a key
|
||||
(** Allocate a new, generative key.
|
||||
When the key is used for the first time on a thread,
|
||||
the function is called to produce it.
|
||||
val get : 'a t -> 'a
|
||||
(** @raise Failure if not present *)
|
||||
|
||||
This should only ever be called at toplevel to produce
|
||||
constants, do not use it in a loop. *)
|
||||
|
||||
val get : 'a key -> 'a
|
||||
(** Get the value for the current thread. *)
|
||||
|
||||
val set : 'a key -> 'a -> unit
|
||||
(** Set the value for the current thread. *)
|
||||
val get_opt : 'a t -> 'a option
|
||||
val set : 'a t -> 'a -> unit
|
||||
|
|
|
|||
|
|
@ -1,82 +1,111 @@
|
|||
(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *)
|
||||
|
||||
module A = Atomic_
|
||||
(* vendored from https://github.com/c-cube/thread-local-storage *)
|
||||
|
||||
(* sanity check *)
|
||||
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())
|
||||
|
||||
type 'a key = {
|
||||
index: int; (** Unique index for this key. *)
|
||||
compute: unit -> 'a;
|
||||
(** Initializer for values for this key. Called at most
|
||||
once per thread. *)
|
||||
}
|
||||
type 'a t = int
|
||||
(** Unique index for this TLS slot. *)
|
||||
|
||||
let tls_length index =
|
||||
let ceil_pow_2_minus_1 (n : int) : int =
|
||||
let n = n lor (n lsr 1) in
|
||||
let n = n lor (n lsr 2) in
|
||||
let n = n lor (n lsr 4) in
|
||||
let n = n lor (n lsr 8) in
|
||||
let n = n lor (n lsr 16) in
|
||||
if Sys.int_size > 32 then
|
||||
n lor (n lsr 32)
|
||||
else
|
||||
n
|
||||
in
|
||||
let size = ceil_pow_2_minus_1 (index + 1) in
|
||||
assert (size > index);
|
||||
size
|
||||
|
||||
(** Counter used to allocate new keys *)
|
||||
let counter = A.make 0
|
||||
let counter = Atomic.make 0
|
||||
|
||||
(** Value used to detect a TLS slot that was not initialized yet *)
|
||||
let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter
|
||||
(** Value used to detect a TLS slot that was not initialized yet.
|
||||
Because [counter] is private and lives forever, no other
|
||||
object the user can see will have the same address. *)
|
||||
let sentinel_value_for_uninit_tls : Obj.t = Obj.repr counter
|
||||
|
||||
let new_key compute : _ key =
|
||||
let index = A.fetch_and_add counter 1 in
|
||||
{ index; compute }
|
||||
external max_wosize : unit -> int = "caml_sys_const_max_wosize"
|
||||
|
||||
let max_word_size = max_wosize ()
|
||||
|
||||
let create () : _ t =
|
||||
let index = Atomic.fetch_and_add counter 1 in
|
||||
if tls_length index <= max_word_size then
|
||||
index
|
||||
else (
|
||||
(* Some platforms have a small max word size. *)
|
||||
ignore (Atomic.fetch_and_add counter (-1));
|
||||
failwith "Thread_local_storage.create: out of TLS slots"
|
||||
)
|
||||
|
||||
type thread_internal_state = {
|
||||
_id: int; (** Thread ID (here for padding reasons) *)
|
||||
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
|
||||
_other: Obj.t;
|
||||
(** Here to avoid lying to ocamlopt/flambda about the size of [Thread.t] *)
|
||||
}
|
||||
(** A partial representation of the internal type [Thread.t], allowing
|
||||
us to access the second field (unused after the thread
|
||||
has started) and stash TLS data in it. *)
|
||||
|
||||
let ceil_pow_2_minus_1 (n : int) : int =
|
||||
let n = n lor (n lsr 1) in
|
||||
let n = n lor (n lsr 2) in
|
||||
let n = n lor (n lsr 4) in
|
||||
let n = n lor (n lsr 8) in
|
||||
let n = n lor (n lsr 16) in
|
||||
if Sys.int_size > 32 then
|
||||
n lor (n lsr 32)
|
||||
let[@inline] get_raw index : Obj.t =
|
||||
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
|
||||
let tls = thread.tls in
|
||||
if Obj.is_block tls && index < Array.length (Obj.obj tls : Obj.t array) then
|
||||
Array.unsafe_get (Obj.obj tls : Obj.t array) index
|
||||
else
|
||||
n
|
||||
sentinel_value_for_uninit_tls
|
||||
|
||||
let[@inline never] tls_error () =
|
||||
failwith "Thread_local_storage.get: TLS entry not initialised"
|
||||
|
||||
let[@inline] get slot =
|
||||
let v = get_raw slot in
|
||||
if v != sentinel_value_for_uninit_tls then
|
||||
Obj.obj v
|
||||
else
|
||||
tls_error ()
|
||||
|
||||
let[@inline] get_opt slot =
|
||||
let v = get_raw slot in
|
||||
if v != sentinel_value_for_uninit_tls then
|
||||
Some (Obj.obj v)
|
||||
else
|
||||
None
|
||||
|
||||
(** Allocating and setting *)
|
||||
|
||||
(** Grow the array so that [index] is valid. *)
|
||||
let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array =
|
||||
let new_length = ceil_pow_2_minus_1 (index + 1) in
|
||||
let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in
|
||||
let grow (old : Obj.t array) (index : int) : Obj.t array =
|
||||
let new_length = tls_length index in
|
||||
let new_ = Array.make new_length sentinel_value_for_uninit_tls in
|
||||
Array.blit old 0 new_ 0 (Array.length old);
|
||||
new_
|
||||
|
||||
let[@inline] get_tls_ (index : int) : Obj.t array =
|
||||
let get_tls_with_capacity index : Obj.t array =
|
||||
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
|
||||
let tls = thread.tls in
|
||||
if Obj.is_int tls then (
|
||||
let new_tls = grow_tls [||] index in
|
||||
thread.tls <- Obj.magic new_tls;
|
||||
let new_tls = grow [||] index in
|
||||
thread.tls <- Obj.repr new_tls;
|
||||
new_tls
|
||||
) else (
|
||||
let tls = (Obj.magic tls : Obj.t array) in
|
||||
let tls = (Obj.obj tls : Obj.t array) in
|
||||
if index < Array.length tls then
|
||||
tls
|
||||
else (
|
||||
let new_tls = grow_tls tls index in
|
||||
thread.tls <- Obj.magic new_tls;
|
||||
let new_tls = grow tls index in
|
||||
thread.tls <- Obj.repr new_tls;
|
||||
new_tls
|
||||
)
|
||||
)
|
||||
|
||||
let get key =
|
||||
let tls = get_tls_ key.index in
|
||||
let value = Array.unsafe_get tls key.index in
|
||||
if value != sentinel_value_for_uninit_tls_ () then
|
||||
Obj.magic value
|
||||
else (
|
||||
let value = key.compute () in
|
||||
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value));
|
||||
value
|
||||
)
|
||||
|
||||
let set key value =
|
||||
let tls = get_tls_ key.index in
|
||||
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value))
|
||||
let[@inline] set slot value : unit =
|
||||
let tls = get_tls_with_capacity slot in
|
||||
Array.unsafe_set tls slot (Obj.repr (Sys.opaque_identity value))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue