fix: race condition in shutdown, we need to wait for domain to quit

risk is a tight loop of `Pool.with_`, where by not waiting for the pool
to entirely shutdown (including the domains, potentially) we risk
running out of domains in the next iterations.
This commit is contained in:
Simon Cruanes 2023-08-13 22:24:18 -04:00
parent 6c4d2cbc79
commit ed531e68e1
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
4 changed files with 68 additions and 27 deletions

View file

@ -2,15 +2,14 @@ type domain = Domain_.t
type event = type event =
| Run of (unit -> unit) (** Run this function *) | Run of (unit -> unit) (** Run this function *)
| Decr | Die (** Nudge the domain, asking it to die *)
(** decrement number of threads on this domain. If it reaches 0,
wind down *)
(* State for a domain worker. It should not do too much except for starting (* State for a domain worker. It should not do too much except for starting
new threads for pools. *) new threads for pools. *)
type worker_state = { type worker_state = {
q: event Bb_queue.t; q: event Bb_queue.t;
th_count: int Atomic_.t; (** Number of threads on this *) th_count: int Atomic_.t; (** Number of threads on this *)
mutable domain: domain option;
} }
(** Array of (optional) workers. (** Array of (optional) workers.
@ -22,37 +21,53 @@ let domains_ : worker_state option Lock.t array =
let n = max 1 (Domain_.recommended_number () - 1) in let n = max 1 (Domain_.recommended_number () - 1) in
Array.init n (fun _ -> Lock.create None) Array.init n (fun _ -> Lock.create None)
let work_ idx (st : worker_state) : unit = let work_ (st : worker_state) : unit =
Dla_.setup_domain (); Dla_.setup_domain ();
while Atomic_.get st.th_count > 0 do let continue = ref true in
while !continue do
match Bb_queue.pop st.q with match Bb_queue.pop st.q with
| Run f -> (try f () with _ -> ()) | Run f -> (try f () with _ -> ())
| Decr -> | Die -> continue := false
if Atomic_.fetch_and_add st.th_count (-1) = 1 then
Lock.set domains_.(idx) None
done done
let[@inline] n_domains () : int = Array.length domains_ let[@inline] n_domains () : int = Array.length domains_
let run_on (i : int) (f : unit -> unit) : unit = let run_on (i : int) (f : unit -> unit) : unit =
assert (i < Array.length domains_); assert (i < Array.length domains_);
let w =
Lock.update_map domains_.(i) (function
| Some w as st ->
Atomic_.incr w.th_count;
st, w
| None ->
let w =
{ th_count = Atomic_.make 1; q = Bb_queue.create (); domain = None }
in
let worker : domain = Domain_.spawn (fun () -> work_ w) in
w.domain <- Some worker;
Some w, w)
in
Bb_queue.push w.q (Run f)
Lock.update domains_.(i) (function let decr_on (i : int) ~(domain_to_join : Domain_.t -> unit) : unit =
| Some w as st ->
Atomic_.incr w.th_count;
Bb_queue.push w.q (Run f);
st
| None ->
let st = { th_count = Atomic_.make 1; q = Bb_queue.create () } in
let _domain : domain = Domain_.spawn (fun () -> work_ i st) in
Bb_queue.push st.q (Run f);
Some st)
let decr_on (i : int) : unit =
assert (i < Array.length domains_); assert (i < Array.length domains_);
match Lock.get domains_.(i) with let st_to_kill =
Lock.update_map domains_.(i) (function
| None -> assert false
| Some st ->
if Atomic_.fetch_and_add st.th_count (-1) = 1 then
None, Some st
else
Some st, None)
in
(* prepare for domain termination outside of critical section *)
match st_to_kill with
| None -> () | None -> ()
| Some st -> Bb_queue.push st.q Decr | Some st ->
(* ask the domain to die *)
Bb_queue.push st.q Die;
Option.iter domain_to_join st.domain
let run_on_and_wait (i : int) (f : unit -> 'a) : 'a = let run_on_and_wait (i : int) (f : unit -> 'a) : 'a =
let q = Bb_queue.create () in let q = Bb_queue.create () in

View file

@ -15,8 +15,10 @@ val run_on : int -> (unit -> unit) -> unit
(** [run_on i f] runs [f()] on the domain with index [i]. (** [run_on i f] runs [f()] on the domain with index [i].
Precondition: [0 <= i < n_domains()] *) Precondition: [0 <= i < n_domains()] *)
val decr_on : int -> unit val decr_on : int -> domain_to_join:(Domain_.t -> unit) -> unit
(** Signal that a thread is stopping on the domain with index [i] *) (** Signal that a thread is stopping on the domain with index [i].
@param domain_to_join called with a domain if this domain shuts down
because no one is using it anymore *)
val run_on_and_wait : int -> (unit -> 'a) -> 'a val run_on_and_wait : int -> (unit -> 'a) -> 'a
(** [run_on_and_wait i f] runs [f()] on the domain with index [i], (** [run_on_and_wait i f] runs [f()] on the domain with index [i],

View file

@ -8,6 +8,7 @@ type t = unit Domain.t
let get_id (self : t) : int = (Domain.get_id self :> int) let get_id (self : t) : int = (Domain.get_id self :> int)
let spawn : _ -> t = Domain.spawn let spawn : _ -> t = Domain.spawn
let relax = Domain.cpu_relax let relax = Domain.cpu_relax
let join = Domain.join
[@@@ocaml.alert "+unstable"] [@@@ocaml.alert "+unstable"]
[@@@else_] [@@@else_]
@ -19,5 +20,6 @@ type t = Thread.t
let get_id (self : t) : int = Thread.id self let get_id (self : t) : int = Thread.id self
let spawn f : t = Thread.create f () let spawn f : t = Thread.create f ()
let relax () = Thread.yield () let relax () = Thread.yield ()
let join = Thread.join
[@@@endif] [@@@endif]

View file

@ -22,6 +22,7 @@ type state = {
active: bool A.t; active: bool A.t;
threads: Thread.t array; threads: Thread.t array;
qs: task Bb_queue.t array; qs: task Bb_queue.t array;
domains_to_join: Domain_.t Bb_queue.t;
cur_q: int A.t; (** Selects queue into which to push *) cur_q: int A.t; (** Selects queue into which to push *)
} }
(** internal state *) (** internal state *)
@ -158,7 +159,19 @@ let shutdown_ ~wait (self : state) : unit =
(* close the job queues, which will fail future calls to [run], (* close the job queues, which will fail future calls to [run],
and wake up the subset of [self.threads] that are waiting on them. *) and wake up the subset of [self.threads] that are waiting on them. *)
if was_active then Array.iter Bb_queue.close self.qs; if was_active then Array.iter Bb_queue.close self.qs;
if wait then Array.iter Thread.join self.threads if wait then Array.iter Thread.join self.threads;
Bb_queue.close self.domains_to_join;
(* now join domains which need to be joined *)
while
match Bb_queue.pop self.domains_to_join with
| exception Bb_queue.Closed -> false
| d ->
Domain_.join d;
true
do
()
done
type ('a, 'b) create_args = type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
@ -199,7 +212,13 @@ let create ?(on_init_thread = default_thread_init_exit_)
let pool = let pool =
let dummy = Thread.self () in let dummy = Thread.self () in
{ active; threads = Array.make num_threads dummy; qs; cur_q = A.make 0 } {
active;
threads = Array.make num_threads dummy;
qs;
cur_q = A.make 0;
domains_to_join = Bb_queue.create ();
}
in in
let runner = let runner =
@ -241,7 +260,10 @@ let create ?(on_init_thread = default_thread_init_exit_)
in in
(* now run the main loop *) (* now run the main loop *)
Fun.protect run' ~finally:(fun () -> D_pool_.decr_on dom_idx); Fun.protect run' ~finally:(fun () ->
(* on termination, decrease refcount of underlying domain *)
D_pool_.decr_on dom_idx
~domain_to_join:(Bb_queue.push pool.domains_to_join));
on_exit_thread ~dom_id:dom_idx ~t_id () on_exit_thread ~dom_id:dom_idx ~t_id ()
in in