module WSQ = Ws_deque_ module A = Atomic_ module Int_tbl = Hashtbl.Make (struct type t = int let equal : t -> t -> bool = ( = ) let hash : t -> int = Hashtbl.hash end) include Runner let ( let@ ) = ( @@ ) type thread_loop_wrapper = thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit type worker_state = { mutable thread: Thread.t; q: task WSQ.t; (** Work stealing queue *) } (** State for a given worker. Only this worker is allowed to push into the queue, but other workers can come and steal from it if they're idle. *) type mut_cond = { mutex: Mutex.t; cond: Condition.t; } type state = { active: bool Atomic.t; (** Becomes [false] when the pool is shutdown. *) workers: worker_state array; (** Fixed set of workers. *) worker_by_id: worker_state Int_tbl.t; main_q: task Queue.t; (** Main queue for tasks coming from the outside *) mc: mut_cond; (** Used to block on [main_q] *) } (** internal state *) let[@inline] size_ (self : state) = Array.length self.workers let num_tasks_ (self : state) : int = let n = ref 0 in Mutex.lock self.mc.mutex; n := Queue.length self.main_q; Mutex.unlock self.mc.mutex; Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; !n let[@inline] find_current_worker_ (self : state) : worker_state option = let self_id = Thread.id @@ Thread.self () in Int_tbl.find_opt self.worker_by_id self_id (** Run [task] as is, on the pool. *) let run_direct_ (self : state) (w : worker_state option) (task : task) : unit = match w with | Some w -> WSQ.push w.q task; (* see if we need to wakeup other workers to come and steal from us *) Mutex.lock self.mc.mutex; if Queue.is_empty self.main_q then Condition.broadcast self.mc.cond; Mutex.unlock self.mc.mutex | None -> if A.get self.active then ( (* push into the main queue *) Mutex.lock self.mc.mutex; let was_empty = Queue.is_empty self.main_q in Queue.push task self.main_q; if was_empty then Condition.broadcast self.mc.cond; Mutex.unlock self.mc.mutex ) else (* notify the caller that scheduling tasks is no longer permitted *) raise Shutdown let run_async_ (self : state) (task : task) : unit = (* stay on current worker if possible *) let w = find_current_worker_ self in let rec run_async_rec_ (task : task) = let task_with_suspend_ () = (* run [f()] and handle [suspend] in it *) Suspend_.with_suspend task ~run:(fun ~with_handler task' -> if with_handler then run_async_rec_ task' else run_direct_ self w task') in run_direct_ self w task_with_suspend_ in run_async_rec_ task let run = run_async exception Got_task of task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task (** How many times in a row do we try to read the next local task? *) let run_self_task_max_retry = 5 (** How many times in a row do we try to do work-stealing? *) let steal_attempt_max_retry = 7 let worker_thread_ (self : state) (runner : t) (w : worker_state) ~on_exn ~around_task : unit = let (AT_pair (before_task, after_task)) = around_task in (* run this task. *) let run_task task : unit = let _ctx = before_task runner in (* run the task now, catching errors *) (try task () with e -> let bt = Printexc.get_raw_backtrace () in on_exn e bt); after_task runner _ctx in let run_self_tasks_ () = let continue = ref true in let pop_retries = ref 0 in while !continue do match WSQ.pop w.q with | Some task -> pop_retries := 0; run_task task | None -> Domain_.relax (); incr pop_retries; if !pop_retries > run_self_task_max_retry then continue := false done in let work_steal_offset = ref 0 in (* get a task from another worker *) let try_to_steal_work () : task option = assert (size_ self > 1); work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers; (* if we're pointing to [w], skip to the next worker as it's useless to steal from oneself *) if self.workers.(!work_steal_offset) == w then work_steal_offset := (!work_steal_offset + 1) mod Array.length self.workers; let w' = self.workers.(!work_steal_offset) in assert (w != w'); WSQ.steal w'.q in (* try for _retry = 1 to 1 do for i = 0 to Array.length self.workers - 1 do let w' = self.workers.(i) in if w != w' then ( match WSQ.steal w'.q with | None -> () | Some task -> raise_notrace (Got_task task) ) done done; None with Got_task task -> Some task *) (* try to steal work multiple times *) let try_to_steal_work_loop () : bool = if size_ self = 1 then (* no stealing for single thread pool *) false else ( try let unsuccessful_steal_attempts = ref 0 in while !unsuccessful_steal_attempts < steal_attempt_max_retry do match try_to_steal_work () with | Some task -> run_task task; raise_notrace Exit | None -> incr unsuccessful_steal_attempts; Domain_.relax () done; false with Exit -> true ) in let get_task_from_main_queue_block () : task option = try Mutex.lock self.mc.mutex; while true do match Queue.pop self.main_q with | exception Queue.Empty -> if A.get self.active then Condition.wait self.mc.cond self.mc.mutex else ( (* empty queue and we're closed, time to exit *) Mutex.unlock self.mc.mutex; raise_notrace Exit ) | task -> Mutex.unlock self.mc.mutex; raise_notrace (Got_task task) done; (* unreachable *) assert false with | Got_task t -> Some t | Exit -> None in let main_loop () = let continue = ref true in while !continue do run_self_tasks_ (); let did_steal = try_to_steal_work_loop () in if not did_steal then ( match get_task_from_main_queue_block () with | None -> (* main queue is closed *) continue := false | Some task -> run_task task ) done; assert (WSQ.size w.q = 0) in (* handle domain-local await *) Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main_loop let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let shutdown_ ~wait (self : state) : unit = if A.exchange self.active false then ( Mutex.lock self.mc.mutex; Condition.broadcast self.mc.cond; Mutex.unlock self.mc.mutex; if wait then Array.iter (fun w -> Thread.join w.thread) self.workers ) type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?thread_wrappers:thread_loop_wrapper list -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?around_task:(t -> 'b) * (t -> 'b -> unit) -> ?min:int -> ?per_domain:int -> 'a (** Arguments used in {!create}. See {!create} for explanations. *) let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = []) ?(on_exn = fun _ _ -> ()) ?around_task ?min:(min_threads = 1) ?(per_domain = 0) () : t = (* wrapper *) let around_task = match around_task with | Some (f, g) -> AT_pair (f, g) | None -> AT_pair (ignore, fun _ _ -> ()) in (* number of threads to run *) let min_threads = max 1 min_threads in let num_domains = D_pool_.n_domains () in assert (num_domains >= 1); let num_threads = max min_threads (num_domains * per_domain) in (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) let offset = Random.int num_domains in let workers : worker_state array = let dummy = Thread.self () in Array.init num_threads (fun _ -> { thread = dummy; q = WSQ.create () }) in let pool = { active = A.make true; workers; worker_by_id = Int_tbl.create 8; main_q = Queue.create (); mc = { mutex = Mutex.create (); cond = Condition.create () }; } in let runner = Runner.For_runner_implementors.create ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) ~run_async:(fun f -> run_async_ pool f) ~size:(fun () -> size_ pool) ~num_tasks:(fun () -> num_tasks_ pool) () in (* temporary queue used to obtain thread handles from domains on which the thread are started. *) let receive_threads = Bb_queue.create () in (* start the thread with index [i] *) let start_thread_with_idx i = let w = pool.workers.(i) in let dom_idx = (offset + i) mod num_domains in (* function run in the thread itself *) let main_thread_fun () : unit = let thread = Thread.self () in let t_id = Thread.id thread in on_init_thread ~dom_id:dom_idx ~t_id (); let run () = worker_thread_ pool runner w ~on_exn ~around_task in (* the actual worker loop is [worker_thread_], with all wrappers for this pool and for all pools (global_thread_wrappers_) *) let run' = List.fold_left (fun run f -> f ~thread ~pool:runner run) run thread_wrappers in (* now run the main loop *) Fun.protect run' ~finally:(fun () -> (* on termination, decrease refcount of underlying domain *) D_pool_.decr_on dom_idx); on_exit_thread ~dom_id:dom_idx ~t_id () in (* function called in domain with index [i], to create the thread and push it into [receive_threads] *) let create_thread_in_domain () = let thread = Thread.create main_thread_fun () in (* send the thread from the domain back to us *) Bb_queue.push receive_threads (i, thread) in D_pool_.run_on dom_idx create_thread_in_domain in (* start all threads, placing them on the domains according to their index and [offset] in a round-robin fashion. *) for i = 0 to num_threads - 1 do start_thread_with_idx i done; (* receive the newly created threads back from domains *) for _j = 1 to num_threads do let i, th = Bb_queue.pop receive_threads in let worker_state = pool.workers.(i) in worker_state.thread <- th; Mutex.lock pool.mc.mutex; Int_tbl.add pool.worker_by_id (Thread.id th) worker_state; Mutex.unlock pool.mc.mutex done; runner let with_ ?on_init_thread ?on_exit_thread ?thread_wrappers ?on_exn ?around_task ?min ?per_domain () f = let pool = create ?on_init_thread ?on_exit_thread ?thread_wrappers ?on_exn ?around_task ?min ?per_domain () in let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in f pool