mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-06 03:05:30 -05:00
move to thread-local-storage 0.2 with get/set API
This commit is contained in:
parent
3388098fcc
commit
265d4f73dd
10 changed files with 112 additions and 89 deletions
|
|
@ -30,6 +30,7 @@
|
||||||
(depopts
|
(depopts
|
||||||
(trace (>= 0.6))
|
(trace (>= 0.6))
|
||||||
thread-local-storage)
|
thread-local-storage)
|
||||||
|
(conflicts (thread-local-storage (< 0.2)))
|
||||||
(tags
|
(tags
|
||||||
(thread pool domain futures fork-join)))
|
(thread pool domain futures fork-join)))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,9 @@ depopts: [
|
||||||
"trace" {>= "0.6"}
|
"trace" {>= "0.6"}
|
||||||
"thread-local-storage"
|
"thread-local-storage"
|
||||||
]
|
]
|
||||||
|
conflicts: [
|
||||||
|
"thread-local-storage" {< "0.2"}
|
||||||
|
]
|
||||||
build: [
|
build: [
|
||||||
["dune" "subst"] {dev}
|
["dune" "subst"] {dev}
|
||||||
[
|
[
|
||||||
|
|
|
||||||
|
|
@ -31,18 +31,17 @@ let schedule_ (self : state) (task : task_full) : unit =
|
||||||
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
|
||||||
type worker_state = { mutable cur_ls: Task_local_storage.t option }
|
type worker_state = { mutable cur_ls: Task_local_storage.t option }
|
||||||
|
|
||||||
let k_worker_state : worker_state option ref TLS.key =
|
let k_worker_state : worker_state TLS.t = TLS.create ()
|
||||||
TLS.new_key (fun () -> ref None)
|
|
||||||
|
|
||||||
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
||||||
let w = { cur_ls = None } in
|
let w = { cur_ls = None } in
|
||||||
TLS.get k_worker_state := Some w;
|
TLS.set k_worker_state w;
|
||||||
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
|
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
|
||||||
|
|
||||||
let (AT_pair (before_task, after_task)) = around_task in
|
let (AT_pair (before_task, after_task)) = around_task in
|
||||||
|
|
||||||
let on_suspend () =
|
let on_suspend () =
|
||||||
match !(TLS.get k_worker_state) with
|
match TLS.get_opt k_worker_state with
|
||||||
| Some { cur_ls = Some ls; _ } -> ls
|
| Some { cur_ls = Some ls; _ } -> ls
|
||||||
| _ -> assert false
|
| _ -> assert false
|
||||||
in
|
in
|
||||||
|
|
@ -55,7 +54,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
||||||
| T_start { ls; _ } | T_resume { ls; _ } -> ls
|
| T_start { ls; _ } | T_resume { ls; _ } -> ls
|
||||||
in
|
in
|
||||||
w.cur_ls <- Some ls;
|
w.cur_ls <- Some ls;
|
||||||
TLS.get k_cur_storage := Some ls;
|
TLS.set k_cur_storage ls;
|
||||||
let _ctx = before_task runner in
|
let _ctx = before_task runner in
|
||||||
|
|
||||||
(* run the task now, catching errors, handling effects *)
|
(* run the task now, catching errors, handling effects *)
|
||||||
|
|
@ -74,7 +73,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
|
||||||
on_exn e bt);
|
on_exn e bt);
|
||||||
after_task runner _ctx;
|
after_task runner _ctx;
|
||||||
w.cur_ls <- None;
|
w.cur_ls <- None;
|
||||||
TLS.get k_cur_storage := None
|
TLS.set k_cur_storage _dummy_ls
|
||||||
in
|
in
|
||||||
|
|
||||||
let main_loop () =
|
let main_loop () =
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ module For_runner_implementors = struct
|
||||||
let create ~size ~num_tasks ~shutdown ~run_async () : t =
|
let create ~size ~num_tasks ~shutdown ~run_async () : t =
|
||||||
{ size; num_tasks; shutdown; run_async }
|
{ size; num_tasks; shutdown; run_async }
|
||||||
|
|
||||||
let k_cur_runner : t option ref TLS.key = Types_.k_cur_runner
|
let k_cur_runner : t TLS.t = Types_.k_cur_runner
|
||||||
end
|
end
|
||||||
|
|
||||||
let dummy : t =
|
let dummy : t =
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ module For_runner_implementors : sig
|
||||||
{b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x,
|
{b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x,
|
||||||
so that {!Fork_join} and other 5.x features work properly. *)
|
so that {!Fork_join} and other 5.x features work properly. *)
|
||||||
|
|
||||||
val k_cur_runner : t option ref Thread_local_storage_.key
|
val k_cur_runner : t Thread_local_storage_.t
|
||||||
(** Key that should be used by each runner to store itself in TLS
|
(** Key that should be used by each runner to store itself in TLS
|
||||||
on every thread it controls, so that tasks running on these threads
|
on every thread it controls, so that tasks running on these threads
|
||||||
can access the runner. This is necessary for {!get_current_runner}
|
can access the runner. This is necessary for {!get_current_runner}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ let key_count_ = A.make 0
|
||||||
type t = local_storage
|
type t = local_storage
|
||||||
type ls_value += Dummy
|
type ls_value += Dummy
|
||||||
|
|
||||||
let dummy : t = ref [||]
|
let dummy : t = _dummy_ls
|
||||||
|
|
||||||
(** Resize array of TLS values *)
|
(** Resize array of TLS values *)
|
||||||
let[@inline never] resize_ (cur : ls_value array ref) n =
|
let[@inline never] resize_ (cur : ls_value array ref) n =
|
||||||
|
|
@ -57,7 +57,9 @@ let new_key (type t) ~init () : t key =
|
||||||
|
|
||||||
let[@inline] get_cur_ () : ls_value array ref =
|
let[@inline] get_cur_ () : ls_value array ref =
|
||||||
match get_current_storage () with
|
match get_current_storage () with
|
||||||
| Some r -> r
|
| Some r ->
|
||||||
|
assert (r != dummy);
|
||||||
|
r
|
||||||
| None -> failwith "Task local storage must be accessed from within a runner."
|
| None -> failwith "Task local storage must be accessed from within a runner."
|
||||||
|
|
||||||
let[@inline] get (key : 'a key) : 'a =
|
let[@inline] get (key : 'a key) : 'a =
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,9 @@ type runner = {
|
||||||
num_tasks: unit -> int;
|
num_tasks: unit -> int;
|
||||||
}
|
}
|
||||||
|
|
||||||
let k_cur_runner : runner option ref TLS.key = TLS.new_key (fun () -> ref None)
|
let k_cur_runner : runner TLS.t = TLS.create ()
|
||||||
|
let k_cur_storage : local_storage TLS.t = TLS.create ()
|
||||||
let k_cur_storage : local_storage option ref TLS.key =
|
let _dummy_ls : local_storage = ref [||]
|
||||||
TLS.new_key (fun () -> ref None)
|
let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner
|
||||||
|
let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage
|
||||||
let[@inline] get_current_runner () : _ option = !(TLS.get k_cur_runner)
|
|
||||||
let[@inline] get_current_storage () : _ option = !(TLS.get k_cur_storage)
|
|
||||||
let[@inline] create_local_storage () = ref [||]
|
let[@inline] create_local_storage () = ref [||]
|
||||||
|
|
|
||||||
|
|
@ -65,11 +65,10 @@ let num_tasks_ (self : state) : int =
|
||||||
(** TLS, used by worker to store their specific state
|
(** TLS, used by worker to store their specific state
|
||||||
and be able to retrieve it from tasks when we schedule new
|
and be able to retrieve it from tasks when we schedule new
|
||||||
sub-tasks. *)
|
sub-tasks. *)
|
||||||
let k_worker_state : worker_state option ref TLS.key =
|
let k_worker_state : worker_state TLS.t = TLS.create ()
|
||||||
TLS.new_key (fun () -> ref None)
|
|
||||||
|
|
||||||
let[@inline] find_current_worker_ () : worker_state option =
|
let[@inline] find_current_worker_ () : worker_state option =
|
||||||
!(TLS.get k_worker_state)
|
TLS.get_opt k_worker_state
|
||||||
|
|
||||||
(** Try to wake up a waiter, if there's any. *)
|
(** Try to wake up a waiter, if there's any. *)
|
||||||
let[@inline] try_wake_someone_ (self : state) : unit =
|
let[@inline] try_wake_someone_ (self : state) : unit =
|
||||||
|
|
@ -121,7 +120,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
|
||||||
in
|
in
|
||||||
|
|
||||||
w.cur_ls <- Some ls;
|
w.cur_ls <- Some ls;
|
||||||
TLS.get k_cur_storage := Some ls;
|
TLS.set k_cur_storage ls;
|
||||||
let _ctx = before_task runner in
|
let _ctx = before_task runner in
|
||||||
|
|
||||||
let[@inline] on_suspend () : _ ref =
|
let[@inline] on_suspend () : _ ref =
|
||||||
|
|
@ -166,7 +165,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full)
|
||||||
|
|
||||||
after_task runner _ctx;
|
after_task runner _ctx;
|
||||||
w.cur_ls <- None;
|
w.cur_ls <- None;
|
||||||
TLS.get k_cur_storage := None
|
TLS.set k_cur_storage _dummy_ls
|
||||||
|
|
||||||
let run_async_ (self : state) ~ls (f : task) : unit =
|
let run_async_ (self : state) ~ls (f : task) : unit =
|
||||||
let w = find_current_worker_ () in
|
let w = find_current_worker_ () in
|
||||||
|
|
@ -222,8 +221,8 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
|
||||||
|
|
||||||
(** Main loop for a worker thread. *)
|
(** Main loop for a worker thread. *)
|
||||||
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
|
||||||
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
|
TLS.set Runner.For_runner_implementors.k_cur_runner runner;
|
||||||
TLS.get k_worker_state := Some w;
|
TLS.set k_worker_state w;
|
||||||
|
|
||||||
let rec main () : unit =
|
let rec main () : unit =
|
||||||
worker_run_self_tasks_ self ~runner w;
|
worker_run_self_tasks_ self ~runner w;
|
||||||
|
|
@ -358,7 +357,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
let thread = Thread.self () in
|
let thread = Thread.self () in
|
||||||
let t_id = Thread.id thread in
|
let t_id = Thread.id thread in
|
||||||
on_init_thread ~dom_id:dom_idx ~t_id ();
|
on_init_thread ~dom_id:dom_idx ~t_id ();
|
||||||
TLS.get k_cur_storage := None;
|
TLS.set k_cur_storage _dummy_ls;
|
||||||
|
|
||||||
(* set thread name *)
|
(* set thread name *)
|
||||||
Option.iter
|
Option.iter
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,13 @@
|
||||||
(** Thread local storage *)
|
(** Thread local storage *)
|
||||||
|
|
||||||
(* TODO: alias this to the library if present *)
|
type 'a t
|
||||||
|
(** A TLS slot for values of type ['a]. This allows the storage of a
|
||||||
|
single value of type ['a] per thread. *)
|
||||||
|
|
||||||
type 'a key
|
val create : unit -> 'a t
|
||||||
(** A TLS key for values of type ['a]. This allows the
|
|
||||||
storage of a single value of type ['a] per thread. *)
|
|
||||||
|
|
||||||
val new_key : (unit -> 'a) -> 'a key
|
val get : 'a t -> 'a
|
||||||
(** Allocate a new, generative key.
|
(** @raise Failure if not present *)
|
||||||
When the key is used for the first time on a thread,
|
|
||||||
the function is called to produce it.
|
|
||||||
|
|
||||||
This should only ever be called at toplevel to produce
|
val get_opt : 'a t -> 'a option
|
||||||
constants, do not use it in a loop. *)
|
val set : 'a t -> 'a -> unit
|
||||||
|
|
||||||
val get : 'a key -> 'a
|
|
||||||
(** Get the value for the current thread. *)
|
|
||||||
|
|
||||||
val set : 'a key -> 'a -> unit
|
|
||||||
(** Set the value for the current thread. *)
|
|
||||||
|
|
|
||||||
|
|
@ -1,82 +1,111 @@
|
||||||
(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *)
|
(* vendored from https://github.com/c-cube/thread-local-storage *)
|
||||||
|
|
||||||
module A = Atomic_
|
|
||||||
|
|
||||||
(* sanity check *)
|
(* sanity check *)
|
||||||
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())
|
let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ())
|
||||||
|
|
||||||
type 'a key = {
|
type 'a t = int
|
||||||
index: int; (** Unique index for this key. *)
|
(** Unique index for this TLS slot. *)
|
||||||
compute: unit -> 'a;
|
|
||||||
(** Initializer for values for this key. Called at most
|
let tls_length index =
|
||||||
once per thread. *)
|
let ceil_pow_2_minus_1 (n : int) : int =
|
||||||
}
|
let n = n lor (n lsr 1) in
|
||||||
|
let n = n lor (n lsr 2) in
|
||||||
|
let n = n lor (n lsr 4) in
|
||||||
|
let n = n lor (n lsr 8) in
|
||||||
|
let n = n lor (n lsr 16) in
|
||||||
|
if Sys.int_size > 32 then
|
||||||
|
n lor (n lsr 32)
|
||||||
|
else
|
||||||
|
n
|
||||||
|
in
|
||||||
|
let size = ceil_pow_2_minus_1 (index + 1) in
|
||||||
|
assert (size > index);
|
||||||
|
size
|
||||||
|
|
||||||
(** Counter used to allocate new keys *)
|
(** Counter used to allocate new keys *)
|
||||||
let counter = A.make 0
|
let counter = Atomic.make 0
|
||||||
|
|
||||||
(** Value used to detect a TLS slot that was not initialized yet *)
|
(** Value used to detect a TLS slot that was not initialized yet.
|
||||||
let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter
|
Because [counter] is private and lives forever, no other
|
||||||
|
object the user can see will have the same address. *)
|
||||||
|
let sentinel_value_for_uninit_tls : Obj.t = Obj.repr counter
|
||||||
|
|
||||||
let new_key compute : _ key =
|
external max_wosize : unit -> int = "caml_sys_const_max_wosize"
|
||||||
let index = A.fetch_and_add counter 1 in
|
|
||||||
{ index; compute }
|
let max_word_size = max_wosize ()
|
||||||
|
|
||||||
|
let create () : _ t =
|
||||||
|
let index = Atomic.fetch_and_add counter 1 in
|
||||||
|
if tls_length index <= max_word_size then
|
||||||
|
index
|
||||||
|
else (
|
||||||
|
(* Some platforms have a small max word size. *)
|
||||||
|
ignore (Atomic.fetch_and_add counter (-1));
|
||||||
|
failwith "Thread_local_storage.create: out of TLS slots"
|
||||||
|
)
|
||||||
|
|
||||||
type thread_internal_state = {
|
type thread_internal_state = {
|
||||||
_id: int; (** Thread ID (here for padding reasons) *)
|
_id: int; (** Thread ID (here for padding reasons) *)
|
||||||
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
|
mutable tls: Obj.t; (** Our data, stowed away in this unused field *)
|
||||||
|
_other: Obj.t;
|
||||||
|
(** Here to avoid lying to ocamlopt/flambda about the size of [Thread.t] *)
|
||||||
}
|
}
|
||||||
(** A partial representation of the internal type [Thread.t], allowing
|
(** A partial representation of the internal type [Thread.t], allowing
|
||||||
us to access the second field (unused after the thread
|
us to access the second field (unused after the thread
|
||||||
has started) and stash TLS data in it. *)
|
has started) and stash TLS data in it. *)
|
||||||
|
|
||||||
let ceil_pow_2_minus_1 (n : int) : int =
|
let[@inline] get_raw index : Obj.t =
|
||||||
let n = n lor (n lsr 1) in
|
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
|
||||||
let n = n lor (n lsr 2) in
|
let tls = thread.tls in
|
||||||
let n = n lor (n lsr 4) in
|
if Obj.is_block tls && index < Array.length (Obj.obj tls : Obj.t array) then
|
||||||
let n = n lor (n lsr 8) in
|
Array.unsafe_get (Obj.obj tls : Obj.t array) index
|
||||||
let n = n lor (n lsr 16) in
|
|
||||||
if Sys.int_size > 32 then
|
|
||||||
n lor (n lsr 32)
|
|
||||||
else
|
else
|
||||||
n
|
sentinel_value_for_uninit_tls
|
||||||
|
|
||||||
|
let[@inline never] tls_error () =
|
||||||
|
failwith "Thread_local_storage.get: TLS entry not initialised"
|
||||||
|
|
||||||
|
let[@inline] get slot =
|
||||||
|
let v = get_raw slot in
|
||||||
|
if v != sentinel_value_for_uninit_tls then
|
||||||
|
Obj.obj v
|
||||||
|
else
|
||||||
|
tls_error ()
|
||||||
|
|
||||||
|
let[@inline] get_opt slot =
|
||||||
|
let v = get_raw slot in
|
||||||
|
if v != sentinel_value_for_uninit_tls then
|
||||||
|
Some (Obj.obj v)
|
||||||
|
else
|
||||||
|
None
|
||||||
|
|
||||||
|
(** Allocating and setting *)
|
||||||
|
|
||||||
(** Grow the array so that [index] is valid. *)
|
(** Grow the array so that [index] is valid. *)
|
||||||
let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array =
|
let grow (old : Obj.t array) (index : int) : Obj.t array =
|
||||||
let new_length = ceil_pow_2_minus_1 (index + 1) in
|
let new_length = tls_length index in
|
||||||
let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in
|
let new_ = Array.make new_length sentinel_value_for_uninit_tls in
|
||||||
Array.blit old 0 new_ 0 (Array.length old);
|
Array.blit old 0 new_ 0 (Array.length old);
|
||||||
new_
|
new_
|
||||||
|
|
||||||
let[@inline] get_tls_ (index : int) : Obj.t array =
|
let get_tls_with_capacity index : Obj.t array =
|
||||||
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
|
let thread : thread_internal_state = Obj.magic (Thread.self ()) in
|
||||||
let tls = thread.tls in
|
let tls = thread.tls in
|
||||||
if Obj.is_int tls then (
|
if Obj.is_int tls then (
|
||||||
let new_tls = grow_tls [||] index in
|
let new_tls = grow [||] index in
|
||||||
thread.tls <- Obj.magic new_tls;
|
thread.tls <- Obj.repr new_tls;
|
||||||
new_tls
|
new_tls
|
||||||
) else (
|
) else (
|
||||||
let tls = (Obj.magic tls : Obj.t array) in
|
let tls = (Obj.obj tls : Obj.t array) in
|
||||||
if index < Array.length tls then
|
if index < Array.length tls then
|
||||||
tls
|
tls
|
||||||
else (
|
else (
|
||||||
let new_tls = grow_tls tls index in
|
let new_tls = grow tls index in
|
||||||
thread.tls <- Obj.magic new_tls;
|
thread.tls <- Obj.repr new_tls;
|
||||||
new_tls
|
new_tls
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
let get key =
|
let[@inline] set slot value : unit =
|
||||||
let tls = get_tls_ key.index in
|
let tls = get_tls_with_capacity slot in
|
||||||
let value = Array.unsafe_get tls key.index in
|
Array.unsafe_set tls slot (Obj.repr (Sys.opaque_identity value))
|
||||||
if value != sentinel_value_for_uninit_tls_ () then
|
|
||||||
Obj.magic value
|
|
||||||
else (
|
|
||||||
let value = key.compute () in
|
|
||||||
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value));
|
|
||||||
value
|
|
||||||
)
|
|
||||||
|
|
||||||
let set key value =
|
|
||||||
let tls = get_tls_ key.index in
|
|
||||||
Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value))
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue