mirror of
https://github.com/c-cube/moonpool.git
synced 2026-05-05 08:54:24 -04:00
ws pool: use ws queue in as_runner (#46)
fix a bug where the work stealing queue wasn't used in the `Runner.t` implementation. close #45
This commit is contained in:
parent
18701bfde4
commit
0d0db75f26
4 changed files with 70 additions and 18 deletions
2
Makefile
2
Makefile
|
|
@ -7,7 +7,7 @@ clean:
|
||||||
@dune clean
|
@dune clean
|
||||||
|
|
||||||
test:
|
test:
|
||||||
@dune runtest $(DUNE_OPTS)
|
@dune runtest $(DUNE_OPTS) --no-buffer
|
||||||
|
|
||||||
test-autopromote:
|
test-autopromote:
|
||||||
@dune runtest $(DUNE_OPTS) --auto-promote
|
@dune runtest $(DUNE_OPTS) --auto-promote
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,7 @@ include Runner
|
||||||
|
|
||||||
let ( let@ ) = ( @@ )
|
let ( let@ ) = ( @@ )
|
||||||
|
|
||||||
module Id = struct
|
|
||||||
type t = unit ref
|
|
||||||
(** Unique identifier for a pool *)
|
|
||||||
|
|
||||||
let create () : t = Sys.opaque_identity (ref ())
|
|
||||||
let equal : t -> t -> bool = ( == )
|
|
||||||
end
|
|
||||||
|
|
||||||
type state = {
|
type state = {
|
||||||
id_: Id.t;
|
|
||||||
(** Unique to this pool. Used to make sure tasks stay within the same
|
|
||||||
pool. *)
|
|
||||||
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
|
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
|
||||||
mutable workers: worker_state array; (** Fixed set of workers. *)
|
mutable workers: worker_state array; (** Fixed set of workers. *)
|
||||||
main_q: WL.task_full Queue.t;
|
main_q: WL.task_full Queue.t;
|
||||||
|
|
@ -99,12 +88,15 @@ let schedule_in_main_queue (self : state) task : unit =
|
||||||
longer permitted *)
|
longer permitted *)
|
||||||
raise Shutdown
|
raise Shutdown
|
||||||
|
|
||||||
let schedule_from_w (self : worker_state) (task : WL.task_full) : unit =
|
let schedule_from_anywhere_ (st : state) (task : WL.task_full) : unit =
|
||||||
match get_current_worker_ () with
|
match get_current_worker_ () with
|
||||||
| Some w when Id.equal self.st.id_ w.st.id_ ->
|
| Some w when st == w.st ->
|
||||||
(* use worker from the same pool *)
|
(* use worker from the same pool *)
|
||||||
schedule_on_current_worker w task
|
schedule_on_current_worker w task
|
||||||
| _ -> schedule_in_main_queue self.st task
|
| _ -> schedule_in_main_queue st task
|
||||||
|
|
||||||
|
let schedule_from_w (w : worker_state) task : unit =
|
||||||
|
schedule_from_anywhere_ w.st task
|
||||||
|
|
||||||
exception Got_task of WL.task_full
|
exception Got_task of WL.task_full
|
||||||
|
|
||||||
|
|
@ -223,7 +215,8 @@ let as_runner_ (self : state) : t =
|
||||||
Runner.For_runner_implementors.create
|
Runner.For_runner_implementors.create
|
||||||
~shutdown:(fun ~wait () -> shutdown_ self ~wait)
|
~shutdown:(fun ~wait () -> shutdown_ self ~wait)
|
||||||
~run_async:(fun ~fiber f ->
|
~run_async:(fun ~fiber f ->
|
||||||
schedule_in_main_queue self @@ T_start { fiber; f })
|
let task = WL.T_start { fiber; f } in
|
||||||
|
schedule_from_anywhere_ self task)
|
||||||
~size:(fun () -> size_ self)
|
~size:(fun () -> size_ self)
|
||||||
~num_tasks:(fun () -> num_tasks_ self)
|
~num_tasks:(fun () -> num_tasks_ self)
|
||||||
()
|
()
|
||||||
|
|
@ -240,7 +233,6 @@ type ('a, 'b) create_args =
|
||||||
let create ?(on_init_thread = default_thread_init_exit_)
|
let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
|
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
|
||||||
?num_threads ?name () : t =
|
?num_threads ?name () : t =
|
||||||
let pool_id_ = Id.create () in
|
|
||||||
let num_domains = Domain_pool_.max_number_of_domains () in
|
let num_domains = Domain_pool_.max_number_of_domains () in
|
||||||
let num_threads = Util_pool_.num_threads ?num_threads () in
|
let num_threads = Util_pool_.num_threads ?num_threads () in
|
||||||
|
|
||||||
|
|
@ -249,7 +241,6 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
||||||
|
|
||||||
let pool =
|
let pool =
|
||||||
{
|
{
|
||||||
id_ = pool_id_;
|
|
||||||
active = A.make true;
|
active = A.make true;
|
||||||
workers = [||];
|
workers = [||];
|
||||||
main_q = Queue.create ();
|
main_q = Queue.create ();
|
||||||
|
|
|
||||||
|
|
@ -20,3 +20,10 @@
|
||||||
unix
|
unix
|
||||||
trace-tef
|
trace-tef
|
||||||
trace))
|
trace))
|
||||||
|
|
||||||
|
(test
|
||||||
|
(name t_fib_await_mem)
|
||||||
|
(package moonpool)
|
||||||
|
(enabled_if
|
||||||
|
(= %{system} linux))
|
||||||
|
(libraries moonpool))
|
||||||
|
|
|
||||||
54
test/t_fib_await_mem.ml
Normal file
54
test/t_fib_await_mem.ml
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
(* regression test for #45 *)
|
||||||
|
|
||||||
|
open Moonpool
|
||||||
|
|
||||||
|
let ( let@ ) = ( @@ )
|
||||||
|
|
||||||
|
let rec fib_direct x =
|
||||||
|
if x <= 1 then
|
||||||
|
1
|
||||||
|
else
|
||||||
|
fib_direct (x - 1) + fib_direct (x - 2)
|
||||||
|
|
||||||
|
let cutoff = 8
|
||||||
|
|
||||||
|
let rec fib_await ~on x : int Fut.t =
|
||||||
|
if x <= cutoff then
|
||||||
|
Fut.spawn ~on (fun () -> fib_direct x)
|
||||||
|
else
|
||||||
|
Fut.spawn ~on (fun () ->
|
||||||
|
let n1 = fib_await ~on (x - 1) in
|
||||||
|
let n2 = fib_await ~on (x - 2) in
|
||||||
|
let n1 = Fut.await n1 in
|
||||||
|
let n2 = Fut.await n2 in
|
||||||
|
n1 + n2)
|
||||||
|
|
||||||
|
(** Read VmHWM (peak RSS in kB) from /proc/self/status. *)
|
||||||
|
let get_vmhwm_kb () : int option =
|
||||||
|
let path = "/proc/self/status" in
|
||||||
|
match In_channel.with_open_bin path In_channel.input_all with
|
||||||
|
| exception Sys_error _ -> None
|
||||||
|
| content ->
|
||||||
|
let lines = String.split_on_char '\n' content in
|
||||||
|
List.find_map
|
||||||
|
(fun line -> Scanf.sscanf_opt line "VmHWM: %d kB" Fun.id)
|
||||||
|
lines
|
||||||
|
|
||||||
|
let max_rss_bytes = 150_000_000
|
||||||
|
|
||||||
|
let () =
|
||||||
|
let@ pool = Ws_pool.with_ ~num_threads:4 () in
|
||||||
|
let result = fib_await ~on:pool 40 |> Fut.wait_block_exn in
|
||||||
|
assert (result = 165580141);
|
||||||
|
match get_vmhwm_kb () with
|
||||||
|
| None ->
|
||||||
|
Printf.printf "fib 40 = %d (skip RSS check: no /proc/self/status)\n%!"
|
||||||
|
result
|
||||||
|
| Some hwm_kb ->
|
||||||
|
let hwm_bytes = hwm_kb * 1024 in
|
||||||
|
Printf.printf "fib 40 = %d, peak RSS = %d bytes\n%!" result hwm_bytes;
|
||||||
|
if hwm_bytes > max_rss_bytes then (
|
||||||
|
Printf.eprintf "FAIL: peak RSS %d bytes exceeds limit %d bytes\n%!"
|
||||||
|
hwm_bytes max_rss_bytes;
|
||||||
|
exit 1
|
||||||
|
)
|
||||||
Loading…
Add table
Reference in a new issue