diff --git a/src/ws_pool.ml b/src/ws_pool.ml index 4d1e0c70..ca5d2500 100644 --- a/src/ws_pool.ml +++ b/src/ws_pool.ml @@ -8,7 +8,7 @@ let ( let@ ) = ( @@ ) type worker_state = { mutable thread: Thread.t; q: task WSQ.t; (** Work stealing queue *) - mutable work_steal_offset: int; (** Current offset for work stealing *) + rng: Random.State.t; } (** State for a given worker. Only this worker is allowed to push into the queue, but other workers @@ -111,39 +111,26 @@ let[@inline] wait_ (self : state) : unit = self.n_waiting <- self.n_waiting - 1; if self.n_waiting = 0 then self.n_waiting_nonzero <- false -(** Try to steal a task from the worker [w] *) +exception Got_task of task + +(** Try to steal a task *) let try_to_steal_work_once_ (self : state) (w : worker_state) : task option = - w.work_steal_offset <- (w.work_steal_offset + 1) mod Array.length self.workers; + let init = Random.State.int w.rng (Array.length self.workers) in - (* if we're pointing to [w], skip to the next worker as - it's useless to steal from oneself *) - if Array.unsafe_get self.workers w.work_steal_offset == w then - w.work_steal_offset <- - (w.work_steal_offset + 1) mod Array.length self.workers; + try + for i = 0 to Array.length self.workers - 1 do + let w' = + Array.unsafe_get self.workers ((i + init) mod Array.length self.workers) + in - let w' = Array.unsafe_get self.workers w.work_steal_offset in - WSQ.steal w'.q - -(** Try to steal work from several other workers. *) -let try_to_steal_work_loop (self : state) ~runner w : bool = - if size_ self = 1 then - (* no stealing for single thread pool *) - false - else ( - let has_stolen = ref false in - let n_retries_left = ref (size_ self - 1) in - - while !n_retries_left > 0 do - match try_to_steal_work_once_ self w with - | Some task -> - try_wake_someone_ self; - run_task_now_ self ~runner task; - has_stolen := true; - n_retries_left := 0 - | None -> decr n_retries_left + if w != w' then ( + match WSQ.steal w'.q with + | Some t -> raise_notrace (Got_task t) + | None -> () + ) done; - !has_stolen - ) + None + with Got_task t -> Some t (** Worker runs tasks from its queue until none remains *) let worker_run_self_tasks_ (self : state) ~runner w : unit = @@ -160,29 +147,41 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit = let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = TLS.get k_worker_state := Some w; - let main_loop () : unit = - let continue = ref true in - while !continue && A.get self.active do + let rec main () : unit = + if A.get self.active then ( worker_run_self_tasks_ self ~runner w; + try_steal () + ) + and run_task task : unit = + run_task_now_ self ~runner task; + main () + and try_steal () = + if A.get self.active then ( + match try_to_steal_work_once_ self w with + | Some task -> run_task task + | None -> wait () + ) + and wait () = + Mutex.lock self.mutex; + match Queue.pop self.main_q with + | task -> + Mutex.unlock self.mutex; + run_task task + | exception Queue.Empty -> + (* wait here *) + if A.get self.active then wait_ self; - let did_steal = try_to_steal_work_loop self ~runner w in - if not did_steal then ( - Mutex.lock self.mutex; - match Queue.pop self.main_q with - | task -> - Mutex.unlock self.mutex; - run_task_now_ self ~runner task - | exception Queue.Empty -> - if A.get self.active then wait_ self; - Mutex.unlock self.mutex - ) - done; - assert (WSQ.size w.q = 0) + (* see if a task became available *) + let task = try Some (Queue.pop self.main_q) with Queue.Empty -> None in + Mutex.unlock self.mutex; + + (match task with + | Some t -> run_task t + | None -> try_steal ()) in (* handle domain-local await *) - Dla_.using ~prepare_for_await:Suspend_.prepare_for_await - ~while_running:main_loop + Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () @@ -226,11 +225,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let workers : worker_state array = let dummy = Thread.self () in Array.init num_threads (fun i -> - { - thread = dummy; - q = WSQ.create (); - work_steal_offset = (i + 1) mod num_threads; - }) + { thread = dummy; q = WSQ.create (); rng = Random.State.make [| i |] }) in let pool =