diff --git a/src/d_pool_.ml b/src/d_pool_.ml index 2dedc339..13468359 100644 --- a/src/d_pool_.ml +++ b/src/d_pool_.ml @@ -2,15 +2,14 @@ type domain = Domain_.t type event = | Run of (unit -> unit) (** Run this function *) - | Decr - (** decrement number of threads on this domain. If it reaches 0, - wind down *) + | Die (** Nudge the domain, asking it to die *) (* State for a domain worker. It should not do too much except for starting new threads for pools. *) type worker_state = { q: event Bb_queue.t; th_count: int Atomic_.t; (** Number of threads on this *) + mutable domain: domain option; } (** Array of (optional) workers. @@ -22,37 +21,53 @@ let domains_ : worker_state option Lock.t array = let n = max 1 (Domain_.recommended_number () - 1) in Array.init n (fun _ -> Lock.create None) -let work_ idx (st : worker_state) : unit = +let work_ (st : worker_state) : unit = Dla_.setup_domain (); - while Atomic_.get st.th_count > 0 do + let continue = ref true in + while !continue do match Bb_queue.pop st.q with | Run f -> (try f () with _ -> ()) - | Decr -> - if Atomic_.fetch_and_add st.th_count (-1) = 1 then - Lock.set domains_.(idx) None + | Die -> continue := false done let[@inline] n_domains () : int = Array.length domains_ let run_on (i : int) (f : unit -> unit) : unit = assert (i < Array.length domains_); + let w = + Lock.update_map domains_.(i) (function + | Some w as st -> + Atomic_.incr w.th_count; + st, w + | None -> + let w = + { th_count = Atomic_.make 1; q = Bb_queue.create (); domain = None } + in + let worker : domain = Domain_.spawn (fun () -> work_ w) in + w.domain <- Some worker; + Some w, w) + in + Bb_queue.push w.q (Run f) - Lock.update domains_.(i) (function - | Some w as st -> - Atomic_.incr w.th_count; - Bb_queue.push w.q (Run f); - st - | None -> - let st = { th_count = Atomic_.make 1; q = Bb_queue.create () } in - let _domain : domain = Domain_.spawn (fun () -> work_ i st) in - Bb_queue.push st.q (Run f); - Some st) - -let decr_on (i : int) : unit = +let decr_on (i : int) ~(domain_to_join : Domain_.t -> unit) : unit = assert (i < Array.length domains_); - match Lock.get domains_.(i) with + let st_to_kill = + Lock.update_map domains_.(i) (function + | None -> assert false + | Some st -> + if Atomic_.fetch_and_add st.th_count (-1) = 1 then + None, Some st + else + Some st, None) + in + + (* prepare for domain termination outside of critical section *) + match st_to_kill with | None -> () - | Some st -> Bb_queue.push st.q Decr + | Some st -> + (* ask the domain to die *) + Bb_queue.push st.q Die; + Option.iter domain_to_join st.domain let run_on_and_wait (i : int) (f : unit -> 'a) : 'a = let q = Bb_queue.create () in diff --git a/src/d_pool_.mli b/src/d_pool_.mli index dd0902af..378c065a 100644 --- a/src/d_pool_.mli +++ b/src/d_pool_.mli @@ -15,8 +15,10 @@ val run_on : int -> (unit -> unit) -> unit (** [run_on i f] runs [f()] on the domain with index [i]. Precondition: [0 <= i < n_domains()] *) -val decr_on : int -> unit -(** Signal that a thread is stopping on the domain with index [i] *) +val decr_on : int -> domain_to_join:(Domain_.t -> unit) -> unit +(** Signal that a thread is stopping on the domain with index [i]. + @param domain_to_join called with a domain if this domain shuts down + because no one is using it anymore *) val run_on_and_wait : int -> (unit -> 'a) -> 'a (** [run_on_and_wait i f] runs [f()] on the domain with index [i], diff --git a/src/domain_.ml b/src/domain_.ml index 1878ab6d..60d1e669 100644 --- a/src/domain_.ml +++ b/src/domain_.ml @@ -8,6 +8,7 @@ type t = unit Domain.t let get_id (self : t) : int = (Domain.get_id self :> int) let spawn : _ -> t = Domain.spawn let relax = Domain.cpu_relax +let join = Domain.join [@@@ocaml.alert "+unstable"] [@@@else_] @@ -19,5 +20,6 @@ type t = Thread.t let get_id (self : t) : int = Thread.id self let spawn f : t = Thread.create f () let relax () = Thread.yield () +let join = Thread.join [@@@endif] diff --git a/src/pool.ml b/src/pool.ml index eded6b24..5faefc4d 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -22,6 +22,7 @@ type state = { active: bool A.t; threads: Thread.t array; qs: task Bb_queue.t array; + domains_to_join: Domain_.t Bb_queue.t; cur_q: int A.t; (** Selects queue into which to push *) } (** internal state *) @@ -158,7 +159,19 @@ let shutdown_ ~wait (self : state) : unit = (* close the job queues, which will fail future calls to [run], and wake up the subset of [self.threads] that are waiting on them. *) if was_active then Array.iter Bb_queue.close self.qs; - if wait then Array.iter Thread.join self.threads + if wait then Array.iter Thread.join self.threads; + Bb_queue.close self.domains_to_join; + + (* now join domains which need to be joined *) + while + match Bb_queue.pop self.domains_to_join with + | exception Bb_queue.Closed -> false + | d -> + Domain_.join d; + true + do + () + done type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> @@ -199,7 +212,13 @@ let create ?(on_init_thread = default_thread_init_exit_) let pool = let dummy = Thread.self () in - { active; threads = Array.make num_threads dummy; qs; cur_q = A.make 0 } + { + active; + threads = Array.make num_threads dummy; + qs; + cur_q = A.make 0; + domains_to_join = Bb_queue.create (); + } in let runner = @@ -241,7 +260,10 @@ let create ?(on_init_thread = default_thread_init_exit_) in (* now run the main loop *) - Fun.protect run' ~finally:(fun () -> D_pool_.decr_on dom_idx); + Fun.protect run' ~finally:(fun () -> + (* on termination, decrease refcount of underlying domain *) + D_pool_.decr_on dom_idx + ~domain_to_join:(Bb_queue.push pool.domains_to_join)); on_exit_thread ~dom_id:dom_idx ~t_id () in