diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index 7646ca11..0023c1a8 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -2,15 +2,14 @@ open Types_ include Runner let ( let@ ) = ( @@ ) -let k_storage = Task_local_storage.Private_.Storage.k_storage type task_full = | T_start of { - ls: Task_local_storage.storage ref; + ls: Task_local_storage.t; f: task; } | T_resume : { - ls: Task_local_storage.storage ref; + ls: Task_local_storage.t; k: 'a -> unit; x: 'a; } @@ -30,7 +29,7 @@ let schedule_ (self : state) (task : task_full) : unit = try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task -type worker_state = { mutable cur_ls: Task_local_storage.storage ref option } +type worker_state = { mutable cur_ls: Task_local_storage.t option } let k_worker_state : worker_state option ref TLS.key = TLS.new_key (fun () -> ref None) @@ -56,7 +55,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = | T_start { ls; _ } | T_resume { ls; _ } -> ls in w.cur_ls <- Some ls; - TLS.set k_storage (Some ls); + TLS.get k_cur_storage := Some ls; let _ctx = before_task runner in (* run the task now, catching errors, handling effects *) @@ -75,7 +74,7 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = on_exn e bt); after_task runner _ctx; w.cur_ls <- None; - TLS.set k_storage None + TLS.get k_cur_storage := None in let main_loop () = diff --git a/src/core/immediate_runner.ml b/src/core/immediate_runner.ml index 4e15c434..56a2cbee 100644 --- a/src/core/immediate_runner.ml +++ b/src/core/immediate_runner.ml @@ -1,18 +1,15 @@ open Types_ include Runner -(* convenient alias *) -let k_ls = Task_local_storage.Private_.Storage.k_storage - let run_async_ ~ls:cur_ls f = - TLS.set k_ls (Some cur_ls); + TLS.get k_cur_storage := Some cur_ls; try let x = f () in - TLS.set k_ls None; + TLS.get k_cur_storage := None; x with e -> let bt = Printexc.get_raw_backtrace () in - TLS.set k_ls None; + TLS.get k_cur_storage := None; Printexc.raise_with_backtrace e bt let runner : t = diff --git a/src/core/moonpool.mli b/src/core/moonpool.mli index df09d409..d78f12b6 100644 --- a/src/core/moonpool.mli +++ b/src/core/moonpool.mli @@ -26,15 +26,13 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t to run the thread. This ensures that we don't always pick the same domain to run all the various threads needed in an application (timers, event loops, etc.) *) -val run_async : - ?ls:Task_local_storage.storage ref -> Runner.t -> (unit -> unit) -> unit +val run_async : ?ls:Task_local_storage.t -> Runner.t -> (unit -> unit) -> unit (** [run_async runner task] schedules the task to run on the given runner. This means [task()] will be executed at some point in the future, possibly in another thread. @since 0.5 *) -val run_wait_block : - ?ls:Task_local_storage.storage ref -> Runner.t -> (unit -> 'a) -> 'a +val run_wait_block : ?ls:Task_local_storage.t -> Runner.t -> (unit -> 'a) -> 'a (** [run_wait_block runner f] schedules [f] for later execution on the runner, like {!run_async}. It then blocks the current thread until [f()] is done executing, diff --git a/src/core/runner.ml b/src/core/runner.ml index 391fdcd9..cf4deb9a 100644 --- a/src/core/runner.ml +++ b/src/core/runner.ml @@ -1,9 +1,10 @@ +open Types_ module TLS = Thread_local_storage_ type task = unit -> unit -type t = { - run_async: ls:Task_local_storage.storage ref -> task -> unit; +type t = runner = { + run_async: ls:local_storage -> task -> unit; shutdown: wait:bool -> unit -> unit; size: unit -> int; num_tasks: unit -> int; @@ -11,9 +12,7 @@ type t = { exception Shutdown -let[@inline] run_async - ?(ls = ref @@ Task_local_storage.Private_.Storage.create ()) (self : t) f : - unit = +let[@inline] run_async ?(ls = create_local_storage ()) (self : t) f : unit = self.run_async ~ls f let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true () @@ -41,8 +40,8 @@ module For_runner_implementors = struct let create ~size ~num_tasks ~shutdown ~run_async () : t = { size; num_tasks; shutdown; run_async } - let k_cur_runner : t option ref TLS.key = TLS.new_key (fun () -> ref None) + let k_cur_runner : t option ref TLS.key = Types_.k_cur_runner end -let[@inline] get_current_runner () : _ option = - !(TLS.get For_runner_implementors.k_cur_runner) +let get_current_runner = get_current_runner +let get_current_storage = get_current_storage diff --git a/src/core/runner.mli b/src/core/runner.mli index 577a4b39..4b43bb1c 100644 --- a/src/core/runner.mli +++ b/src/core/runner.mli @@ -33,15 +33,14 @@ val shutdown_without_waiting : t -> unit exception Shutdown -val run_async : ?ls:Task_local_storage.storage ref -> t -> task -> unit +val run_async : ?ls:Task_local_storage.t -> t -> task -> unit (** [run_async pool f] schedules [f] for later execution on the runner in one of the threads. [f()] will run on one of the runner's worker threads/domains. @param ls if provided, run the task with this initial local storage @raise Shutdown if the runner was shut down before [run_async] was called. *) -val run_wait_block : - ?ls:Task_local_storage.storage ref -> t -> (unit -> 'a) -> 'a +val run_wait_block : ?ls:Task_local_storage.t -> t -> (unit -> 'a) -> 'a (** [run_wait_block pool f] schedules [f] for later execution on the pool, like {!run_async}. It then blocks the current thread until [f()] is done executing, @@ -61,7 +60,7 @@ module For_runner_implementors : sig size:(unit -> int) -> num_tasks:(unit -> int) -> shutdown:(wait:bool -> unit -> unit) -> - run_async:(ls:Task_local_storage.storage ref -> task -> unit) -> + run_async:(ls:Task_local_storage.t -> task -> unit) -> unit -> t (** Create a new runner. @@ -80,3 +79,7 @@ val get_current_runner : unit -> t option (** Access the current runner. This returns [Some r] if the call happens on a thread that belongs in a runner. @since 0.5 *) + +val get_current_storage : unit -> Task_local_storage.t option +(** [get_current_storage runner] gets the local storage + for the currently running task. *) diff --git a/src/core/task_local_storage.ml b/src/core/task_local_storage.ml index 28cc3be7..1491b270 100644 --- a/src/core/task_local_storage.ml +++ b/src/core/task_local_storage.ml @@ -5,7 +5,44 @@ type 'a key = 'a ls_key let key_count_ = A.make 0 -type storage = task_ls +type t = local_storage +type ls_value += Dummy + +let dummy : t = ref [||] + +(** Resize array of TLS values *) +let[@inline never] resize_ (cur : ls_value array ref) n = + if n > Sys.max_array_length then failwith "too many task local storage keys"; + let len = Array.length !cur in + let new_ls = + Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy + in + Array.blit !cur 0 new_ls 0 len; + cur := new_ls + +module Direct = struct + type nonrec t = t + + let create = create_local_storage + let[@inline] copy (self : t) = ref (Array.copy !self) + + let get (type a) (self : t) ((module K) : a key) : a = + if K.offset >= Array.length !self then resize_ self (K.offset + 1); + match !self.(K.offset) with + | K.V x -> (* common case first *) x + | Dummy -> + (* first time we access this *) + let v = K.init () in + !self.(K.offset) <- K.V v; + v + | _ -> assert false + + let set (type a) (self : t) ((module K) : a key) (v : a) : unit = + assert (self != dummy); + if K.offset >= Array.length !self then resize_ self (K.offset + 1); + !self.(K.offset) <- K.V v; + () +end let new_key (type t) ~init () : t key = let offset = A.fetch_and_add key_count_ 1 in @@ -18,68 +55,25 @@ let new_key (type t) ~init () : t key = end : LS_KEY with type t = t) -type ls_value += Dummy - -(** Resize array of TLS values *) -let[@inline never] resize_ (cur : ls_value array ref) n = - if n > Sys.max_array_length then failwith "too many task local storage keys"; - let len = Array.length !cur in - let new_ls = - Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy - in - Array.blit !cur 0 new_ls 0 len; - cur := new_ls - let[@inline] get_cur_ () : ls_value array ref = - match TLS.get k_ls_values with + match get_current_storage () with | Some r -> r | None -> failwith "Task local storage must be accessed from within a runner." -let get_from_ (type a) cur ((module K) : a key) : a = - if K.offset >= Array.length !cur then resize_ cur (K.offset + 1); - match !cur.(K.offset) with - | K.V x -> (* common case first *) x - | Dummy -> - (* first time we access this *) - let v = K.init () in - !cur.(K.offset) <- K.V v; - v - | _ -> assert false - let[@inline] get (key : 'a key) : 'a = let cur = get_cur_ () in - get_from_ cur key + Direct.get cur key let[@inline] get_opt key = - match TLS.get k_ls_values with + match get_current_storage () with | None -> None - | Some cur -> Some (get_from_ cur key) - -let set_into_ (type a) cur ((module K) : a key) (v : a) : unit = - if K.offset >= Array.length !cur then resize_ cur (K.offset + 1); - !cur.(K.offset) <- K.V v; - () + | Some cur -> Some (Direct.get cur key) let[@inline] set key v : unit = let cur = get_cur_ () in - set_into_ cur key v + Direct.set cur key v let with_value key x f = let old = get key in set key x; Fun.protect ~finally:(fun () -> set key old) f - -module Private_ = struct - module Storage = struct - type t = storage - - let k_storage = k_ls_values - let[@inline] create () = [||] - let[@inline] get_cur_opt () = TLS.get k_storage - let copy = Array.copy - let get = get_from_ - let set = set_into_ - let[@inline] copy_of_current () = copy @@ !(get_cur_ ()) - let dummy = [||] - end -end diff --git a/src/core/task_local_storage.mli b/src/core/task_local_storage.mli index 6ca9557e..502661bb 100644 --- a/src/core/task_local_storage.mli +++ b/src/core/task_local_storage.mli @@ -8,8 +8,11 @@ @since NEXT_RELEASE *) -type storage -(** Underlying storage for a task *) +type t = Types_.local_storage +(** Underlying storage for a task. This is mutable and + not thread-safe. *) + +val dummy : t type 'a key (** A key used to access a particular (typed) storage slot on every task. *) @@ -49,22 +52,12 @@ val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b to [f()]. When [f()] returns (or fails), [k] is restored to its old value. *) -(**/**) +(** Direct access to values from a storage handle *) +module Direct : sig + val get : t -> 'a key -> 'a + (** Access a key *) -(** Private API *) -module Private_ : sig - module Storage : sig - type t = storage - - val get : t ref -> 'a key -> 'a - val set : t ref -> 'a key -> 'a -> unit - val k_storage : t ref option Thread_local_storage_.key - val get_cur_opt : unit -> t ref option - val create : unit -> t - val copy : t -> t - val copy_of_current : unit -> t - val dummy : t - end + val set : t -> 'a key -> 'a -> unit + val create : unit -> t + val copy : t -> t end - -(**/**) diff --git a/src/core/types_.ml b/src/core/types_.ml index 00ffbe23..97079428 100644 --- a/src/core/types_.ml +++ b/src/core/types_.ml @@ -16,11 +16,21 @@ end type 'a ls_key = (module LS_KEY with type t = 'a) (** A LS key (task local storage) *) -type task_ls = ls_value array +type task = unit -> unit +type local_storage = ls_value array ref -(** Store the current LS values for the current thread. +type runner = { + run_async: ls:local_storage -> task -> unit; + shutdown: wait:bool -> unit -> unit; + size: unit -> int; + num_tasks: unit -> int; +} - A worker thread is going to cycle through many tasks, each of which - has its own storage. This key allows tasks running on the worker - to access their own storage *) -let k_ls_values : task_ls ref option TLS.key = TLS.new_key (fun () -> None) +let k_cur_runner : runner option ref TLS.key = TLS.new_key (fun () -> ref None) + +let k_cur_storage : local_storage option ref TLS.key = + TLS.new_key (fun () -> ref None) + +let[@inline] get_current_runner () : _ option = !(TLS.get k_cur_runner) +let[@inline] get_current_storage () : _ option = !(TLS.get k_cur_storage) +let[@inline] create_local_storage () = ref [||] diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index 367fbae2..12b98b09 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -1,10 +1,10 @@ +open Types_ module WSQ = Ws_deque_ module A = Atomic_ module TLS = Thread_local_storage_ include Runner let ( let@ ) = ( @@ ) -let k_storage = Task_local_storage.Private_.Storage.k_storage module Id = struct type t = unit ref @@ -18,11 +18,11 @@ type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type task_full = | T_start of { - ls: Task_local_storage.storage ref; + ls: Task_local_storage.t; f: task; } | T_resume : { - ls: Task_local_storage.storage ref; + ls: Task_local_storage.t; k: 'a -> unit; x: 'a; } @@ -32,7 +32,7 @@ type worker_state = { pool_id_: Id.t; (** Unique per pool *) mutable thread: Thread.t; q: task_full WSQ.t; (** Work stealing queue *) - mutable cur_ls: Task_local_storage.storage ref option; (** Task storage *) + mutable cur_ls: Task_local_storage.t option; (** Task storage *) rng: Random.State.t; } (** State for a given worker. Only this worker is @@ -121,7 +121,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) in w.cur_ls <- Some ls; - TLS.set k_storage (Some ls); + TLS.get k_cur_storage := Some ls; let _ctx = before_task runner in let[@inline] on_suspend () : _ ref = @@ -136,7 +136,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) | Some w when Id.equal w.pool_id_ self.id_ -> Some w | _ -> None in - let ls' = ref @@ Task_local_storage.Private_.Storage.copy !ls in + let ls' = Task_local_storage.Direct.copy ls in schedule_task_ self ~w @@ T_start { ls = ls'; f = task' } in @@ -166,7 +166,7 @@ let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) after_task runner _ctx; w.cur_ls <- None; - TLS.set k_storage None + TLS.get k_cur_storage := None let run_async_ (self : state) ~ls (f : task) : unit = let w = find_current_worker_ () in @@ -289,7 +289,7 @@ type ('a, 'b) create_args = (** Arguments used in {!create}. See {!create} for explanations. *) let dummy_task_ : task_full = - T_start { f = ignore; ls = ref Task_local_storage.Private_.Storage.dummy } + T_start { f = ignore; ls = Task_local_storage.dummy } let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) @@ -358,7 +358,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let thread = Thread.self () in let t_id = Thread.id thread in on_init_thread ~dom_id:dom_idx ~t_id (); - TLS.set k_storage None; + TLS.get k_cur_storage := None; (* set thread name *) Option.iter diff --git a/src/fib/fiber.ml b/src/fib/fiber.ml index 73c82ee4..fe6fe22b 100644 --- a/src/fib/fiber.ml +++ b/src/fib/fiber.ml @@ -16,7 +16,7 @@ module Private_ = struct state: 'a state A.t; (** Current state in the lifetime of the fiber *) res: 'a Fut.t; runner: Runner.t; - ls: Task_local_storage.storage ref; + ls: Task_local_storage.t; } and 'a state = @@ -248,7 +248,7 @@ let spawn_ ~ls (Nursery n) (f : nursery -> 'a) : 'a t = let spawn (Nursery n) ?(protect = true) f : _ t = (* spawn [f()] with a copy of our local storage *) - let ls = ref (Task_local_storage.Private_.Storage.copy !(n.ls)) in + let ls = Task_local_storage.Direct.copy n.ls in let child = spawn_ ~ls (Nursery n) f in add_child_ ~protect n child; child @@ -259,6 +259,8 @@ let[@inline] spawn_ignore n ?protect f : unit = module Nursery = struct type t = nursery + let[@inline] runner (Nursery n) = n.runner + let[@inline] await (Nursery n) : unit = ignore (await n); () @@ -266,17 +268,13 @@ module Nursery = struct let cancel_with (Nursery n) ebt : unit = resolve_as_failed_ n ebt let with_create_top ~on () f = - let n = - create_ - ~ls:(ref @@ Task_local_storage.Private_.Storage.create ()) - ~runner:on () - in + let n = create_ ~ls:(Task_local_storage.Direct.create ()) ~runner:on () in Fun.protect ~finally:(fun () -> resolve_ok_ n ()) (fun () -> f (Nursery n)) let with_create_sub ~protect (Nursery parent : t) f = let n = create_ - ~ls:(ref @@ Task_local_storage.Private_.Storage.copy !(parent.ls)) + ~ls:(Task_local_storage.Direct.copy parent.ls) ~runner:parent.runner () in add_child_ ~protect parent n; diff --git a/src/fib/fiber.mli b/src/fib/fiber.mli index 095ab569..4a6366b8 100644 --- a/src/fib/fiber.mli +++ b/src/fib/fiber.mli @@ -27,6 +27,9 @@ type cancel_callback = Exn_bt.t -> unit module Nursery : sig type t + val runner : t -> Runner.t + (** Recover the runner this nursery uses to spawn fibers *) + val await : t -> unit (** Await for the nursery to exit. *) @@ -59,7 +62,7 @@ module Private_ : sig state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *) res: 'a Fut.t; runner: Runner.t; - ls: Task_local_storage.storage ref; + ls: Task_local_storage.t; } (** Type definition, exposed so that {!any} can be unboxed. Please do not rely on that. *) diff --git a/src/fib/moonpool_fib.ml b/src/fib/moonpool_fib.ml new file mode 100644 index 00000000..e8063d1f --- /dev/null +++ b/src/fib/moonpool_fib.ml @@ -0,0 +1,6 @@ +(** Fiber for moonpool *) + +module Fiber = Fiber +module Fls = Fls +module Handle = Handle +include Fiber