diff --git a/src/moonpool.ml b/src/moonpool.ml index 96b93a24..12dc5fad 100644 --- a/src/moonpool.ml +++ b/src/moonpool.ml @@ -91,6 +91,19 @@ module Pool = struct q: (unit -> unit) S_queue.t; } + type thread_loop_wrapper = + thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit + + let global_thread_wrappers_ : thread_loop_wrapper list A.t = A.make [] + + let add_global_thread_loop_wrapper f : unit = + while + let l = A.get global_thread_wrappers_ in + not (A.compare_and_set global_thread_wrappers_ l (f :: l)) + do + () + done + let[@inline] run self f : unit = S_queue.push self.q f let worker_thread_ ~on_exn (active : bool A.t) (q : _ S_queue.t) : unit = @@ -105,9 +118,8 @@ module Pool = struct let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let create ?(on_init_thread = default_thread_init_exit_) - ?(on_exit_thread = default_thread_init_exit_) - ?(wrap_thread = fun f () -> f ()) ?(on_exn = fun _ _ -> ()) ?(min = 1) - ?(per_domain = 0) () : t = + ?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = []) + ?(on_exn = fun _ _ -> ()) ?(min = 1) ?(per_domain = 0) () : t = (* number of threads to run *) let min = max 1 min in let n_domains = D_pool_.n_domains () in @@ -120,6 +132,13 @@ module Pool = struct let active = A.make true in let q = S_queue.create () in + let pool = + let dummy = Thread.self () in + { active; threads = Array.make n dummy; q } + in + + (* temporary queue used to obtain thread handles from domains + on which the thread are started. *) let receive_threads = S_queue.create () in (* start the thread with index [i] *) @@ -128,10 +147,22 @@ module Pool = struct (* function run in the thread itself *) let main_thread_fun () = - let t_id = Thread.id @@ Thread.self () in + let thread = Thread.self () in + let t_id = Thread.id thread in on_init_thread ~dom_id:dom_idx ~t_id (); + + let all_wrappers = + List.rev_append thread_wrappers (A.get global_thread_wrappers_) + in + let run () = worker_thread_ ~on_exn active q in - let run' = wrap_thread run in + (* the actual worker loop is [worker_thread_], with all + wrappers for this pool and for all pools (global_thread_wrappers_) *) + let run' = + List.fold_left (fun run f -> f ~thread ~pool run) run all_wrappers + in + + (* now run the main loop *) run' (); on_exit_thread ~dom_id:dom_idx ~t_id () in @@ -149,20 +180,16 @@ module Pool = struct (* start all threads, placing them on the domains according to their index and [offset] in a round-robin fashion. *) - let threads = - let dummy = Thread.self () in - Array.init n (fun i -> - start_thread_with_idx i; - dummy) - in + for i = 0 to n - 1 do + start_thread_with_idx i + done; (* receive the newly created threads back from domains *) for _j = 1 to n do let i, th = S_queue.pop receive_threads in - threads.(i) <- th + pool.threads.(i) <- th done; - - { active; threads; q } + pool let shutdown (self : t) : unit = let was_active = A.exchange self.active false in diff --git a/src/moonpool.mli b/src/moonpool.mli index 062ba8c8..9d53593f 100644 --- a/src/moonpool.mli +++ b/src/moonpool.mli @@ -10,10 +10,23 @@ type 'a or_error = ('a, exn * Printexc.raw_backtrace) result module Pool : sig type t + type thread_loop_wrapper = + thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit + (** a thread wrapper [f] takes the current thread, the current pool, + and the worker function [loop : unit -> unit] which is + the worker's main loop, and returns a new loop function. + By default it just returns the same loop function but it can be used + to install tracing, effect handlers, etc. *) + + val add_global_thread_loop_wrapper : thread_loop_wrapper -> unit + (** [add_global_thread_loop_wrapper f] installs [f] to be installed in every new pool worker + thread, for all existing pools, and all new pools created with [create]. + These wrappers accumulate: they all apply, but their order is not specified. *) + val create : ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> - ?wrap_thread:((unit -> unit) -> unit -> unit) -> + ?thread_wrappers:thread_loop_wrapper list -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?min:int -> ?per_domain:int -> @@ -23,10 +36,8 @@ module Pool : sig @param on_init_thread called at the beginning of each new thread in the pool. @param on_exit_thread called at the end of each thread in the pool - @param wrap_thread takes the worker function [loop : unit -> unit] which is - the worker's main loop, and returns a new loop function. - By default it just returns the same loop function but it can be used - to install tracing, effect handlers, etc. + @param thread_wrappers a list of {!thread_loop_wrapper} functions + to use for this pool's workers. *) val shutdown : t -> unit