mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-06 03:05:30 -05:00
refactor: use a fixed size work-stealing deque
if it's full, tasks overflow into the main queue.
This commit is contained in:
parent
80031c0a54
commit
72f289af84
4 changed files with 72 additions and 108 deletions
|
|
@ -12,94 +12,68 @@ module A = Atomic_
|
|||
module CA : sig
|
||||
type 'a t
|
||||
|
||||
val create : log_size:int -> unit -> 'a t
|
||||
val size : _ t -> int
|
||||
val create : dummy:'a -> unit -> 'a t
|
||||
val size : 'a t -> int
|
||||
val get : 'a t -> int -> 'a
|
||||
val set : 'a t -> int -> 'a -> unit
|
||||
val grow : 'a t -> bottom:int -> top:int -> 'a t
|
||||
val shrink : 'a t -> bottom:int -> top:int -> 'a t
|
||||
end = struct
|
||||
type 'a t = {
|
||||
log_size: int;
|
||||
arr: 'a option array;
|
||||
}
|
||||
(** The array has size 256. *)
|
||||
let log_size = 8
|
||||
|
||||
let[@inline] size (self : _ t) = 1 lsl self.log_size
|
||||
type 'a t = { arr: 'a array } [@@unboxed]
|
||||
|
||||
let create ~log_size () : _ t =
|
||||
{ log_size; arr = Array.make (1 lsl log_size) None }
|
||||
let[@inline] size (_self : _ t) = 1 lsl log_size
|
||||
let create ~dummy () : _ t = { arr = Array.make (1 lsl log_size) dummy }
|
||||
|
||||
let[@inline] get (self : _ t) (i : int) : 'a =
|
||||
match Array.unsafe_get self.arr (i land ((1 lsl self.log_size) - 1)) with
|
||||
| Some x -> x
|
||||
| None -> assert false
|
||||
let[@inline] get (self : 'a t) (i : int) : 'a =
|
||||
Array.unsafe_get self.arr (i land ((1 lsl log_size) - 1))
|
||||
|
||||
let[@inline] set (self : 'a t) (i : int) (x : 'a) : unit =
|
||||
Array.unsafe_set self.arr (i land ((1 lsl self.log_size) - 1)) (Some x)
|
||||
|
||||
let grow (self : _ t) ~bottom ~top : 'a t =
|
||||
let new_arr = create ~log_size:(self.log_size + 1) () in
|
||||
for i = top to bottom - 1 do
|
||||
set new_arr i (get self i)
|
||||
done;
|
||||
new_arr
|
||||
|
||||
let shrink (self : _ t) ~bottom ~top : 'a t =
|
||||
let new_arr = create ~log_size:(self.log_size - 1) () in
|
||||
for i = top to bottom - 1 do
|
||||
set new_arr i (get self i)
|
||||
done;
|
||||
new_arr
|
||||
Array.unsafe_set self.arr (i land ((1 lsl log_size) - 1)) x
|
||||
end
|
||||
|
||||
type 'a t = {
|
||||
top: int A.t; (** Where we steal *)
|
||||
bottom: int A.t; (** Where we push/pop from the owning thread *)
|
||||
mutable top_cached: int; (** Last read value of [top] *)
|
||||
arr: 'a CA.t A.t; (** The circular array *)
|
||||
arr: 'a CA.t; (** The circular array *)
|
||||
}
|
||||
|
||||
let create () : _ t =
|
||||
let create ~dummy () : _ t =
|
||||
let top = A.make 0 in
|
||||
let arr = A.make @@ CA.create ~log_size:4 () in
|
||||
(* allocate far from top to avoid false sharing *)
|
||||
let arr = CA.create ~dummy () in
|
||||
(* allocate far from [top] to avoid false sharing *)
|
||||
let bottom = A.make 0 in
|
||||
{ top; top_cached = 0; bottom; arr }
|
||||
|
||||
let[@inline] size (self : _ t) : int = max 0 (A.get self.bottom - A.get self.top)
|
||||
|
||||
let push (self : 'a t) (x : 'a) : unit =
|
||||
exception Full
|
||||
|
||||
let push (self : 'a t) (x : 'a) : bool =
|
||||
try
|
||||
let b = A.get self.bottom in
|
||||
let t_approx = self.top_cached in
|
||||
let arr = ref (A.get self.arr) in
|
||||
|
||||
(* Section 2.3: over-approximation of size.
|
||||
Only if it seems too big do we actually read [t]. *)
|
||||
let size_approx = b - t_approx in
|
||||
if size_approx >= CA.size !arr - 1 then (
|
||||
if size_approx >= CA.size self.arr - 1 then (
|
||||
(* we need to read the actual value of [top], which might entail contention. *)
|
||||
let t = A.get self.top in
|
||||
self.top_cached <- t;
|
||||
let size = b - t in
|
||||
|
||||
if size >= CA.size !arr - 1 then (
|
||||
arr := CA.grow !arr ~top:t ~bottom:b;
|
||||
A.set self.arr !arr
|
||||
)
|
||||
if size >= CA.size self.arr - 1 then (* full! *) raise_notrace Full
|
||||
);
|
||||
|
||||
CA.set !arr b x;
|
||||
A.set self.bottom (b + 1)
|
||||
|
||||
let maybe_shrink_ (self : _ t) arr ~top ~bottom : unit =
|
||||
let size = bottom - top in
|
||||
let ca_size = CA.size arr in
|
||||
if ca_size >= 256 && size < ca_size / 3 then
|
||||
A.set self.arr (CA.shrink arr ~top ~bottom)
|
||||
CA.set self.arr b x;
|
||||
A.set self.bottom (b + 1);
|
||||
true
|
||||
with Full -> false
|
||||
|
||||
let pop (self : 'a t) : 'a option =
|
||||
let b = A.get self.bottom in
|
||||
let arr = A.get self.arr in
|
||||
let b = b - 1 in
|
||||
A.set self.bottom b;
|
||||
|
||||
|
|
@ -113,15 +87,14 @@ let pop (self : 'a t) : 'a option =
|
|||
None
|
||||
) else if size > 0 then (
|
||||
(* can pop without modifying [top] *)
|
||||
let x = CA.get arr b in
|
||||
maybe_shrink_ self arr ~bottom:b ~top:t;
|
||||
let x = CA.get self.arr b in
|
||||
Some x
|
||||
) else (
|
||||
assert (size = 0);
|
||||
(* there was exactly one slot, so we might be racing against stealers
|
||||
to update [self.top] *)
|
||||
if A.compare_and_set self.top t (t + 1) then (
|
||||
let x = CA.get arr b in
|
||||
let x = CA.get self.arr b in
|
||||
A.set self.bottom (t + 1);
|
||||
Some x
|
||||
) else (
|
||||
|
|
@ -135,13 +108,12 @@ let steal (self : 'a t) : 'a option =
|
|||
as we're in another thread *)
|
||||
let t = A.get self.top in
|
||||
let b = A.get self.bottom in
|
||||
let arr = A.get self.arr in
|
||||
|
||||
let size = b - t in
|
||||
if size <= 0 then
|
||||
None
|
||||
else (
|
||||
let x = CA.get arr t in
|
||||
let x = CA.get self.arr t in
|
||||
if A.compare_and_set self.top t (t + 1) then
|
||||
(* successfully increased top to consume [x] *)
|
||||
Some x
|
||||
|
|
|
|||
|
|
@ -6,14 +6,16 @@
|
|||
type 'a t
|
||||
(** Deque containing values of type ['a] *)
|
||||
|
||||
val create : unit -> _ t
|
||||
val create : dummy:'a -> unit -> 'a t
|
||||
(** Create a new deque. *)
|
||||
|
||||
val push : 'a t -> 'a -> unit
|
||||
(** Push value at the bottom of deque. This is not thread-safe. *)
|
||||
val push : 'a t -> 'a -> bool
|
||||
(** Push value at the bottom of deque. returns [true] if it succeeds.
|
||||
This must be called only by the owner thread. *)
|
||||
|
||||
val pop : 'a t -> 'a option
|
||||
(** Pop value from the bottom of deque. This is not thread-safe. *)
|
||||
(** Pop value from the bottom of deque.
|
||||
This must be called only by the owner thread. *)
|
||||
|
||||
val steal : 'a t -> 'a option
|
||||
(** Try to steal from the top of deque. This is thread-safe. *)
|
||||
|
|
|
|||
|
|
@ -60,8 +60,16 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit
|
|||
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
|
||||
match w with
|
||||
| Some w ->
|
||||
WSQ.push w.q task;
|
||||
let pushed = WSQ.push w.q task in
|
||||
if pushed then
|
||||
try_wake_someone_ self
|
||||
else (
|
||||
(* overflow into main queue *)
|
||||
Mutex.lock self.mutex;
|
||||
Queue.push task self.main_q;
|
||||
if self.n_waiting_nonzero then Condition.signal self.cond;
|
||||
Mutex.unlock self.mutex
|
||||
)
|
||||
| None ->
|
||||
if A.get self.active then (
|
||||
(* push into the main queue *)
|
||||
|
|
@ -202,6 +210,8 @@ type ('a, 'b) create_args =
|
|||
'a
|
||||
(** Arguments used in {!create}. See {!create} for explanations. *)
|
||||
|
||||
let dummy_task_ () = assert false
|
||||
|
||||
let create ?(on_init_thread = default_thread_init_exit_)
|
||||
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
|
||||
?around_task ?num_threads () : t =
|
||||
|
|
@ -221,7 +231,11 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
let workers : worker_state array =
|
||||
let dummy = Thread.self () in
|
||||
Array.init num_threads (fun i ->
|
||||
{ thread = dummy; q = WSQ.create (); rng = Random.State.make [| i |] })
|
||||
{
|
||||
thread = dummy;
|
||||
q = WSQ.create ~dummy:dummy_task_ ();
|
||||
rng = Random.State.make [| i |];
|
||||
})
|
||||
in
|
||||
|
||||
let pool =
|
||||
|
|
|
|||
|
|
@ -2,22 +2,23 @@ module A = Moonpool.Atomic
|
|||
module D = Moonpool.Private.Ws_deque_
|
||||
|
||||
let ( let@ ) = ( @@ )
|
||||
let dummy = -100
|
||||
|
||||
let t_simple () =
|
||||
let d = D.create () in
|
||||
let d = D.create ~dummy () in
|
||||
assert (D.steal d = None);
|
||||
assert (D.pop d = None);
|
||||
D.push d 1;
|
||||
D.push d 2;
|
||||
assert (D.push d 1);
|
||||
assert (D.push d 2);
|
||||
assert (D.pop d = Some 2);
|
||||
assert (D.steal d = Some 1);
|
||||
assert (D.steal d = None);
|
||||
assert (D.pop d = None);
|
||||
D.push d 3;
|
||||
assert (D.push d 3);
|
||||
assert (D.pop d = Some 3);
|
||||
D.push d 4;
|
||||
D.push d 5;
|
||||
D.push d 6;
|
||||
assert (D.push d 4);
|
||||
assert (D.push d 5);
|
||||
assert (D.push d 6);
|
||||
assert (D.steal d = Some 4);
|
||||
assert (D.steal d = Some 5);
|
||||
assert (D.pop d = Some 6);
|
||||
|
|
@ -35,7 +36,7 @@ let t_heavy () =
|
|||
|
||||
let active = A.make true in
|
||||
|
||||
let d = D.create () in
|
||||
let d = D.create ~dummy () in
|
||||
|
||||
let stealer_loop () =
|
||||
Trace.set_thread_name "stealer";
|
||||
|
|
@ -51,11 +52,13 @@ let t_heavy () =
|
|||
Trace.set_thread_name "producer";
|
||||
for _i = 1 to 100_000 do
|
||||
let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main.outer" in
|
||||
|
||||
(* NOTE: we make sure to push less than 256 elements at once *)
|
||||
for j = 1 to 100 do
|
||||
ref_sum := !ref_sum + j;
|
||||
D.push d j;
|
||||
assert (D.push d j);
|
||||
ref_sum := !ref_sum + j;
|
||||
D.push d j;
|
||||
assert (D.push d j);
|
||||
|
||||
Option.iter (fun x -> add_to_sum x) (D.pop d);
|
||||
Option.iter (fun x -> add_to_sum x) (D.pop d)
|
||||
|
|
@ -92,35 +95,8 @@ let t_heavy () =
|
|||
assert (ref_sum = sum);
|
||||
()
|
||||
|
||||
let t_many () =
|
||||
print_endline "pushing many elements";
|
||||
let d = D.create () in
|
||||
|
||||
let push_and_pop count =
|
||||
for i = 1 to count do
|
||||
(* if i mod 100_000 = 0 then Printf.printf "push %d\n%!" i; *)
|
||||
D.push d i
|
||||
done;
|
||||
let n = ref 0 in
|
||||
|
||||
let continue = ref true in
|
||||
while !continue do
|
||||
match D.pop d with
|
||||
| None -> continue := false
|
||||
| Some _ -> incr n
|
||||
done;
|
||||
assert (!n = count)
|
||||
in
|
||||
|
||||
push_and_pop 10_000;
|
||||
push_and_pop 100_000;
|
||||
push_and_pop 1_000_000;
|
||||
print_endline "pushing many elements: ok";
|
||||
()
|
||||
|
||||
let () =
|
||||
let@ () = Trace_tef.with_setup () in
|
||||
t_simple ();
|
||||
t_heavy ();
|
||||
t_many ();
|
||||
()
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue