diff --git a/src/pool.ml b/src/pool.ml index 4f4abac3..eb2e9366 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -60,9 +60,12 @@ let num_tasks (self : t) : int = exception Got_task of task -let worker_thread_ ~on_exn (active : bool A.t) (qs : task Bb_queue.t array) - ~(offset : int) : unit = +type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task + +let worker_thread_ pool ~on_exn ~around_task (active : bool A.t) + (qs : task Bb_queue.t array) ~(offset : int) : unit = let num_qs = Array.length qs in + let (AT_pair (before_task, after_task)) = around_task in try while A.get active do @@ -84,10 +87,12 @@ let worker_thread_ ~on_exn (active : bool A.t) (qs : task Bb_queue.t array) with Got_task f -> f in - try task () - with e -> - let bt = Printexc.get_raw_backtrace () in - on_exn e bt + let _ctx = before_task pool in + (try task () + with e -> + let bt = Printexc.get_raw_backtrace () in + on_exn e bt); + after_task pool _ctx done with Bb_queue.Closed -> () @@ -103,7 +108,15 @@ let max_queues = 32 let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = []) - ?(on_exn = fun _ _ -> ()) ?min:(min_threads = 1) ?(per_domain = 0) () : t = + ?(on_exn = fun _ _ -> ()) ?around_task ?min:(min_threads = 1) + ?(per_domain = 0) () : t = + (* wrapper *) + let around_task = + match around_task with + | Some (f, g) -> AT_pair (f, g) + | None -> AT_pair (ignore, fun _ _ -> ()) + in + (* number of threads to run *) let min_threads = max 1 min_threads in let num_domains = D_pool_.n_domains () in @@ -133,7 +146,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let dom_idx = (offset + i) mod num_domains in (* function run in the thread itself *) - let main_thread_fun () = + let main_thread_fun () : unit = let thread = Thread.self () in let t_id = Thread.id thread in on_init_thread ~dom_id:dom_idx ~t_id (); @@ -142,7 +155,9 @@ let create ?(on_init_thread = default_thread_init_exit_) List.rev_append thread_wrappers (A.get global_thread_wrappers_) in - let run () = worker_thread_ ~on_exn active qs ~offset:i in + let run () = + worker_thread_ pool ~on_exn ~around_task active qs ~offset:i + in (* the actual worker loop is [worker_thread_], with all wrappers for this pool and for all pools (global_thread_wrappers_) *) let run' = diff --git a/src/pool.mli b/src/pool.mli index b11cf496..114f11ed 100644 --- a/src/pool.mli +++ b/src/pool.mli @@ -13,7 +13,7 @@ type t type thread_loop_wrapper = thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit -(** a thread wrapper [f] takes the current thread, the current pool, +(** A thread wrapper [f] takes the current thread, the current pool, and the worker function [loop : unit -> unit] which is the worker's main loop, and returns a new loop function. By default it just returns the same loop function but it can be used @@ -29,6 +29,7 @@ val create : ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?thread_wrappers:thread_loop_wrapper list -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> + ?around_task:(t -> 'a) * (t -> 'a -> unit) -> ?min:int -> ?per_domain:int -> unit -> @@ -47,6 +48,10 @@ val create : @param on_exit_thread called at the end of each thread in the pool @param thread_wrappers a list of {!thread_loop_wrapper} functions to use for this pool's workers. + @param around_task a pair of [before, after], where [before pool] is called + before a task is processed, + on the worker thread about to run it, and returns [x]; and [after pool x] is called by + the same thread after the task is over. (since 0.2) *) val size : t -> int