diff --git a/src/pool.ml b/src/pool.ml index 7cd78820..4f2eef66 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -2,10 +2,13 @@ module A = Atomic_ +type task = unit -> unit + type t = { active: bool A.t; threads: Thread.t array; - q: (unit -> unit) Bb_queue.t; + qs: task Bb_queue.t array; + cur_q: int A.t; (** Selects queue into which to push *) } type thread_loop_wrapper = @@ -23,42 +26,89 @@ let add_global_thread_loop_wrapper f : unit = exception Shutdown -let[@inline] run self f : unit = - try Bb_queue.push self.q f with Bb_queue.Closed -> raise Shutdown +let run (self : t) (f : task) : unit = + let n_qs = Array.length self.qs in + let offset = A.fetch_and_add self.cur_q 1 in + + (* blocking push, last resort *) + let push_wait () = + let q_idx = offset mod Array.length self.qs in + let q = self.qs.(q_idx) in + Bb_queue.push q f + in + + try + (* try each queue with a round-robin initial offset *) + for _retry = 1 to 10 do + for i = 0 to n_qs - 1 do + let q_idx = (i + offset) mod Array.length self.qs in + let q = self.qs.(q_idx) in + if Bb_queue.try_push q f then raise_notrace Exit + done + done; + push_wait () + with + | Exit -> () + | Bb_queue.Closed -> raise Shutdown let size self = Array.length self.threads -let worker_thread_ ~on_exn (active : bool A.t) (q : _ Bb_queue.t) : unit = - while A.get active do - match Bb_queue.pop q with - | exception Bb_queue.Closed -> () - | task -> - (try task () - with e -> - let bt = Printexc.get_raw_backtrace () in - on_exn e bt) - done +exception Got_task of task + +let worker_thread_ ~on_exn (active : bool A.t) (qs : task Bb_queue.t array) + ~(offset : int) : unit = + let num_qs = Array.length qs in + + try + while A.get active do + (* last resort: block on my queue *) + let pop_blocking () = + let my_q = qs.(offset mod num_qs) in + Bb_queue.pop my_q + in + + let task = + try + for i = 0 to num_qs - 1 do + let q = qs.((offset + i) mod num_qs) in + match Bb_queue.try_pop ~force_lock:false q with + | Some f -> raise_notrace (Got_task f) + | None -> () + done; + pop_blocking () + with Got_task f -> f + in + + try task () + with e -> + let bt = Printexc.get_raw_backtrace () in + on_exn e bt + done + with Bb_queue.Closed -> () let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = []) - ?(on_exn = fun _ _ -> ()) ?(min = 1) ?(per_domain = 0) () : t = + ?(on_exn = fun _ _ -> ()) ?min:(min_threads = 1) ?(per_domain = 0) () : t = (* number of threads to run *) - let min = max 1 min in - let n_domains = D_pool_.n_domains () in - assert (n_domains >= 1); - let n = max min (n_domains * per_domain) in + let min_threads = max 1 min_threads in + let num_domains = D_pool_.n_domains () in + assert (num_domains >= 1); + let num_threads = max min_threads (num_domains * per_domain) in (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) - let offset = Random.int n_domains in + let offset = Random.int num_domains in let active = A.make true in - let q = Bb_queue.create () in + let qs = + let num_qs = min num_domains num_threads in + Array.init num_qs (fun _ -> Bb_queue.create ()) + in let pool = let dummy = Thread.self () in - { active; threads = Array.make n dummy; q } + { active; threads = Array.make num_threads dummy; qs; cur_q = A.make 0 } in (* temporary queue used to obtain thread handles from domains @@ -67,7 +117,7 @@ let create ?(on_init_thread = default_thread_init_exit_) (* start the thread with index [i] *) let start_thread_with_idx i = - let dom_idx = (offset + i) mod n_domains in + let dom_idx = (offset + i) mod num_domains in (* function run in the thread itself *) let main_thread_fun () = @@ -79,7 +129,7 @@ let create ?(on_init_thread = default_thread_init_exit_) List.rev_append thread_wrappers (A.get global_thread_wrappers_) in - let run () = worker_thread_ ~on_exn active q in + let run () = worker_thread_ ~on_exn active qs ~offset:i in (* the actual worker loop is [worker_thread_], with all wrappers for this pool and for all pools (global_thread_wrappers_) *) let run' = @@ -104,12 +154,12 @@ let create ?(on_init_thread = default_thread_init_exit_) (* start all threads, placing them on the domains according to their index and [offset] in a round-robin fashion. *) - for i = 0 to n - 1 do + for i = 0 to num_threads - 1 do start_thread_with_idx i done; (* receive the newly created threads back from domains *) - for _j = 1 to n do + for _j = 1 to num_threads do let i, th = Bb_queue.pop receive_threads in pool.threads.(i) <- th done; @@ -117,7 +167,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let shutdown (self : t) : unit = let was_active = A.exchange self.active false in - (* close the job queue, which will fail future calls to [run], - and wake up the subset of [self.threads] that are waiting on it. *) - if was_active then Bb_queue.close self.q; + (* close the job queues, which will fail future calls to [run], + and wake up the subset of [self.threads] that are waiting on them. *) + if was_active then Array.iter Bb_queue.close self.qs; Array.iter Thread.join self.threads