diff --git a/src/ws_deque_.ml b/src/ws_deque_.ml index 137e1c15..6c5d1419 100644 --- a/src/ws_deque_.ml +++ b/src/ws_deque_.ml @@ -12,94 +12,68 @@ module A = Atomic_ module CA : sig type 'a t - val create : log_size:int -> unit -> 'a t - val size : _ t -> int + val create : dummy:'a -> unit -> 'a t + val size : 'a t -> int val get : 'a t -> int -> 'a val set : 'a t -> int -> 'a -> unit - val grow : 'a t -> bottom:int -> top:int -> 'a t - val shrink : 'a t -> bottom:int -> top:int -> 'a t end = struct - type 'a t = { - log_size: int; - arr: 'a option array; - } + (** The array has size 256. *) + let log_size = 8 - let[@inline] size (self : _ t) = 1 lsl self.log_size + type 'a t = { arr: 'a array } [@@unboxed] - let create ~log_size () : _ t = - { log_size; arr = Array.make (1 lsl log_size) None } + let[@inline] size (_self : _ t) = 1 lsl log_size + let create ~dummy () : _ t = { arr = Array.make (1 lsl log_size) dummy } - let[@inline] get (self : _ t) (i : int) : 'a = - match Array.unsafe_get self.arr (i land ((1 lsl self.log_size) - 1)) with - | Some x -> x - | None -> assert false + let[@inline] get (self : 'a t) (i : int) : 'a = + Array.unsafe_get self.arr (i land ((1 lsl log_size) - 1)) let[@inline] set (self : 'a t) (i : int) (x : 'a) : unit = - Array.unsafe_set self.arr (i land ((1 lsl self.log_size) - 1)) (Some x) - - let grow (self : _ t) ~bottom ~top : 'a t = - let new_arr = create ~log_size:(self.log_size + 1) () in - for i = top to bottom - 1 do - set new_arr i (get self i) - done; - new_arr - - let shrink (self : _ t) ~bottom ~top : 'a t = - let new_arr = create ~log_size:(self.log_size - 1) () in - for i = top to bottom - 1 do - set new_arr i (get self i) - done; - new_arr + Array.unsafe_set self.arr (i land ((1 lsl log_size) - 1)) x end type 'a t = { top: int A.t; (** Where we steal *) bottom: int A.t; (** Where we push/pop from the owning thread *) mutable top_cached: int; (** Last read value of [top] *) - arr: 'a CA.t A.t; (** The circular array *) + arr: 'a CA.t; (** The circular array *) } -let create () : _ t = +let create ~dummy () : _ t = let top = A.make 0 in - let arr = A.make @@ CA.create ~log_size:4 () in - (* allocate far from top to avoid false sharing *) + let arr = CA.create ~dummy () in + (* allocate far from [top] to avoid false sharing *) let bottom = A.make 0 in { top; top_cached = 0; bottom; arr } let[@inline] size (self : _ t) : int = max 0 (A.get self.bottom - A.get self.top) -let push (self : 'a t) (x : 'a) : unit = - let b = A.get self.bottom in - let t_approx = self.top_cached in - let arr = ref (A.get self.arr) in +exception Full - (* Section 2.3: over-approximation of size. - Only if it seems too big do we actually read [t]. *) - let size_approx = b - t_approx in - if size_approx >= CA.size !arr - 1 then ( - (* we need to read the actual value of [top], which might entail contention. *) - let t = A.get self.top in - self.top_cached <- t; - let size = b - t in +let push (self : 'a t) (x : 'a) : bool = + try + let b = A.get self.bottom in + let t_approx = self.top_cached in - if size >= CA.size !arr - 1 then ( - arr := CA.grow !arr ~top:t ~bottom:b; - A.set self.arr !arr - ) - ); + (* Section 2.3: over-approximation of size. + Only if it seems too big do we actually read [t]. *) + let size_approx = b - t_approx in + if size_approx >= CA.size self.arr - 1 then ( + (* we need to read the actual value of [top], which might entail contention. *) + let t = A.get self.top in + self.top_cached <- t; + let size = b - t in - CA.set !arr b x; - A.set self.bottom (b + 1) + if size >= CA.size self.arr - 1 then (* full! *) raise_notrace Full + ); -let maybe_shrink_ (self : _ t) arr ~top ~bottom : unit = - let size = bottom - top in - let ca_size = CA.size arr in - if ca_size >= 256 && size < ca_size / 3 then - A.set self.arr (CA.shrink arr ~top ~bottom) + CA.set self.arr b x; + A.set self.bottom (b + 1); + true + with Full -> false let pop (self : 'a t) : 'a option = let b = A.get self.bottom in - let arr = A.get self.arr in let b = b - 1 in A.set self.bottom b; @@ -113,15 +87,14 @@ let pop (self : 'a t) : 'a option = None ) else if size > 0 then ( (* can pop without modifying [top] *) - let x = CA.get arr b in - maybe_shrink_ self arr ~bottom:b ~top:t; + let x = CA.get self.arr b in Some x ) else ( assert (size = 0); (* there was exactly one slot, so we might be racing against stealers to update [self.top] *) if A.compare_and_set self.top t (t + 1) then ( - let x = CA.get arr b in + let x = CA.get self.arr b in A.set self.bottom (t + 1); Some x ) else ( @@ -135,13 +108,12 @@ let steal (self : 'a t) : 'a option = as we're in another thread *) let t = A.get self.top in let b = A.get self.bottom in - let arr = A.get self.arr in let size = b - t in if size <= 0 then None else ( - let x = CA.get arr t in + let x = CA.get self.arr t in if A.compare_and_set self.top t (t + 1) then (* successfully increased top to consume [x] *) Some x diff --git a/src/ws_deque_.mli b/src/ws_deque_.mli index 0b243f68..bead45aa 100644 --- a/src/ws_deque_.mli +++ b/src/ws_deque_.mli @@ -6,14 +6,16 @@ type 'a t (** Deque containing values of type ['a] *) -val create : unit -> _ t +val create : dummy:'a -> unit -> 'a t (** Create a new deque. *) -val push : 'a t -> 'a -> unit -(** Push value at the bottom of deque. This is not thread-safe. *) +val push : 'a t -> 'a -> bool +(** Push value at the bottom of deque. returns [true] if it succeeds. + This must be called only by the owner thread. *) val pop : 'a t -> 'a option -(** Pop value from the bottom of deque. This is not thread-safe. *) +(** Pop value from the bottom of deque. + This must be called only by the owner thread. *) val steal : 'a t -> 'a option (** Try to steal from the top of deque. This is thread-safe. *) diff --git a/src/ws_pool.ml b/src/ws_pool.ml index 44432112..4623a3e3 100644 --- a/src/ws_pool.ml +++ b/src/ws_pool.ml @@ -60,8 +60,16 @@ let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) match w with | Some w -> - WSQ.push w.q task; - try_wake_someone_ self + let pushed = WSQ.push w.q task in + if pushed then + try_wake_someone_ self + else ( + (* overflow into main queue *) + Mutex.lock self.mutex; + Queue.push task self.main_q; + if self.n_waiting_nonzero then Condition.signal self.cond; + Mutex.unlock self.mutex + ) | None -> if A.get self.active then ( (* push into the main queue *) @@ -202,6 +210,8 @@ type ('a, 'b) create_args = 'a (** Arguments used in {!create}. See {!create} for explanations. *) +let dummy_task_ () = assert false + let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?around_task ?num_threads () : t = @@ -221,7 +231,11 @@ let create ?(on_init_thread = default_thread_init_exit_) let workers : worker_state array = let dummy = Thread.self () in Array.init num_threads (fun i -> - { thread = dummy; q = WSQ.create (); rng = Random.State.make [| i |] }) + { + thread = dummy; + q = WSQ.create ~dummy:dummy_task_ (); + rng = Random.State.make [| i |]; + }) in let pool = diff --git a/test/t_ws_deque.ml b/test/t_ws_deque.ml index 3377dcb6..88429a8d 100644 --- a/test/t_ws_deque.ml +++ b/test/t_ws_deque.ml @@ -2,22 +2,23 @@ module A = Moonpool.Atomic module D = Moonpool.Private.Ws_deque_ let ( let@ ) = ( @@ ) +let dummy = -100 let t_simple () = - let d = D.create () in + let d = D.create ~dummy () in assert (D.steal d = None); assert (D.pop d = None); - D.push d 1; - D.push d 2; + assert (D.push d 1); + assert (D.push d 2); assert (D.pop d = Some 2); assert (D.steal d = Some 1); assert (D.steal d = None); assert (D.pop d = None); - D.push d 3; + assert (D.push d 3); assert (D.pop d = Some 3); - D.push d 4; - D.push d 5; - D.push d 6; + assert (D.push d 4); + assert (D.push d 5); + assert (D.push d 6); assert (D.steal d = Some 4); assert (D.steal d = Some 5); assert (D.pop d = Some 6); @@ -35,7 +36,7 @@ let t_heavy () = let active = A.make true in - let d = D.create () in + let d = D.create ~dummy () in let stealer_loop () = Trace.set_thread_name "stealer"; @@ -51,11 +52,13 @@ let t_heavy () = Trace.set_thread_name "producer"; for _i = 1 to 100_000 do let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main.outer" in + + (* NOTE: we make sure to push less than 256 elements at once *) for j = 1 to 100 do ref_sum := !ref_sum + j; - D.push d j; + assert (D.push d j); ref_sum := !ref_sum + j; - D.push d j; + assert (D.push d j); Option.iter (fun x -> add_to_sum x) (D.pop d); Option.iter (fun x -> add_to_sum x) (D.pop d) @@ -92,35 +95,8 @@ let t_heavy () = assert (ref_sum = sum); () -let t_many () = - print_endline "pushing many elements"; - let d = D.create () in - - let push_and_pop count = - for i = 1 to count do - (* if i mod 100_000 = 0 then Printf.printf "push %d\n%!" i; *) - D.push d i - done; - let n = ref 0 in - - let continue = ref true in - while !continue do - match D.pop d with - | None -> continue := false - | Some _ -> incr n - done; - assert (!n = count) - in - - push_and_pop 10_000; - push_and_pop 100_000; - push_and_pop 1_000_000; - print_endline "pushing many elements: ok"; - () - let () = let@ () = Trace_tef.with_setup () in t_simple (); t_heavy (); - t_many (); ()