From 69faea0bcbc269a104665e6dd3811269e2ff8b68 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 24 Oct 2023 12:53:19 -0400 Subject: [PATCH] wip: have only one condition in pool --- src/pool.ml | 133 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 100 insertions(+), 33 deletions(-) diff --git a/src/pool.ml b/src/pool.ml index e8e4c863..e4685b9c 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -5,6 +5,48 @@ include Runner let ( let@ ) = ( @@ ) +(** Thread safe queue, non blocking *) +module TS_queue = struct + type 'a t = { + mutex: Mutex.t; + q: 'a Queue.t; + } + + let create () : _ t = { mutex = Mutex.create (); q = Queue.create () } + + let try_push (self : _ t) x : bool = + if Mutex.try_lock self.mutex then ( + Queue.push x self.q; + Mutex.unlock self.mutex; + true + ) else + false + + let push (self : _ t) x : unit = + Mutex.lock self.mutex; + Queue.push x self.q; + Mutex.unlock self.mutex + + let try_pop ~force_lock (self : _ t) : _ option = + let has_lock = + if force_lock then ( + Mutex.lock self.mutex; + true + ) else + Mutex.try_lock self.mutex + in + if has_lock then ( + match Queue.pop self.q with + | x -> + Mutex.unlock self.mutex; + Some x + | exception Queue.Empty -> + Mutex.unlock self.mutex; + None + ) else + None +end + type thread_loop_wrapper = thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit @@ -21,11 +63,22 @@ let add_global_thread_loop_wrapper f : unit = type state = { active: bool A.t; threads: Thread.t array; - qs: task Bb_queue.t array; + qs: task TS_queue.t array; + num_tasks: int A.t; + mutex: Mutex.t; + cond: Condition.t; cur_q: int A.t; (** Selects queue into which to push *) } (** internal state *) +let[@inline] size_ (self : state) = Array.length self.threads +let[@inline] num_tasks_ (self : state) : int = A.get self.num_tasks + +let awake_workers_ (self : state) : unit = + Mutex.lock self.mutex; + Condition.broadcast self.cond; + Mutex.unlock self.mutex + (** Run [task] as is, on the pool. *) let run_direct_ (self : state) (task : task) : unit = let n_qs = Array.length self.qs in @@ -35,22 +88,22 @@ let run_direct_ (self : state) (task : task) : unit = let[@inline] push_wait f = let q_idx = offset mod Array.length self.qs in let q = self.qs.(q_idx) in - Bb_queue.push q f + TS_queue.push q f in + let old_num_tasks = A.fetch_and_add self.num_tasks 1 in + try (* try each queue with a round-robin initial offset *) for _retry = 1 to 10 do for i = 0 to n_qs - 1 do let q_idx = (i + offset) mod Array.length self.qs in let q = self.qs.(q_idx) in - if Bb_queue.try_push q task then raise_notrace Exit + if TS_queue.try_push q task then raise_notrace Exit done done; push_wait task - with - | Exit -> () - | Bb_queue.Closed -> raise Shutdown + with Exit -> if old_num_tasks < size_ self then awake_workers_ self let rec run_async_ (self : state) (task : task) : unit = let task' () = @@ -64,12 +117,6 @@ let rec run_async_ (self : state) (task : task) : unit = run_direct_ self task' let run = run_async -let size_ (self : state) = Array.length self.threads - -let num_tasks_ (self : state) : int = - let n = ref 0 in - Array.iter (fun q -> n := !n + Bb_queue.size q) self.qs; - !n [@@@ifge 5.0] @@ -103,27 +150,41 @@ exception Got_task of task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task -let worker_thread_ (runner : t) ~on_exn ~around_task (active : bool A.t) - (qs : task Bb_queue.t array) ~(offset : int) : unit = - let num_qs = Array.length qs in +exception Closed + +let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task + ~(offset : int) : unit = + let num_qs = Array.length self.qs in let (AT_pair (before_task, after_task)) = around_task in let get_task_without_blocking () : _ option = try for i = 0 to num_qs - 1 do - let q = qs.((offset + i) mod num_qs) in - match Bb_queue.try_pop ~force_lock:false q with + let q = self.qs.((offset + i) mod num_qs) in + match TS_queue.try_pop ~force_lock:false q with | Some f -> raise_notrace (Got_task f) | None -> () done; None - with Got_task f -> Some f + with Got_task f -> + A.decr self.num_tasks; + Some f in - (* last resort: block on my queue *) - let[@inline] pop_blocking () = - let my_q = qs.(offset mod num_qs) in - Bb_queue.pop my_q + (* last resort: block on condition or raise Closed *) + let pop_blocking () : task = + Mutex.lock self.mutex; + + try + while A.get self.active do + match get_task_without_blocking () with + | Some t -> + Mutex.unlock self.mutex; + raise_notrace (Got_task t) + | None -> Condition.wait self.cond self.mutex + done; + raise Closed + with Got_task t -> t in let run_task task : unit = @@ -147,12 +208,13 @@ let worker_thread_ (runner : t) ~on_exn ~around_task (active : bool A.t) in let main_loop () = - while A.get active do + while A.get self.active do run_tasks_already_present (); (* no task available, block until one comes *) - let task = pop_blocking () in - run_task task + match pop_blocking () with + | exception Closed -> () + | task -> run_task task done; (* cleanup *) @@ -176,9 +238,8 @@ let max_queues = 32 let shutdown_ ~wait (self : state) : unit = let was_active = A.exchange self.active false in - (* close the job queues, which will fail future calls to [run], - and wake up the subset of [self.threads] that are waiting on them. *) - if was_active then Array.iter Bb_queue.close self.qs; + (* wake up the subset of [self.threads] that are waiting on new tasks *) + if was_active then awake_workers_ self; if wait then Array.iter Thread.join self.threads type ('a, 'b) create_args = @@ -215,12 +276,20 @@ let create ?(on_init_thread = default_thread_init_exit_) let active = A.make true in let qs = let num_qs = min (min num_domains num_threads) max_queues in - Array.init num_qs (fun _ -> Bb_queue.create ()) + Array.init num_qs (fun _ -> TS_queue.create ()) in let pool = let dummy = Thread.self () in - { active; threads = Array.make num_threads dummy; qs; cur_q = A.make 0 } + { + active; + threads = Array.make num_threads dummy; + num_tasks = A.make 0; + qs; + mutex = Mutex.create (); + cond = Condition.create (); + cur_q = A.make 0; + } in let runner = @@ -250,9 +319,7 @@ let create ?(on_init_thread = default_thread_init_exit_) List.rev_append thread_wrappers (A.get global_thread_wrappers_) in - let run () = - worker_thread_ runner ~on_exn ~around_task active qs ~offset:i - in + let run () = worker_thread_ pool runner ~on_exn ~around_task ~offset:i in (* the actual worker loop is [worker_thread_], with all wrappers for this pool and for all pools (global_thread_wrappers_) *) let run' =