diff --git a/src/pool.ml b/src/pool.ml index 8d74ab65..8d0a8a20 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -1,5 +1,13 @@ module WSQ = Ws_deque_ module A = Atomic_ + +module Int_tbl = Hashtbl.Make (struct + type t = int + + let equal : t -> t -> bool = ( = ) + let hash : t -> int = Hashtbl.hash +end) + include Runner let ( let@ ) = ( @@ ) @@ -11,6 +19,9 @@ type worker_state = { mutable thread: Thread.t; q: task WSQ.t; (** Work stealing queue *) } +(** State for a given worker. Only this worker is + allowed to push into the queue, but other workers + can come and steal from it if they're idle. *) type mut_cond = { mutex: Mutex.t; @@ -18,35 +29,27 @@ type mut_cond = { } type state = { - active: bool Atomic.t; - workers: worker_state array; - main_q: task Queue.t; (** Main queue to block on *) - mc: mut_cond; + active: bool Atomic.t; (** Becomes [false] when the pool is shutdown. *) + workers: worker_state array; (** Fixed set of workers. *) + worker_by_id: worker_state Int_tbl.t; + main_q: task Queue.t; (** Main queue for tasks coming from the outside *) + mc: mut_cond; (** Used to block on [main_q] *) } (** internal state *) let[@inline] size_ (self : state) = Array.length self.workers let num_tasks_ (self : state) : int = + let n = ref 0 in Mutex.lock self.mc.mutex; - let n = ref (Queue.length self.main_q) in + n := Queue.length self.main_q; Mutex.unlock self.mc.mutex; Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; !n -exception Got_worker of worker_state -exception Closed = Bb_queue.Closed - -let find_current_worker_ (self : state) : worker_state option = +let[@inline] find_current_worker_ (self : state) : worker_state option = let self_id = Thread.id @@ Thread.self () in - try - (* see if we're in one of the worker threads *) - for i = 0 to Array.length self.workers - 1 do - let w = self.workers.(i) in - if Thread.id w.thread = self_id then raise_notrace (Got_worker w) - done; - None - with Got_worker w -> Some w + Int_tbl.find_opt self.worker_by_id self_id (** Run [task] as is, on the pool. *) let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = @@ -133,14 +136,20 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn let work_steal_offset = ref 0 in (* get a task from another worker *) - let rec try_to_steal_work () : task option = - let i = !work_steal_offset in - work_steal_offset := (i + 1) mod Array.length self.workers; - let w' = self.workers.(i) in - if w == w' then - try_to_steal_work () - else - WSQ.steal w'.q + let try_to_steal_work () : task option = + assert (size_ self > 1); + + work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers; + + (* if we're pointing to [w], skip to the next worker as + it's useless to steal from oneself *) + if self.workers.(!work_steal_offset) == w then + work_steal_offset := + (!work_steal_offset + 1) mod Array.length self.workers; + + let w' = self.workers.(!work_steal_offset) in + assert (w != w'); + WSQ.steal w'.q in (* @@ -161,48 +170,65 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn (* try to steal work multiple times *) let try_to_steal_work_loop () : bool = - try - let unsuccessful_steal_attempts = ref 0 in - while !unsuccessful_steal_attempts < steal_attempt_max_retry do - match try_to_steal_work () with - | Some task -> - run_task task; - raise_notrace Exit - | None -> - incr unsuccessful_steal_attempts; - Domain_.relax () - done; + if size_ self = 1 then + (* no stealing for single thread pool *) false - with Exit -> true + else ( + try + let unsuccessful_steal_attempts = ref 0 in + while !unsuccessful_steal_attempts < steal_attempt_max_retry do + match try_to_steal_work () with + | Some task -> + run_task task; + raise_notrace Exit + | None -> + incr unsuccessful_steal_attempts; + Domain_.relax () + done; + false + with Exit -> true + ) in - let get_task_from_main_queue_block () : task = + let get_task_from_main_queue_block () : task option = try Mutex.lock self.mc.mutex; - while A.get self.active do + while true do match Queue.pop self.main_q with - | exception Queue.Empty -> Condition.wait self.mc.cond self.mc.mutex + | exception Queue.Empty -> + if A.get self.active then + Condition.wait self.mc.cond self.mc.mutex + else ( + (* empty queue and we're closed, time to exit *) + Mutex.unlock self.mc.mutex; + raise_notrace Exit + ) | task -> Mutex.unlock self.mc.mutex; raise_notrace (Got_task task) done; - Mutex.unlock self.mc.mutex; - raise Shutdown - with Got_task t -> t + (* unreachable *) + assert false + with + | Got_task t -> Some t + | Exit -> None in let main_loop () = - (try - while true do - run_self_tasks_ (); + let continue = ref true in + while !continue do + run_self_tasks_ (); - if not (try_to_steal_work_loop ()) then ( - let task = get_task_from_main_queue_block () in - run_task task - ) - done - with Shutdown -> ()); - run_self_tasks_ () + let did_steal = try_to_steal_work_loop () in + if not did_steal then ( + match get_task_from_main_queue_block () with + | None -> + (* main queue is closed *) + continue := false + | Some task -> run_task task + ) + done; + assert (WSQ.size w.q = 0) in (* handle domain-local await *) @@ -259,6 +285,7 @@ let create ?(on_init_thread = default_thread_init_exit_) { active = A.make true; workers; + worker_by_id = Int_tbl.create 8; main_q = Queue.create (); mc = { mutex = Mutex.create (); cond = Condition.create () }; } @@ -324,7 +351,12 @@ let create ?(on_init_thread = default_thread_init_exit_) (* receive the newly created threads back from domains *) for _j = 1 to num_threads do let i, th = Bb_queue.pop receive_threads in - pool.workers.(i).thread <- th + let worker_state = pool.workers.(i) in + worker_state.thread <- th; + + Mutex.lock pool.mc.mutex; + Int_tbl.add pool.worker_by_id (Thread.id th) worker_state; + Mutex.unlock pool.mc.mutex done; runner