diff --git a/src/moonpool.ml b/src/moonpool.ml index 75af6d18..96b93a24 100644 --- a/src/moonpool.ml +++ b/src/moonpool.ml @@ -105,8 +105,9 @@ module Pool = struct let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let create ?(on_init_thread = default_thread_init_exit_) - ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) - ?(min = 1) ?(per_domain = 0) () : t = + ?(on_exit_thread = default_thread_init_exit_) + ?(wrap_thread = fun f () -> f ()) ?(on_exn = fun _ _ -> ()) ?(min = 1) + ?(per_domain = 0) () : t = (* number of threads to run *) let min = max 1 min in let n_domains = D_pool_.n_domains () in @@ -125,21 +126,25 @@ module Pool = struct let start_thread_with_idx i = let dom_idx = (offset + i) mod n_domains in - let create () = - let thread = - Thread.create - (fun () -> - let t_id = Thread.id @@ Thread.self () in - on_init_thread ~dom_id:dom_idx ~t_id (); - worker_thread_ ~on_exn active q; - on_exit_thread ~dom_id:dom_idx ~t_id ()) - () - in + (* function run in the thread itself *) + let main_thread_fun () = + let t_id = Thread.id @@ Thread.self () in + on_init_thread ~dom_id:dom_idx ~t_id (); + let run () = worker_thread_ ~on_exn active q in + let run' = wrap_thread run in + run' (); + on_exit_thread ~dom_id:dom_idx ~t_id () + in + + (* function called in domain with index [i], to + create the thread and push it into [receive_threads] *) + let create_thread_in_domain () = + let thread = Thread.create main_thread_fun () in (* send the thread from the domain back to us *) S_queue.push receive_threads (i, thread) in - D_pool_.run_on dom_idx create + D_pool_.run_on dom_idx create_thread_in_domain in (* start all threads, placing them on the domains diff --git a/src/moonpool.mli b/src/moonpool.mli index bd0cdc27..062ba8c8 100644 --- a/src/moonpool.mli +++ b/src/moonpool.mli @@ -13,6 +13,7 @@ module Pool : sig val create : ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> + ?wrap_thread:((unit -> unit) -> unit -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?min:int -> ?per_domain:int -> @@ -21,6 +22,11 @@ module Pool : sig (** [create ()] makes a new thread pool. @param on_init_thread called at the beginning of each new thread in the pool. + @param on_exit_thread called at the end of each thread in the pool + @param wrap_thread takes the worker function [loop : unit -> unit] which is + the worker's main loop, and returns a new loop function. + By default it just returns the same loop function but it can be used + to install tracing, effect handlers, etc. *) val shutdown : t -> unit