add around_task to Pool.create

This commit is contained in:
Simon Cruanes 2023-06-15 11:19:50 -04:00
parent b451fde853
commit fc3d2d2645
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
2 changed files with 30 additions and 10 deletions

View file

@ -60,9 +60,12 @@ let num_tasks (self : t) : int =
exception Got_task of task exception Got_task of task
let worker_thread_ ~on_exn (active : bool A.t) (qs : task Bb_queue.t array) type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
~(offset : int) : unit =
let worker_thread_ pool ~on_exn ~around_task (active : bool A.t)
(qs : task Bb_queue.t array) ~(offset : int) : unit =
let num_qs = Array.length qs in let num_qs = Array.length qs in
let (AT_pair (before_task, after_task)) = around_task in
try try
while A.get active do while A.get active do
@ -84,10 +87,12 @@ let worker_thread_ ~on_exn (active : bool A.t) (qs : task Bb_queue.t array)
with Got_task f -> f with Got_task f -> f
in in
try task () let _ctx = before_task pool in
with e -> (try task ()
let bt = Printexc.get_raw_backtrace () in with e ->
on_exn e bt let bt = Printexc.get_raw_backtrace () in
on_exn e bt);
after_task pool _ctx
done done
with Bb_queue.Closed -> () with Bb_queue.Closed -> ()
@ -103,7 +108,15 @@ let max_queues = 32
let create ?(on_init_thread = default_thread_init_exit_) let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = []) ?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = [])
?(on_exn = fun _ _ -> ()) ?min:(min_threads = 1) ?(per_domain = 0) () : t = ?(on_exn = fun _ _ -> ()) ?around_task ?min:(min_threads = 1)
?(per_domain = 0) () : t =
(* wrapper *)
let around_task =
match around_task with
| Some (f, g) -> AT_pair (f, g)
| None -> AT_pair (ignore, fun _ _ -> ())
in
(* number of threads to run *) (* number of threads to run *)
let min_threads = max 1 min_threads in let min_threads = max 1 min_threads in
let num_domains = D_pool_.n_domains () in let num_domains = D_pool_.n_domains () in
@ -133,7 +146,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let dom_idx = (offset + i) mod num_domains in let dom_idx = (offset + i) mod num_domains in
(* function run in the thread itself *) (* function run in the thread itself *)
let main_thread_fun () = let main_thread_fun () : unit =
let thread = Thread.self () in let thread = Thread.self () in
let t_id = Thread.id thread in let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id (); on_init_thread ~dom_id:dom_idx ~t_id ();
@ -142,7 +155,9 @@ let create ?(on_init_thread = default_thread_init_exit_)
List.rev_append thread_wrappers (A.get global_thread_wrappers_) List.rev_append thread_wrappers (A.get global_thread_wrappers_)
in in
let run () = worker_thread_ ~on_exn active qs ~offset:i in let run () =
worker_thread_ pool ~on_exn ~around_task active qs ~offset:i
in
(* the actual worker loop is [worker_thread_], with all (* the actual worker loop is [worker_thread_], with all
wrappers for this pool and for all pools (global_thread_wrappers_) *) wrappers for this pool and for all pools (global_thread_wrappers_) *)
let run' = let run' =

View file

@ -13,7 +13,7 @@ type t
type thread_loop_wrapper = type thread_loop_wrapper =
thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit
(** a thread wrapper [f] takes the current thread, the current pool, (** A thread wrapper [f] takes the current thread, the current pool,
and the worker function [loop : unit -> unit] which is and the worker function [loop : unit -> unit] which is
the worker's main loop, and returns a new loop function. the worker's main loop, and returns a new loop function.
By default it just returns the same loop function but it can be used By default it just returns the same loop function but it can be used
@ -29,6 +29,7 @@ val create :
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?thread_wrappers:thread_loop_wrapper list -> ?thread_wrappers:thread_loop_wrapper list ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'a) * (t -> 'a -> unit) ->
?min:int -> ?min:int ->
?per_domain:int -> ?per_domain:int ->
unit -> unit ->
@ -47,6 +48,10 @@ val create :
@param on_exit_thread called at the end of each thread in the pool @param on_exit_thread called at the end of each thread in the pool
@param thread_wrappers a list of {!thread_loop_wrapper} functions @param thread_wrappers a list of {!thread_loop_wrapper} functions
to use for this pool's workers. to use for this pool's workers.
@param around_task a pair of [before, after], where [before pool] is called
before a task is processed,
on the worker thread about to run it, and returns [x]; and [after pool x] is called by
the same thread after the task is over. (since 0.2)
*) *)
val size : t -> int val size : t -> int