diff --git a/src/pool.ml b/src/pool.ml index 8d0a8a20..df8992de 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -1,13 +1,5 @@ module WSQ = Ws_deque_ module A = Atomic_ - -module Int_tbl = Hashtbl.Make (struct - type t = int - - let equal : t -> t -> bool = ( = ) - let hash : t -> int = Hashtbl.hash -end) - include Runner let ( let@ ) = ( @@ ) @@ -31,7 +23,6 @@ type mut_cond = { type state = { active: bool Atomic.t; (** Becomes [false] when the pool is shutdown. *) workers: worker_state array; (** Fixed set of workers. *) - worker_by_id: worker_state Int_tbl.t; main_q: task Queue.t; (** Main queue for tasks coming from the outside *) mc: mut_cond; (** Used to block on [main_q] *) } @@ -47,9 +38,18 @@ let num_tasks_ (self : state) : int = Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; !n +exception Got_worker of worker_state + let[@inline] find_current_worker_ (self : state) : worker_state option = let self_id = Thread.id @@ Thread.self () in - Int_tbl.find_opt self.worker_by_id self_id + try + (* see if we're in one of the worker threads *) + for i = 0 to Array.length self.workers - 1 do + let w = self.workers.(i) in + if Thread.id w.thread = self_id then raise_notrace (Got_worker w) + done; + None + with Got_worker w -> Some w (** Run [task] as is, on the pool. *) let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = @@ -285,7 +285,6 @@ let create ?(on_init_thread = default_thread_init_exit_) { active = A.make true; workers; - worker_by_id = Int_tbl.create 8; main_q = Queue.create (); mc = { mutex = Mutex.create (); cond = Condition.create () }; } @@ -352,11 +351,7 @@ let create ?(on_init_thread = default_thread_init_exit_) for _j = 1 to num_threads do let i, th = Bb_queue.pop receive_threads in let worker_state = pool.workers.(i) in - worker_state.thread <- th; - - Mutex.lock pool.mc.mutex; - Int_tbl.add pool.worker_by_id (Thread.id th) worker_state; - Mutex.unlock pool.mc.mutex + worker_state.thread <- th done; runner