global thread loop wrappers

This commit is contained in:
Simon Cruanes 2023-06-01 21:48:11 -04:00
parent feb3b39912
commit 835eaf84c4
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 57 additions and 19 deletions

View file

@ -91,6 +91,19 @@ module Pool = struct
q: (unit -> unit) S_queue.t;
}
type thread_loop_wrapper =
thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit
let global_thread_wrappers_ : thread_loop_wrapper list A.t = A.make []
let add_global_thread_loop_wrapper f : unit =
while
let l = A.get global_thread_wrappers_ in
not (A.compare_and_set global_thread_wrappers_ l (f :: l))
do
()
done
let[@inline] run self f : unit = S_queue.push self.q f
let worker_thread_ ~on_exn (active : bool A.t) (q : _ S_queue.t) : unit =
@ -105,9 +118,8 @@ module Pool = struct
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_)
?(wrap_thread = fun f () -> f ()) ?(on_exn = fun _ _ -> ()) ?(min = 1)
?(per_domain = 0) () : t =
?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = [])
?(on_exn = fun _ _ -> ()) ?(min = 1) ?(per_domain = 0) () : t =
(* number of threads to run *)
let min = max 1 min in
let n_domains = D_pool_.n_domains () in
@ -120,6 +132,13 @@ module Pool = struct
let active = A.make true in
let q = S_queue.create () in
let pool =
let dummy = Thread.self () in
{ active; threads = Array.make n dummy; q }
in
(* temporary queue used to obtain thread handles from domains
on which the thread are started. *)
let receive_threads = S_queue.create () in
(* start the thread with index [i] *)
@ -128,10 +147,22 @@ module Pool = struct
(* function run in the thread itself *)
let main_thread_fun () =
let t_id = Thread.id @@ Thread.self () in
let thread = Thread.self () in
let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id ();
let all_wrappers =
List.rev_append thread_wrappers (A.get global_thread_wrappers_)
in
let run () = worker_thread_ ~on_exn active q in
let run' = wrap_thread run in
(* the actual worker loop is [worker_thread_], with all
wrappers for this pool and for all pools (global_thread_wrappers_) *)
let run' =
List.fold_left (fun run f -> f ~thread ~pool run) run all_wrappers
in
(* now run the main loop *)
run' ();
on_exit_thread ~dom_id:dom_idx ~t_id ()
in
@ -149,20 +180,16 @@ module Pool = struct
(* start all threads, placing them on the domains
according to their index and [offset] in a round-robin fashion. *)
let threads =
let dummy = Thread.self () in
Array.init n (fun i ->
start_thread_with_idx i;
dummy)
in
for i = 0 to n - 1 do
start_thread_with_idx i
done;
(* receive the newly created threads back from domains *)
for _j = 1 to n do
let i, th = S_queue.pop receive_threads in
threads.(i) <- th
pool.threads.(i) <- th
done;
{ active; threads; q }
pool
let shutdown (self : t) : unit =
let was_active = A.exchange self.active false in

View file

@ -10,10 +10,23 @@ type 'a or_error = ('a, exn * Printexc.raw_backtrace) result
module Pool : sig
type t
type thread_loop_wrapper =
thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit
(** a thread wrapper [f] takes the current thread, the current pool,
and the worker function [loop : unit -> unit] which is
the worker's main loop, and returns a new loop function.
By default it just returns the same loop function but it can be used
to install tracing, effect handlers, etc. *)
val add_global_thread_loop_wrapper : thread_loop_wrapper -> unit
(** [add_global_thread_loop_wrapper f] installs [f] to be installed in every new pool worker
thread, for all existing pools, and all new pools created with [create].
These wrappers accumulate: they all apply, but their order is not specified. *)
val create :
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?wrap_thread:((unit -> unit) -> unit -> unit) ->
?thread_wrappers:thread_loop_wrapper list ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?min:int ->
?per_domain:int ->
@ -23,10 +36,8 @@ module Pool : sig
@param on_init_thread called at the beginning of each new thread
in the pool.
@param on_exit_thread called at the end of each thread in the pool
@param wrap_thread takes the worker function [loop : unit -> unit] which is
the worker's main loop, and returns a new loop function.
By default it just returns the same loop function but it can be used
to install tracing, effect handlers, etc.
@param thread_wrappers a list of {!thread_loop_wrapper} functions
to use for this pool's workers.
*)
val shutdown : t -> unit