diff --git a/src/pool.ml b/src/pool.ml index ebc1cab1..b44ce6cf 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -1,4 +1,5 @@ module WSQ = Ws_deque_ +module A = Atomic_ include Runner let ( let@ ) = ( @@ ) @@ -11,20 +12,30 @@ type worker_state = { q: task WSQ.t; (** Work stealing queue *) } +type mut_cond = { + mutex: Mutex.t; + cond: Condition.t; +} + type state = { + active: bool Atomic.t; workers: worker_state array; - main_q: task Bb_queue.t; (** Main queue to block on *) + main_q: task Queue.t; (** Main queue to block on *) + mc: mut_cond; } (** internal state *) let[@inline] size_ (self : state) = Array.length self.workers let num_tasks_ (self : state) : int = - let n = ref (Bb_queue.size self.main_q) in + Mutex.lock self.mc.mutex; + let n = ref (Queue.length self.main_q) in + Mutex.unlock self.mc.mutex; Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; !n exception Got_worker of worker_state +exception Closed = Bb_queue.Closed let find_current_worker_ (self : state) : worker_state option = let self_id = Thread.id @@ Thread.self () in @@ -41,11 +52,22 @@ let find_current_worker_ (self : state) : worker_state option = let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = match w with | Some w -> - print_endline "push local"; - WSQ.push w.q task + WSQ.push w.q task; + + (* see if we need to wakeup other workers to come and steal from us *) + Mutex.lock self.mc.mutex; + if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond; + Mutex.unlock self.mc.mutex | None -> - print_endline "push blocking"; - Bb_queue.push self.main_q task + if A.get self.active then ( + (* push into the main queue *) + Mutex.lock self.mc.mutex; + let was_empty = Queue.is_empty self.main_q in + Queue.push task self.main_q; + if was_empty then Condition.broadcast self.mc.cond; + Mutex.unlock self.mc.mutex + ) else + raise Bb_queue.Closed let run_async_ (self : state) (task : task) : unit = (* stay on current worker if possible *) @@ -74,7 +96,7 @@ type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task let run_self_task_max_retry = 5 (** How many times in a row do we try to do work-stealing? *) -let steal_attempt_max_retry = 5 +let steal_attempt_max_retry = 7 let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn ~around_task : unit = @@ -92,7 +114,6 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn in let run_self_tasks_ () = - print_endline "run self tasks"; let continue = ref true in let pop_retries = ref 0 in while !continue do @@ -107,23 +128,34 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn done in + let work_steal_offset = ref 0 in + (* get a task from another worker *) - let try_to_steal_work () : task option = - print_endline "try to steal work"; + let rec try_to_steal_work () : task option = + let i = !work_steal_offset in + work_steal_offset := (i + 1) mod Array.length self.workers; + let w' = self.workers.(i) in + if w == w' then + try_to_steal_work () + else + WSQ.steal w'.q + in + + (* try - for _retry = 1 to 3 do - Array.iter - (fun w' -> - if w != w' then ( - match WSQ.steal w'.q with - | None -> () - | Some task -> raise_notrace (Got_task task) - )) - self.workers + for _retry = 1 to 1 do + for i = 0 to Array.length self.workers - 1 do + let w' = self.workers.(i) in + if w != w' then ( + match WSQ.steal w'.q with + | None -> () + | Some task -> raise_notrace (Got_task task) + ) + done done; None with Got_task task -> Some task - in + *) (* try to steal work multiple times *) let try_to_steal_work_loop () : bool = @@ -142,19 +174,28 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn with Exit -> true in + let get_task_from_main_queue_block () : task = + try + Mutex.lock self.mc.mutex; + while A.get self.active do + match Queue.pop self.main_q with + | exception Queue.Empty -> Condition.wait self.mc.cond self.mc.mutex + | task -> + Mutex.unlock self.mc.mutex; + raise_notrace (Got_task task) + done; + Mutex.unlock self.mc.mutex; + raise Bb_queue.Closed + with Got_task t -> t + in + let main_loop () = (try while true do run_self_tasks_ (); if not (try_to_steal_work_loop ()) then ( - Array.iteri - (fun i w -> Printf.printf "w[%d].q.size=%d\n" i (WSQ.size w.q)) - self.workers; - Printf.printf "bq.size=%d\n%!" (Bb_queue.size self.main_q); - - print_endline "wait block"; - let task = Bb_queue.pop self.main_q in + let task = get_task_from_main_queue_block () in run_task task ) done @@ -169,8 +210,12 @@ let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let shutdown_ ~wait (self : state) : unit = - Bb_queue.close self.main_q; - if wait then Array.iter (fun w -> Thread.join w.thread) self.workers + if A.exchange self.active false then ( + Mutex.lock self.mc.mutex; + Condition.broadcast self.mc.cond; + Mutex.unlock self.mc.mutex; + if wait then Array.iter (fun w -> Thread.join w.thread) self.workers + ) type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> @@ -208,7 +253,14 @@ let create ?(on_init_thread = default_thread_init_exit_) Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () }) in - let pool = { workers; main_q = Bb_queue.create () } in + let pool = + { + active = A.make true; + workers; + main_q = Queue.create (); + mc = { mutex = Mutex.create (); cond = Condition.create () }; + } + in let runner = Runner.For_runner_implementors.create diff --git a/src/pool.mli b/src/pool.mli index f7a42633..ae6699b2 100644 --- a/src/pool.mli +++ b/src/pool.mli @@ -34,6 +34,8 @@ type ('a, 'b) create_args = 'a (** Arguments used in {!create}. See {!create} for explanations. *) +exception Closed + val create : (unit -> t, _) create_args (** [create ()] makes a new thread pool. @param on_init_thread called at the beginning of each new thread