diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e54322e6..56eecb58 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -9,6 +9,7 @@ on: jobs: run: name: build + timeout-minutes: 10 strategy: fail-fast: true matrix: diff --git a/Makefile b/Makefile index a7308673..1d7ec227 100644 --- a/Makefile +++ b/Makefile @@ -22,22 +22,41 @@ watch: DUNE_OPTS_BENCH?=--profile=release N?=40 -NITER?=3 +NITER?=2 BENCH_PSIZE?=1,4,8,20 +BENCH_KIND?=fifo,pool BENCH_CUTOFF?=20 bench-fib: @echo running for N=$(N) dune build $(DUNE_OPTS_BENCH) benchs/fib_rec.exe - hyperfine -L psize $(BENCH_PSIZE) \ - './_build/default/benchs/fib_rec.exe -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize={psize} -n $(N)' + + hyperfine --warmup=1 \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -seq' \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -dl' \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize=20 -kind=pool -fj' \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize=20 -kind=pool -await' \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize=4 -kind=fifo' \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize=4 -kind=pool' \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize=8 -kind=fifo' \ + './_build/default/benchs/fib_rec.exe -n $(N) -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize=16 -kind=pool' + + #hyperfine -L psize $(BENCH_PSIZE) -L kind $(BENCH_KIND) --warmup=1 \ + # './_build/default/benchs/fib_rec.exe -cutoff $(BENCH_CUTOFF) -niter $(NITER) -psize={psize} -kind={kind} -n $(N)' + #'./_build/default/benchs/fib_rec.exe -seq -cutoff $(BENCH_CUTOFF) -niter $(NITER) -n $(N)' \ + #'./_build/default/benchs/fib_rec.exe -dl -cutoff $(BENCH_CUTOFF) -niter $(NITER) -n $(N)' \ PI_NSTEPS?=100_000_000 PI_MODES?=seq,par1,forkjoin +PI_KIND?=fifo,pool bench-pi: @echo running for N=$(PI_NSTEPS) dune build $(DUNE_OPTS_BENCH) benchs/pi.exe - hyperfine -L mode $(PI_MODES) \ - './_build/default/benchs/pi.exe -mode={mode} -n $(PI_NSTEPS)' + hyperfine --warmup=1 \ + './_build/default/benchs/pi.exe -n $(PI_NSTEPS) -mode=seq' \ + './_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 8 -mode par1 -kind=pool' \ + './_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 8 -mode par1 -kind=fifo' \ + './_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 16 -mode forkjoin -kind=pool' \ + './_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 20 -mode forkjoin -kind=pool' .PHONY: test clean bench-fib bench-pi diff --git a/README.md b/README.md index 60f478d3..5135d00b 100644 --- a/README.md +++ b/README.md @@ -24,22 +24,31 @@ In addition, some concurrency and parallelism primitives are provided: ## Usage -The user can create several thread pools. These pools use regular posix threads, -but the threads are spread across multiple domains (on OCaml 5), which enables -parallelism. +The user can create several thread pools (implementing the interface `Runner.t`). +These pools use regular posix threads, but the threads are spread across +multiple domains (on OCaml 5), which enables parallelism. -The function `Pool.run_async pool task` runs `task()` on one of the workers -of `pool`, as soon as one is available. No result is returned. +Current we provide these pool implementations: +- `Fifo_pool` is a thread pool that uses a blocking queue to schedule tasks, + which means they're picked in the same order they've been scheduled ("fifo"). + This pool is simple and will behave fine for coarse-granularity concurrency, + but will slow down under heavy contention. +- `Ws_pool` is a work-stealing pool, where each thread has its own local queue + in addition to a global queue of tasks. This is efficient for workloads + with many short tasks that spawn other tasks, but the order in which + tasks are run is less predictable. This is useful when throughput is + the important thing to optimize. + +The function `Runner.run_async pool task` schedules `task()` to run on one of +the workers of `pool`, as soon as one is available. No result is returned by `run_async`. ```ocaml # #require "threads";; -# let pool = Moonpool.Pool.create ~min:4 ();; -val pool : Moonpool.Runner.t = - {Moonpool.Pool.run_async = ; shutdown = ; size = ; - num_tasks = } +# let pool = Moonpool.Fifo_pool.create ~num_threads:4 ();; +val pool : Moonpool.Runner.t = # begin - Moonpool.Pool.run_async pool + Moonpool.Runner.run_async pool (fun () -> Thread.delay 0.1; print_endline "running from the pool"); @@ -51,11 +60,13 @@ running from the pool - : unit = () ``` -To wait until the task is done, you can use `Pool.run_wait_block` instead: +To wait until the task is done, you can use `Runner.run_wait_block`[^1] instead: + +[^1]: beware of deadlock! See documentation for more details. ```ocaml # begin - Moonpool.Pool.run_wait_block pool + Moonpool.Runner.run_wait_block pool (fun () -> Thread.delay 0.1; print_endline "running from the pool"); @@ -157,7 +168,11 @@ val expected_sum : int = 5050 On OCaml 5, again using effect handlers, the module `Fork_join` implements the [fork-join model](https://en.wikipedia.org/wiki/Fork%E2%80%93join_model). -It must run on a pool (using [Pool.run] or inside a future via [Future.spawn]). +It must run on a pool (using [Runner.run_async] or inside a future via [Fut.spawn]). + +It is generally better to use the work-stealing pool for workloads that rely on +fork-join for better performance, because fork-join will tend to spawn lots of +shorter tasks. ```ocaml # let rec select_sort arr i len = @@ -259,7 +274,7 @@ This works for OCaml >= 4.08. the same pool, too — this is useful for threads blocking on IO). A useful analogy is that each domain is a bit like a CPU core, and `Thread.t` is a logical thread running on a core. - Multiple threads have to share a single core and do not run in parallel on it[^1]. + Multiple threads have to share a single core and do not run in parallel on it[^2]. We can therefore build pools that spread their worker threads on multiple cores to enable parallelism within each pool. TODO: actually use https://github.com/haesbaert/ocaml-processor to pin domains to cores, @@ -275,4 +290,4 @@ MIT license. $ opam install moonpool ``` -[^1]: let's not talk about hyperthreading. +[^2]: let's not talk about hyperthreading. diff --git a/bench_fib.sh b/bench_fib.sh new file mode 100755 index 00000000..e9996d53 --- /dev/null +++ b/bench_fib.sh @@ -0,0 +1,3 @@ +#!/bin/sh +OPTS="--profile=release --display=quiet" +exec dune exec $OPTS -- benchs/fib_rec.exe $@ diff --git a/benchs/dune b/benchs/dune index 2c798176..ff0f878b 100644 --- a/benchs/dune +++ b/benchs/dune @@ -3,4 +3,4 @@ (names fib_rec pi) (preprocess (action (run %{project_root}/src/cpp/cpp.exe %{input-file}))) - (libraries moonpool unix)) + (libraries moonpool unix trace trace-tef domainslib)) diff --git a/benchs/fib_rec.ml b/benchs/fib_rec.ml index 4ff984f4..b0f1623d 100644 --- a/benchs/fib_rec.ml +++ b/benchs/fib_rec.ml @@ -12,42 +12,127 @@ let rec fib ~on x : int Fut.t = if x <= !cutoff then Fut.spawn ~on (fun () -> fib_direct x) else - let open Fut.Infix_local in + let open Fut.Infix in let+ t1 = fib ~on (x - 1) and+ t2 = fib ~on (x - 2) in t1 + t2 +let fib_fj ~on x : int Fut.t = + let rec fib_rec x : int = + if x <= !cutoff then + fib_direct x + else ( + let n1, n2 = + Fork_join.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) + in + n1 + n2 + ) + in + Fut.spawn ~on (fun () -> fib_rec x) + +let fib_await ~on x : int Fut.t = + let rec fib_rec x : int Fut.t = + if x <= !cutoff then + Fut.spawn ~on (fun () -> fib_direct x) + else + Fut.spawn ~on (fun () -> + let n1 = fib_rec (x - 1) in + let n2 = fib_rec (x - 2) in + let n1 = Fut.await n1 in + let n2 = Fut.await n2 in + n1 + n2) + in + fib_rec x + +let rec fib_dl ~pool x : int Domainslib.Task.promise = + if x <= !cutoff then + Domainslib.Task.async pool (fun () -> fib_direct x) + else + Domainslib.Task.async pool (fun () -> + let t1 = fib_dl ~pool (x - 1) and t2 = fib_dl ~pool (x - 2) in + let t1 = Domainslib.Task.await pool t1 in + let t2 = Domainslib.Task.await pool t2 in + t1 + t2) + let () = assert (List.init 10 fib_direct = [ 1; 1; 2; 3; 5; 8; 13; 21; 34; 55 ]) -let run ~psize ~n ~seq ~niter () : unit = - let pool = lazy (Pool.create ~min:psize ()) in +let create_pool ~psize ~kind () = + match kind with + | "fifo" -> Fifo_pool.create ?num_threads:psize () + | "pool" -> Ws_pool.create ?num_threads:psize () + | _ -> assert false + +let str_of_int_opt = function + | None -> "None" + | Some i -> Printf.sprintf "Some %d" i + +let run ~psize ~n ~seq ~dl ~fj ~await ~niter ~kind () : unit = + let pool = lazy (create_pool ~kind ~psize ()) in + let dl_pool = + lazy + (let n = Domain.recommended_domain_count () in + Printf.printf "use %d domains\n%!" n; + Domainslib.Task.setup_pool ~num_domains:n ()) + in for _i = 1 to niter do let res = if seq then ( Printf.printf "compute fib %d sequentially\n%!" n; fib_direct n + ) else if dl then ( + Printf.printf "compute fib %d with domainslib\n%!" n; + let (lazy pool) = dl_pool in + Domainslib.Task.run pool (fun () -> + Domainslib.Task.await pool @@ fib_dl ~pool n) + ) else if fj then ( + Printf.printf "compute fib %d using fork-join with pool size=%s\n%!" n + (str_of_int_opt psize); + fib_fj ~on:(Lazy.force pool) n |> Fut.wait_block_exn + ) else if await then ( + Printf.printf "compute fib %d using await with pool size=%s\n%!" n + (str_of_int_opt psize); + fib_await ~on:(Lazy.force pool) n |> Fut.wait_block_exn ) else ( - Printf.printf "compute fib %d with pool size=%d\n%!" n psize; + Printf.printf "compute fib %d with pool size=%s\n%!" n + (str_of_int_opt psize); fib ~on:(Lazy.force pool) n |> Fut.wait_block_exn ) in Printf.printf "fib %d = %d\n%!" n res - done + done; + + if seq then + () + else if dl then + Domainslib.Task.teardown_pool (Lazy.force dl_pool) + else + Ws_pool.shutdown (Lazy.force pool) let () = let n = ref 40 in - let psize = ref 16 in + let psize = ref None in let seq = ref false in let niter = ref 3 in + let kind = ref "pool" in + let dl = ref false in + let await = ref false in + let fj = ref false in let opts = [ - "-psize", Arg.Set_int psize, " pool size"; + "-psize", Arg.Int (fun i -> psize := Some i), " pool size"; "-n", Arg.Set_int n, " fib "; "-seq", Arg.Set seq, " sequential"; + "-dl", Arg.Set dl, " domainslib"; + "-fj", Arg.Set fj, " fork join"; "-niter", Arg.Set_int niter, " number of iterations"; + "-await", Arg.Set await, " use await"; "-cutoff", Arg.Set_int cutoff, " cutoff for sequential computation"; + ( "-kind", + Arg.Symbol ([ "pool"; "fifo" ], ( := ) kind), + " pick pool implementation" ); ] |> Arg.align in Arg.parse opts ignore ""; - run ~psize:!psize ~n:!n ~seq:!seq ~niter:!niter () + run ~psize:!psize ~n:!n ~fj:!fj ~seq:!seq ~await:!await ~dl:!dl ~niter:!niter + ~kind:!kind () diff --git a/benchs/pi.ml b/benchs/pi.ml index 1dd55fb9..c8ef57b5 100644 --- a/benchs/pi.ml +++ b/benchs/pi.ml @@ -17,17 +17,25 @@ let run_sequential (num_steps : int) : float = pi (** Create a pool *) -let with_pool f = - if !j = 0 then - Pool.with_ ~per_domain:1 f - else - Pool.with_ ~min:!j f +let with_pool ~kind f = + match kind with + | "pool" -> + if !j = 0 then + Ws_pool.with_ f + else + Ws_pool.with_ ~num_threads:!j f + | "fifo" -> + if !j = 0 then + Fifo_pool.with_ f + else + Fifo_pool.with_ ~num_threads:!j f + | _ -> assert false (** Run in parallel using {!Fut.for_} *) -let run_par1 (num_steps : int) : float = - let@ pool = with_pool () in +let run_par1 ~kind (num_steps : int) : float = + let@ pool = with_pool ~kind () in - let num_tasks = Pool.size pool in + let num_tasks = Ws_pool.size pool in let step = 1. /. float num_steps in let global_sum = Lock.create 0. in @@ -53,15 +61,15 @@ let run_par1 (num_steps : int) : float = [@@@ifge 5.0] -let run_fork_join num_steps : float = - let@ pool = with_pool () in +let run_fork_join ~kind num_steps : float = + let@ pool = with_pool ~kind () in - let num_tasks = Pool.size pool in + let num_tasks = Ws_pool.size pool in let step = 1. /. float num_steps in let global_sum = Lock.create 0. in - Pool.run_wait_block pool (fun () -> + Ws_pool.run_wait_block pool (fun () -> Fork_join.for_ ~chunk_size:(3 + (num_steps / num_tasks)) num_steps @@ -90,9 +98,11 @@ type mode = | Fork_join let () = + let@ () = Trace_tef.with_setup () in let mode = ref Sequential in let n = ref 1000 in let time = ref false in + let kind = ref "pool" in let set_mode = function | "seq" -> mode := Sequential @@ -109,6 +119,9 @@ let () = " mode of execution" ); "-j", Arg.Set_int j, " number of threads"; "-t", Arg.Set time, " printing timing"; + ( "-kind", + Arg.Symbol ([ "pool"; "fifo" ], ( := ) kind), + " pick pool implementation" ); ] |> Arg.align in @@ -118,8 +131,8 @@ let () = let res = match !mode with | Sequential -> run_sequential !n - | Par1 -> run_par1 !n - | Fork_join -> run_fork_join !n + | Par1 -> run_par1 ~kind:!kind !n + | Fork_join -> run_fork_join ~kind:!kind !n in let elapsed : float = Unix.gettimeofday () -. t_start in diff --git a/dune-project b/dune-project index 3cd1b85e..f6ace0a6 100644 --- a/dune-project +++ b/dune-project @@ -20,6 +20,7 @@ dune (either (>= 1.0)) (trace :with-test) + (trace-tef :with-test) (qcheck-core (and :with-test (>= 0.19))) (odoc :with-doc) (mdx @@ -27,8 +28,9 @@ (>= 1.9.0) :with-test))) (depopts + thread-local-storage (domain-local-await (>= 0.2))) (tags - (thread pool domain))) + (thread pool domain futures fork-join))) ; See the complete stanza docs at https://dune.readthedocs.io/en/stable/dune-files.html#dune-project diff --git a/moonpool.opam b/moonpool.opam index 3f411cfa..62bdcf6e 100644 --- a/moonpool.opam +++ b/moonpool.opam @@ -5,7 +5,7 @@ synopsis: "Pools of threads supported by a pool of domains" maintainer: ["Simon Cruanes"] authors: ["Simon Cruanes"] license: "MIT" -tags: ["thread" "pool" "domain"] +tags: ["thread" "pool" "domain" "futures" "fork-join"] homepage: "https://github.com/c-cube/moonpool" bug-reports: "https://github.com/c-cube/moonpool/issues" depends: [ @@ -13,11 +13,13 @@ depends: [ "dune" {>= "3.0"} "either" {>= "1.0"} "trace" {with-test} + "trace-tef" {with-test} "qcheck-core" {with-test & >= "0.19"} "odoc" {with-doc} "mdx" {>= "1.9.0" & with-test} ] depopts: [ + "thread-local-storage" "domain-local-await" {>= "0.2"} ] build: [ diff --git a/src/d_pool_.ml b/src/d_pool_.ml index fb78535b..d12a4f6a 100644 --- a/src/d_pool_.ml +++ b/src/d_pool_.ml @@ -18,9 +18,7 @@ type worker_state = { including a work queue and a thread refcount; and the domain itself, if any, in a separate option because it might outlive its own state. *) let domains_ : (worker_state option * Domain_.t option) Lock.t array = - (* number of domains we spawn. Note that we spawn n-1 domains - because there already is the main domain running. *) - let n = max 1 (Domain_.recommended_number () - 1) in + let n = max 1 (Domain_.recommended_number ()) in Array.init n (fun _ -> Lock.create (None, None)) (** main work loop for a domain worker. @@ -84,6 +82,14 @@ let work_ idx (st : worker_state) : unit = done; () +(* special case for main domain: we start a worker immediately *) +let () = + assert (Domain_.is_main_domain ()); + let w = { th_count = Atomic_.make 1; q = Bb_queue.create () } in + (* thread that stays alive *) + ignore (Thread.create (fun () -> work_ 0 w) () : Thread.t); + domains_.(0) <- Lock.create (Some w, None) + let[@inline] n_domains () : int = Array.length domains_ let run_on (i : int) (f : unit -> unit) : unit = diff --git a/src/domain_.ml b/src/domain_.ml index 60d1e669..3050282f 100644 --- a/src/domain_.ml +++ b/src/domain_.ml @@ -9,6 +9,7 @@ let get_id (self : t) : int = (Domain.get_id self :> int) let spawn : _ -> t = Domain.spawn let relax = Domain.cpu_relax let join = Domain.join +let is_main_domain = Domain.is_main_domain [@@@ocaml.alert "+unstable"] [@@@else_] @@ -21,5 +22,6 @@ let get_id (self : t) : int = Thread.id self let spawn f : t = Thread.create f () let relax () = Thread.yield () let join = Thread.join +let is_main_domain () = true [@@@endif] diff --git a/src/dune b/src/dune index d65920e8..59005b54 100644 --- a/src/dune +++ b/src/dune @@ -1,11 +1,14 @@ (library (public_name moonpool) (name moonpool) - (private_modules d_pool_) + (private_modules d_pool_ dla_) (preprocess (action (run %{project_root}/src/cpp/cpp.exe %{input-file}))) (libraries threads either + (select thread_local_storage_.ml from + (thread-local-storage -> thread_local_storage_.stub.ml) + (-> thread_local_storage_.real.ml)) (select dla_.ml from (domain-local-await -> dla_.real.ml) ( -> dla_.dummy.ml)))) diff --git a/src/fifo_pool.ml b/src/fifo_pool.ml new file mode 100644 index 00000000..c4cc59ac --- /dev/null +++ b/src/fifo_pool.ml @@ -0,0 +1,150 @@ +module TLS = Thread_local_storage_ +include Runner + +let ( let@ ) = ( @@ ) + +type state = { + threads: Thread.t array; + q: task Bb_queue.t; (** Queue for tasks. *) +} +(** internal state *) + +let[@inline] size_ (self : state) = Array.length self.threads +let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q + +(** Run [task] as is, on the pool. *) +let schedule_ (self : state) (task : task) : unit = + try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown + +type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task + +let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = + TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; + let (AT_pair (before_task, after_task)) = around_task in + + let run_task task : unit = + let _ctx = before_task runner in + (* run the task now, catching errors *) + (try Suspend_.with_suspend task ~run:(fun task' -> schedule_ self task') + with e -> + let bt = Printexc.get_raw_backtrace () in + on_exn e bt); + after_task runner _ctx + in + + let main_loop () = + let continue = ref true in + while !continue do + match Bb_queue.pop self.q with + | task -> run_task task + | exception Bb_queue.Closed -> continue := false + done + in + + try + (* handle domain-local await *) + Dla_.using ~prepare_for_await:Suspend_.prepare_for_await + ~while_running:main_loop + with Bb_queue.Closed -> () + +let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () + +let shutdown_ ~wait (self : state) : unit = + Bb_queue.close self.q; + if wait then Array.iter Thread.join self.threads + +type ('a, 'b) create_args = + ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> + ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> + ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> + ?around_task:(t -> 'b) * (t -> 'b -> unit) -> + ?num_threads:int -> + 'a + +let create ?(on_init_thread = default_thread_init_exit_) + ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) + ?around_task ?num_threads () : t = + (* wrapper *) + let around_task = + match around_task with + | Some (f, g) -> AT_pair (f, g) + | None -> AT_pair (ignore, fun _ _ -> ()) + in + + let num_domains = D_pool_.n_domains () in + + (* number of threads to run *) + let num_threads = Util_pool_.num_threads ?num_threads () in + + (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) + let offset = Random.int num_domains in + + let pool = + let dummy = Thread.self () in + { threads = Array.make num_threads dummy; q = Bb_queue.create () } + in + + let runner = + Runner.For_runner_implementors.create + ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) + ~run_async:(fun f -> schedule_ pool f) + ~size:(fun () -> size_ pool) + ~num_tasks:(fun () -> num_tasks_ pool) + () + in + + (* temporary queue used to obtain thread handles from domains + on which the thread are started. *) + let receive_threads = Bb_queue.create () in + + (* start the thread with index [i] *) + let start_thread_with_idx i = + let dom_idx = (offset + i) mod num_domains in + + (* function run in the thread itself *) + let main_thread_fun () : unit = + let thread = Thread.self () in + let t_id = Thread.id thread in + on_init_thread ~dom_id:dom_idx ~t_id (); + + let run () = worker_thread_ pool runner ~on_exn ~around_task in + + (* now run the main loop *) + Fun.protect run ~finally:(fun () -> + (* on termination, decrease refcount of underlying domain *) + D_pool_.decr_on dom_idx); + on_exit_thread ~dom_id:dom_idx ~t_id () + in + + (* function called in domain with index [i], to + create the thread and push it into [receive_threads] *) + let create_thread_in_domain () = + let thread = Thread.create main_thread_fun () in + (* send the thread from the domain back to us *) + Bb_queue.push receive_threads (i, thread) + in + + D_pool_.run_on dom_idx create_thread_in_domain + in + + (* start all threads, placing them on the domains + according to their index and [offset] in a round-robin fashion. *) + for i = 0 to num_threads - 1 do + start_thread_with_idx i + done; + + (* receive the newly created threads back from domains *) + for _j = 1 to num_threads do + let i, th = Bb_queue.pop receive_threads in + pool.threads.(i) <- th + done; + + runner + +let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads () f + = + let pool = + create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads () + in + let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in + f pool diff --git a/src/fifo_pool.mli b/src/fifo_pool.mli new file mode 100644 index 00000000..4371db58 --- /dev/null +++ b/src/fifo_pool.mli @@ -0,0 +1,44 @@ +(** A simple thread pool in FIFO order. + + FIFO: first-in, first-out. Basically tasks are put into a queue, + and worker threads pull them out of the queue at the other end. + + Since this uses a single blocking queue to manage tasks, it's very + simple and reliable. The number of worker threads is fixed, but + they are spread over several domains to enable parallelism. + + This can be useful for latency-sensitive applications (e.g. as a + pool of workers for network servers). Work-stealing pools might + have higher throughput but they're very unfair to some tasks; by + contrast, here, older tasks have priority over younger tasks. + + @since NEXT_RELEASE *) + +include module type of Runner + +type ('a, 'b) create_args = + ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> + ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> + ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> + ?around_task:(t -> 'b) * (t -> 'b -> unit) -> + ?num_threads:int -> + 'a +(** Arguments used in {!create}. See {!create} for explanations. *) + +val create : (unit -> t, _) create_args +(** [create ()] makes a new thread pool. + @param on_init_thread called at the beginning of each new thread in the pool. + @param min minimum size of the pool. See {!Pool.create_args}. + The default is [Domain.recommended_domain_count()], ie one worker per + CPU core. + On OCaml 4 the default is [4] (since there is only one domain). + @param on_exit_thread called at the end of each worker thread in the pool. + @param around_task a pair of [before, after] functions + ran around each task. See {!Pool.create_args}. + *) + +val with_ : (unit -> (t -> 'a) -> 'a, _) create_args +(** [with_ () f] calls [f pool], where [pool] is obtained via {!create}. + When [f pool] returns or fails, [pool] is shutdown and its resources + are released. + Most parameters are the same as in {!create}. *) diff --git a/src/fork_join.ml b/src/fork_join.ml index f1733514..8ad61cec 100644 --- a/src/fork_join.ml +++ b/src/fork_join.ml @@ -3,91 +3,100 @@ module A = Atomic_ module State_ = struct - type 'a single_res = - | St_none - | St_some of 'a - | St_fail of exn * Printexc.raw_backtrace + type error = exn * Printexc.raw_backtrace + type 'a or_error = ('a, error) result - type ('a, 'b) t = { - mutable suspension: - ((unit, exn * Printexc.raw_backtrace) result -> unit) option; - (** suspended caller *) - left: 'a single_res; - right: 'b single_res; - } + type ('a, 'b) t = + | Init + | Left_solved of 'a or_error + | Right_solved of 'b or_error * Suspend_.suspension + | Both_solved of 'a or_error * 'b or_error - let get_exn (self : _ t A.t) = + let get_exn_ (self : _ t A.t) = match A.get self with - | { left = St_fail (e, bt); _ } | { right = St_fail (e, bt); _ } -> - Printexc.raise_with_backtrace e bt - | { left = St_some x; right = St_some y; _ } -> x, y + | Both_solved (Ok a, Ok b) -> a, b + | Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) -> + Printexc.raise_with_backtrace exn bt | _ -> assert false - let check_if_state_complete_ (self : _ t) : unit = - match self.left, self.right, self.suspension with - | St_some _, St_some _, Some f -> f (Ok ()) - | St_fail (e, bt), _, Some f | _, St_fail (e, bt), Some f -> - f (Error (e, bt)) - | _ -> () - - let set_left_ (self : _ t A.t) (x : _ single_res) = - while - let old_st = A.get self in - let new_st = { old_st with left = x } in - if A.compare_and_set self old_st new_st then ( - check_if_state_complete_ new_st; - false + let rec set_left_ (self : _ t A.t) (left : _ or_error) = + let old_st = A.get self in + match old_st with + | Init -> + let new_st = Left_solved left in + if not (A.compare_and_set self old_st new_st) then ( + Domain_.relax (); + set_left_ self left + ) + | Right_solved (right, cont) -> + let new_st = Both_solved (left, right) in + if not (A.compare_and_set self old_st new_st) then ( + Domain_.relax (); + set_left_ self left ) else - true - do - Domain_.relax () - done + cont (Ok ()) + | Left_solved _ | Both_solved _ -> assert false - let set_right_ (self : _ t A.t) (y : _ single_res) = - while - let old_st = A.get self in - let new_st = { old_st with right = y } in - if A.compare_and_set self old_st new_st then ( - check_if_state_complete_ new_st; - false - ) else - true - do - Domain_.relax () - done + let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit = + let old_st = A.get self in + match old_st with + | Left_solved left -> + let new_st = Both_solved (left, right) in + if not (A.compare_and_set self old_st new_st) then set_right_ self right + | Init -> + (* we are first arrived, we suspend until the left computation is done *) + Suspend_.suspend + { + Suspend_.handle = + (fun ~run:_ suspension -> + while + let old_st = A.get self in + match old_st with + | Init -> + not + (A.compare_and_set self old_st + (Right_solved (right, suspension))) + | Left_solved left -> + (* other thread is done, no risk of race condition *) + A.set self (Both_solved (left, right)); + suspension (Ok ()); + false + | Right_solved _ | Both_solved _ -> assert false + do + () + done); + } + | Right_solved _ | Both_solved _ -> assert false end let both f g : _ * _ = - let open State_ in - let st = A.make { suspension = None; left = St_none; right = St_none } in + let module ST = State_ in + let st = A.make ST.Init in - let start_tasks ~run () : unit = - run ~with_handler:true (fun () -> - try - let res = f () in - set_left_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_left_ st (St_fail (e, bt))); - - run ~with_handler:true (fun () -> - try - let res = g () in - set_right_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_right_ st (St_fail (e, bt))) + let runner = + match Runner.get_current_runner () with + | None -> invalid_arg "Fork_join.both must be run from within a runner" + | Some r -> r in - Suspend_.suspend - { - Suspend_.handle = - (fun ~run suspension -> - (* nothing else is started, no race condition possible *) - (A.get st).suspension <- Some suspension; - start_tasks ~run ()); - }; - get_exn st + (* start computing [f] in the background *) + Runner.run_async runner (fun () -> + try + let res = f () in + ST.set_left_ st (Ok res) + with exn -> + let bt = Printexc.get_raw_backtrace () in + ST.set_left_ st (Error (exn, bt))); + + let res_right = + try Ok (g ()) + with exn -> + let bt = Printexc.get_raw_backtrace () in + Error (exn, bt) + in + + ST.set_right_ st res_right; + ST.get_exn_ st let both_ignore f g = ignore (both f g : _ * _) @@ -126,7 +135,7 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = let len_range = min chunk_size (n - offset) in assert (offset + len_range <= n); - run ~with_handler:true (fun () -> task_for ~offset ~len_range); + run (fun () -> task_for ~offset ~len_range); i := !i + len_range done in diff --git a/src/fut.ml b/src/fut.ml index 42767b61..0661ffa2 100644 --- a/src/fut.ml +++ b/src/fut.ml @@ -97,9 +97,14 @@ let spawn ~on f : _ t = fulfill promise res in - Pool.run_async on task; + Runner.run_async on task; fut +let spawn_on_current_runner f : _ t = + match Runner.get_current_runner () with + | None -> failwith "Fut.spawn_on_current_runner: not running on a runner" + | Some on -> spawn ~on f + let reify_error (f : 'a t) : 'a or_error t = match peek f with | Some res -> return res @@ -108,8 +113,13 @@ let reify_error (f : 'a t) : 'a or_error t = on_result f (fun r -> fulfill promise (Ok r)); fut +let get_runner_ ?on () : Runner.t option = + match on with + | Some _ as r -> r + | None -> Runner.get_current_runner () + let map ?on ~f fut : _ t = - let map_res r = + let map_immediate_ r : _ result = match r with | Ok x -> (try Ok (f x) @@ -119,20 +129,32 @@ let map ?on ~f fut : _ t = | Error e_bt -> Error e_bt in + match peek fut, get_runner_ ?on () with + | Some res, None -> of_result @@ map_immediate_ res + | Some res, Some runner -> + let fut2, promise = make () in + Runner.run_async runner (fun () -> fulfill promise @@ map_immediate_ res); + fut2 + | None, None -> + let fut2, promise = make () in + on_result fut (fun res -> fulfill promise @@ map_immediate_ res); + fut2 + | None, Some runner -> + let fut2, promise = make () in + on_result fut (fun res -> + Runner.run_async runner (fun () -> + fulfill promise @@ map_immediate_ res)); + fut2 + +let join (fut : 'a t t) : 'a t = match peek fut with - | Some r -> of_result (map_res r) + | Some (Ok f) -> f + | Some (Error (e, bt)) -> fail e bt | None -> let fut2, promise = make () in - on_result fut (fun r -> - let map_and_fulfill () = - let res = map_res r in - fulfill promise res - in - - match on with - | None -> map_and_fulfill () - | Some on -> Pool.run_async on map_and_fulfill); - + on_result fut (function + | Ok sub_fut -> on_result sub_fut (fulfill promise) + | Error _ as e -> fulfill promise e); fut2 let bind ?on ~f fut : _ t = @@ -146,33 +168,31 @@ let bind ?on ~f fut : _ t = | Error (e, bt) -> fail e bt in - let bind_and_fulfill r promise () = + let bind_and_fulfill (r : _ result) promise () : unit = let f_res_fut = apply_f_to_res r in (* forward result *) on_result f_res_fut (fun r -> fulfill promise r) in - match peek fut with - | Some r -> - (match on with - | None -> apply_f_to_res r - | Some on -> - let fut2, promise = make () in - Pool.run_async on (bind_and_fulfill r promise); - fut2) - | None -> + match peek fut, get_runner_ ?on () with + | Some res, Some runner -> + let fut2, promise = make () in + Runner.run_async runner (bind_and_fulfill res promise); + fut2 + | Some res, None -> apply_f_to_res res + | None, Some runner -> let fut2, promise = make () in on_result fut (fun r -> - match on with - | None -> bind_and_fulfill r promise () - | Some on -> Pool.run_async on (bind_and_fulfill r promise)); - + Runner.run_async runner (bind_and_fulfill r promise)); + fut2 + | None, None -> + let fut2, promise = make () in + on_result fut (fun res -> bind_and_fulfill res promise ()); fut2 -let bind_reify_error ?on ~f fut : _ t = bind ?on ~f (reify_error fut) -let join ?on fut = bind ?on fut ~f:(fun x -> x) +let[@inline] bind_reify_error ?on ~f fut : _ t = bind ?on ~f (reify_error fut) -let update_ (st : 'a A.t) f : 'a = +let update_atomic_ (st : 'a A.t) f : 'a = let rec loop () = let x = A.get st in let y = f x in @@ -197,7 +217,7 @@ let both a b : _ t = | Error err -> fulfill_idempotent promise (Error err) | Ok x -> (match - update_ st (function + update_atomic_ st (function | `Neither -> `Left x | `Right y -> `Both (x, y) | _ -> assert false) @@ -208,7 +228,7 @@ let both a b : _ t = | Error err -> fulfill_idempotent promise (Error err) | Ok y -> (match - update_ st (function + update_atomic_ st (function | `Left x -> `Both (x, y) | `Neither -> `Right y | _ -> assert false) @@ -381,9 +401,7 @@ let await (fut : 'a t) : 'a = Suspend_.handle = (fun ~run k -> on_result fut (function - | Ok _ -> - (* run without handler, we're already in a deep effect *) - run ~with_handler:false (fun () -> k (Ok ())) + | Ok _ -> run (fun () -> k (Ok ())) | Error (exn, bt) -> (* fail continuation immediately *) k (Error (exn, bt)))); @@ -393,41 +411,14 @@ let await (fut : 'a t) : 'a = [@@@endif] -module type INFIX = sig - val ( >|= ) : 'a t -> ('a -> 'b) -> 'b t - val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t - val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t - val ( and+ ) : 'a t -> 'b t -> ('a * 'b) t - val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t - val ( and* ) : 'a t -> 'b t -> ('a * 'b) t -end - -module Infix_ (X : sig - val pool : Pool.t option -end) : INFIX = struct - let[@inline] ( >|= ) x f = map ?on:X.pool ~f x - let[@inline] ( >>= ) x f = bind ?on:X.pool ~f x +module Infix = struct + let[@inline] ( >|= ) x f = map ~f x + let[@inline] ( >>= ) x f = bind ~f x let ( let+ ) = ( >|= ) let ( let* ) = ( >>= ) let ( and+ ) = both let ( and* ) = both end -module Infix_local = Infix_ (struct - let pool = None -end) - -include Infix_local - -module Infix (X : sig - val pool : Pool.t -end) = -Infix_ (struct - let pool = Some X.pool -end) - -let[@inline] infix pool : (module INFIX) = - let module M = Infix (struct - let pool = pool - end) in - (module M) +include Infix +module Infix_local = Infix [@@deprecated "use Infix"] diff --git a/src/fut.mli b/src/fut.mli index 944a9525..aa4515f5 100644 --- a/src/fut.mli +++ b/src/fut.mli @@ -85,6 +85,15 @@ val spawn : on:Runner.t -> (unit -> 'a) -> 'a t (** [spaw ~on f] runs [f()] on the given runner [on], and return a future that will hold its result. *) +val spawn_on_current_runner : (unit -> 'a) -> 'a t +(** This must be run from inside a runner, and schedules + the new task on it as well. + + See {!Runner.get_current_runner} to see how the runner is found. + + @since NEXT_RELEASE + @raise Failure if run from outside a runner. *) + val reify_error : 'a t -> 'a or_error t (** [reify_error fut] turns a failing future into a non-failing one that contain [Error (exn, bt)]. A non-failing future @@ -111,7 +120,7 @@ val bind_reify_error : ?on:Runner.t -> f:('a or_error -> 'b t) -> 'a t -> 'b t @param on if provided, [f] runs on the given runner @since 0.4 *) -val join : ?on:Runner.t -> 'a t t -> 'a t +val join : 'a t t -> 'a t (** [join fut] is [fut >>= Fun.id]. It joins the inner layer of the future. @since 0.2 *) @@ -200,7 +209,19 @@ val wait_block : 'a t -> 'a or_error val wait_block_exn : 'a t -> 'a (** Same as {!wait_block} but re-raises the exception if the future failed. *) -module type INFIX = sig +(** {2 Infix operators} + + These combinators run on either the current pool (if present), + or on the same thread that just fulfilled the previous future + if not. + + They were previously present as [module Infix_local] and [val infix], + but are now simplified. + + @since NEXT_RELEASE *) + +(** @since NEXT_RELEASE *) +module Infix : sig val ( >|= ) : 'a t -> ('a -> 'b) -> 'b t val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t @@ -209,17 +230,8 @@ module type INFIX = sig val ( and* ) : 'a t -> 'b t -> ('a * 'b) t end -module Infix_local : INFIX -(** Operators that run on the same thread as the first future. *) +include module type of Infix -include INFIX - -(** Make infix combinators, with intermediate computations running on the given pool. *) -module Infix (_ : sig - val pool : Runner.t -end) : INFIX - -val infix : Runner.t -> (module INFIX) -(** [infix runner] makes a new infix module with intermediate computations - running on the given runner.. - @since 0.2 *) +module Infix_local = Infix +[@@deprecated "Use Infix"] +(** @deprecated use Infix instead *) diff --git a/src/immediate_runner.ml b/src/immediate_runner.ml new file mode 100644 index 00000000..d5e11284 --- /dev/null +++ b/src/immediate_runner.ml @@ -0,0 +1,9 @@ +include Runner + +let runner : t = + Runner.For_runner_implementors.create + ~size:(fun () -> 0) + ~num_tasks:(fun () -> 0) + ~shutdown:(fun ~wait:_ () -> ()) + ~run_async:(fun f -> f ()) + () diff --git a/src/immediate_runner.mli b/src/immediate_runner.mli new file mode 100644 index 00000000..ed017eba --- /dev/null +++ b/src/immediate_runner.mli @@ -0,0 +1,20 @@ +(** Runner that runs tasks immediately in the caller thread. + + Whenever a task is submitted to this runner via [Runner.run_async r task], + the task is run immediately in the caller thread as [task()]. + There are no background threads, no resource, this is just a trivial + implementation of the interface. + + This can be useful when an implementation needs a runner, but there isn't + enough work to justify starting an actual full thread pool. + + Another situation is when threads cannot be used at all (e.g. because you + plan to call [Unix.fork] later). + + @since NEXT_RELEASE +*) + +include module type of Runner + +val runner : t +(** The trivial runner that actually runs tasks at the calling point. *) diff --git a/src/moonpool.ml b/src/moonpool.ml index 83ae22a8..f2cf0174 100644 --- a/src/moonpool.ml +++ b/src/moonpool.ml @@ -2,13 +2,32 @@ let start_thread_on_some_domain f x = let did = Random.int (D_pool_.n_domains ()) in D_pool_.run_on_and_wait did (fun () -> Thread.create f x) +let run_async = Runner.run_async +let recommended_thread_count () = Domain_.recommended_number () +let spawn = Fut.spawn +let spawn_on_current_runner = Fut.spawn_on_current_runner + +[@@@ifge 5.0] + +let await = Fut.await + +[@@@endif] + module Atomic = Atomic_ module Blocking_queue = Bb_queue module Bounded_queue = Bounded_queue module Chan = Chan +module Fifo_pool = Fifo_pool module Fork_join = Fork_join module Fut = Fut module Lock = Lock -module Pool = Pool +module Immediate_runner = Immediate_runner +module Pool = Fifo_pool module Runner = Runner -module Suspend_ = Suspend_ +module Thread_local_storage = Thread_local_storage_ +module Ws_pool = Ws_pool + +module Private = struct + module Ws_deque_ = Ws_deque_ + module Suspend_ = Suspend_ +end diff --git a/src/moonpool.mli b/src/moonpool.mli index 1d300665..60c0ede6 100644 --- a/src/moonpool.mli +++ b/src/moonpool.mli @@ -1,21 +1,63 @@ (** Moonpool A pool within a bigger pool (ie the ocean). Here, we're talking about - pools of [Thread.t] which live within a fixed pool of [Domain.t]. + pools of [Thread.t] that are dispatched over several [Domain.t] to + enable parallelism. + + We provide several implementations of pools + with distinct scheduling strategies, alongside some concurrency + primitives such as guarding locks ({!Lock.t}) and futures ({!Fut.t}). *) -module Pool = Pool +module Ws_pool = Ws_pool +module Fifo_pool = Fifo_pool module Runner = Runner +module Immediate_runner = Immediate_runner + +module Pool = Fifo_pool +[@@deprecated "use Fifo_pool or Ws_pool to be more explicit"] +(** Default pool. Please explicitly pick an implementation instead. *) val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t (** Similar to {!Thread.create}, but it picks a background domain at random to run the thread. This ensures that we don't always pick the same domain to run all the various threads needed in an application (timers, event loops, etc.) *) +val run_async : Runner.t -> (unit -> unit) -> unit +(** [run_async runner task] schedules the task to run + on the given runner. This means [task()] will be executed + at some point in the future, possibly in another thread. + @since NEXT_RELEASE *) + +val recommended_thread_count : unit -> int +(** Number of threads recommended to saturate the CPU. + For IO pools this makes little sense (you might want more threads than + this because many of them will be blocked most of the time). + @since NEXT_RELEASE *) + +val spawn : on:Runner.t -> (unit -> 'a) -> 'a Fut.t +(** [spawn ~on f] runs [f()] on the runner (a thread pool typically) + and returns a future result for it. See {!Fut.spawn}. + @since NEXT_RELEASE *) + +val spawn_on_current_runner : (unit -> 'a) -> 'a Fut.t +(** See {!Fut.spawn_on_current_runner}. + @since NEXT_RELEASE *) + +[@@@ifge 5.0] + +val await : 'a Fut.t -> 'a +(** Await a future. See {!Fut.await}. + Only on OCaml >= 5.0. + @since NEXT_RELEASE *) + +[@@@endif] + module Lock = Lock module Fut = Fut module Chan = Chan module Fork_join = Fork_join +module Thread_local_storage = Thread_local_storage_ (** A simple blocking queue. @@ -141,12 +183,19 @@ module Atomic = Atomic_ This is either a shim using [ref], on pre-OCaml 5, or the standard [Atomic] module on OCaml 5. *) -(** {2 Suspensions} *) +(**/**) -module Suspend_ = Suspend_ -[@@alert unstable "this module is an implementation detail of moonpool for now"] -(** Suspensions. +module Private : sig + module Ws_deque_ = Ws_deque_ + + (** {2 Suspensions} *) + + module Suspend_ = Suspend_ + [@@alert + unstable "this module is an implementation detail of moonpool for now"] + (** Suspensions. This is only going to work on OCaml 5.x. {b NOTE}: this is not stable for now. *) +end diff --git a/src/pool.ml b/src/pool.ml deleted file mode 100644 index 43cda564..00000000 --- a/src/pool.ml +++ /dev/null @@ -1,282 +0,0 @@ -(* TODO: use a better queue for the tasks *) - -module A = Atomic_ -include Runner - -let ( let@ ) = ( @@ ) - -type thread_loop_wrapper = - thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit - -let global_thread_wrappers_ : thread_loop_wrapper list A.t = A.make [] - -let add_global_thread_loop_wrapper f : unit = - while - let l = A.get global_thread_wrappers_ in - not (A.compare_and_set global_thread_wrappers_ l (f :: l)) - do - Domain_.relax () - done - -type state = { - active: bool A.t; - threads: Thread.t array; - qs: task Bb_queue.t array; - cur_q: int A.t; (** Selects queue into which to push *) -} -(** internal state *) - -(** Run [task] as is, on the pool. *) -let run_direct_ (self : state) (task : task) : unit = - let n_qs = Array.length self.qs in - let offset = A.fetch_and_add self.cur_q 1 in - - (* blocking push, last resort *) - let[@inline] push_wait f = - let q_idx = offset mod Array.length self.qs in - let q = self.qs.(q_idx) in - Bb_queue.push q f - in - - try - (* try each queue with a round-robin initial offset *) - for _retry = 1 to 10 do - for i = 0 to n_qs - 1 do - let q_idx = (i + offset) mod Array.length self.qs in - let q = self.qs.(q_idx) in - if Bb_queue.try_push q task then raise_notrace Exit - done - done; - push_wait task - with - | Exit -> () - | Bb_queue.Closed -> raise Shutdown - -let rec run_async_ (self : state) (task : task) : unit = - let task' () = - (* run [f()] and handle [suspend] in it *) - Suspend_.with_suspend task ~run:(fun ~with_handler task -> - if with_handler then - run_async_ self task - else - run_direct_ self task) - in - run_direct_ self task' - -let run = run_async -let size_ (self : state) = Array.length self.threads - -let num_tasks_ (self : state) : int = - let n = ref 0 in - Array.iter (fun q -> n := !n + Bb_queue.size q) self.qs; - !n - -[@@@ifge 5.0] - -(* DLA interop *) -let prepare_for_await () : Dla_.t = - (* current state *) - let st : - ((with_handler:bool -> task -> unit) * Suspend_.suspension) option A.t = - A.make None - in - - let release () : unit = - match A.exchange st None with - | None -> () - | Some (run, k) -> run ~with_handler:true (fun () -> k (Ok ())) - and await () : unit = - Suspend_.suspend - { Suspend_.handle = (fun ~run k -> A.set st (Some (run, k))) } - in - - let t = { Dla_.release; await } in - t - -[@@@else_] - -let prepare_for_await () = { Dla_.release = ignore; await = ignore } - -[@@@endif] - -exception Got_task of task - -type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task - -let worker_thread_ (runner : t) ~on_exn ~around_task (active : bool A.t) - (qs : task Bb_queue.t array) ~(offset : int) : unit = - let num_qs = Array.length qs in - let (AT_pair (before_task, after_task)) = around_task in - - let main_loop () = - while A.get active do - (* last resort: block on my queue *) - let pop_blocking () = - let my_q = qs.(offset mod num_qs) in - Bb_queue.pop my_q - in - - let task = - try - for i = 0 to num_qs - 1 do - let q = qs.((offset + i) mod num_qs) in - match Bb_queue.try_pop ~force_lock:false q with - | Some f -> raise_notrace (Got_task f) - | None -> () - done; - pop_blocking () - with Got_task f -> f - in - - let _ctx = before_task runner in - (* run the task now, catching errors *) - (try task () - with e -> - let bt = Printexc.get_raw_backtrace () in - on_exn e bt); - after_task runner _ctx - done - in - - try - (* handle domain-local await *) - Dla_.using ~prepare_for_await ~while_running:main_loop - with Bb_queue.Closed -> () - -let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () - -(** We want a reasonable number of queues. Even if your system is - a beast with hundreds of cores, trying - to work-steal through hundreds of queues will have a cost. - - Hence, we limit the number of queues to at most 32 (number picked - via the ancestral technique of the pifomètre). *) -let max_queues = 32 - -let shutdown_ ~wait (self : state) : unit = - let was_active = A.exchange self.active false in - (* close the job queues, which will fail future calls to [run], - and wake up the subset of [self.threads] that are waiting on them. *) - if was_active then Array.iter Bb_queue.close self.qs; - if wait then Array.iter Thread.join self.threads - -type ('a, 'b) create_args = - ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> - ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> - ?thread_wrappers:thread_loop_wrapper list -> - ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> - ?around_task:(t -> 'b) * (t -> 'b -> unit) -> - ?min:int -> - ?per_domain:int -> - 'a -(** Arguments used in {!create}. See {!create} for explanations. *) - -let create ?(on_init_thread = default_thread_init_exit_) - ?(on_exit_thread = default_thread_init_exit_) ?(thread_wrappers = []) - ?(on_exn = fun _ _ -> ()) ?around_task ?min:(min_threads = 1) - ?(per_domain = 0) () : t = - (* wrapper *) - let around_task = - match around_task with - | Some (f, g) -> AT_pair (f, g) - | None -> AT_pair (ignore, fun _ _ -> ()) - in - - (* number of threads to run *) - let min_threads = max 1 min_threads in - let num_domains = D_pool_.n_domains () in - assert (num_domains >= 1); - let num_threads = max min_threads (num_domains * per_domain) in - - (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) - let offset = Random.int num_domains in - - let active = A.make true in - let qs = - let num_qs = min (min num_domains num_threads) max_queues in - Array.init num_qs (fun _ -> Bb_queue.create ()) - in - - let pool = - let dummy = Thread.self () in - { active; threads = Array.make num_threads dummy; qs; cur_q = A.make 0 } - in - - let runner = - Runner.For_runner_implementors.create - ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) - ~run_async:(fun f -> run_async_ pool f) - ~size:(fun () -> size_ pool) - ~num_tasks:(fun () -> num_tasks_ pool) - () - in - - (* temporary queue used to obtain thread handles from domains - on which the thread are started. *) - let receive_threads = Bb_queue.create () in - - (* start the thread with index [i] *) - let start_thread_with_idx i = - let dom_idx = (offset + i) mod num_domains in - - (* function run in the thread itself *) - let main_thread_fun () : unit = - let thread = Thread.self () in - let t_id = Thread.id thread in - on_init_thread ~dom_id:dom_idx ~t_id (); - - let all_wrappers = - List.rev_append thread_wrappers (A.get global_thread_wrappers_) - in - - let run () = - worker_thread_ runner ~on_exn ~around_task active qs ~offset:i - in - (* the actual worker loop is [worker_thread_], with all - wrappers for this pool and for all pools (global_thread_wrappers_) *) - let run' = - List.fold_left - (fun run f -> f ~thread ~pool:runner run) - run all_wrappers - in - - (* now run the main loop *) - Fun.protect run' ~finally:(fun () -> - (* on termination, decrease refcount of underlying domain *) - D_pool_.decr_on dom_idx); - on_exit_thread ~dom_id:dom_idx ~t_id () - in - - (* function called in domain with index [i], to - create the thread and push it into [receive_threads] *) - let create_thread_in_domain () = - let thread = Thread.create main_thread_fun () in - (* send the thread from the domain back to us *) - Bb_queue.push receive_threads (i, thread) - in - - D_pool_.run_on dom_idx create_thread_in_domain - in - - (* start all threads, placing them on the domains - according to their index and [offset] in a round-robin fashion. *) - for i = 0 to num_threads - 1 do - start_thread_with_idx i - done; - - (* receive the newly created threads back from domains *) - for _j = 1 to num_threads do - let i, th = Bb_queue.pop receive_threads in - pool.threads.(i) <- th - done; - - runner - -let with_ ?on_init_thread ?on_exit_thread ?thread_wrappers ?on_exn ?around_task - ?min ?per_domain () f = - let pool = - create ?on_init_thread ?on_exit_thread ?thread_wrappers ?on_exn ?around_task - ?min ?per_domain () - in - let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in - f pool diff --git a/src/runner.ml b/src/runner.ml index 91cde5a2..0fcf2392 100644 --- a/src/runner.ml +++ b/src/runner.ml @@ -1,3 +1,5 @@ +module TLS = Thread_local_storage_ + type task = unit -> unit type t = { @@ -34,4 +36,9 @@ let run_wait_block self (f : unit -> 'a) : 'a = module For_runner_implementors = struct let create ~size ~num_tasks ~shutdown ~run_async () : t = { size; num_tasks; shutdown; run_async } + + let k_cur_runner : t option ref TLS.key = TLS.new_key (fun () -> ref None) end + +let[@inline] get_current_runner () : _ option = + !(TLS.get For_runner_implementors.k_cur_runner) diff --git a/src/runner.mli b/src/runner.mli index cda20720..471d21af 100644 --- a/src/runner.mli +++ b/src/runner.mli @@ -1,17 +1,13 @@ -(** Abstract runner. +(** Interface for runners. - This provides an abstraction for running tasks in the background. + This provides an abstraction for running tasks in the background, + which is implemented by various thread pools. @since 0.3 *) type task = unit -> unit -type t = private { - run_async: task -> unit; - shutdown: wait:bool -> unit -> unit; - size: unit -> int; - num_tasks: unit -> int; -} +type t (** A runner. If a runner is no longer needed, {!shutdown} can be used to signal all @@ -50,8 +46,11 @@ val run_wait_block : t -> (unit -> 'a) -> 'a and returns its result. If [f()] raises an exception, then [run_wait_block pool f] will raise it as well. - {b NOTE} be careful with deadlocks (see notes in {!Fut.wait_block}). *) + {b NOTE} be careful with deadlocks (see notes in {!Fut.wait_block} + about the required discipline to avoid deadlocks). *) +(** This module is specifically intended for users who implement their + own runners. Regular users of Moonpool should not need to look at it. *) module For_runner_implementors : sig val create : size:(unit -> int) -> @@ -64,4 +63,11 @@ module For_runner_implementors : sig {b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, so that {!Fork_join} and other 5.x features work properly. *) + + val k_cur_runner : t option ref Thread_local_storage_.key end + +val get_current_runner : unit -> t option +(** Access the current runner. This returns [Some r] if the call + happens on a thread that belongs in a runner. + @since NEXT_RELEASE *) diff --git a/src/suspend_.ml b/src/suspend_.ml index 1a0b55df..6555b6bc 100644 --- a/src/suspend_.ml +++ b/src/suspend_.ml @@ -1,20 +1,19 @@ type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit type task = unit -> unit -type suspension_handler = { - handle: run:(with_handler:bool -> task -> unit) -> suspension -> unit; -} +type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } [@@unboxed] [@@@ifge 5.0] [@@@ocaml.alert "-unstable"] +module A = Atomic_ + type _ Effect.t += Suspend : suspension_handler -> unit Effect.t let[@inline] suspend h = Effect.perform (Suspend h) -let with_suspend ~(run : with_handler:bool -> task -> unit) (f : unit -> unit) : - unit = +let with_suspend ~(run : task -> unit) (f : unit -> unit) : unit = let module E = Effect.Deep in (* effect handler *) let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = @@ -32,9 +31,26 @@ let with_suspend ~(run : with_handler:bool -> task -> unit) (f : unit -> unit) : E.try_with f () { E.effc } +(* DLA interop *) +let prepare_for_await () : Dla_.t = + (* current state *) + let st : ((task -> unit) * suspension) option A.t = A.make None in + + let release () : unit = + match A.exchange st None with + | None -> () + | Some (run, k) -> run (fun () -> k (Ok ())) + and await () : unit = + suspend { handle = (fun ~run k -> A.set st (Some (run, k))) } + in + + let t = { Dla_.release; await } in + t + [@@@ocaml.alert "+unstable"] [@@@else_] -let with_suspend ~run:_ f = f () +let[@inline] with_suspend ~run:_ f = f () +let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore } [@@@endif] diff --git a/src/suspend_.mli b/src/suspend_.mli index 032bc3e0..77cc06af 100644 --- a/src/suspend_.mli +++ b/src/suspend_.mli @@ -8,9 +8,7 @@ type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit type task = unit -> unit -type suspension_handler = { - handle: run:(with_handler:bool -> task -> unit) -> suspension -> unit; -} +type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } [@@unboxed] (** The handler that knows what to do with the suspended computation. @@ -50,8 +48,10 @@ val suspend : suspension_handler -> unit [@@@endif] -val with_suspend : - run:(with_handler:bool -> task -> unit) -> (unit -> unit) -> unit +val prepare_for_await : unit -> Dla_.t +(** Our stub for DLA. Unstable. *) + +val with_suspend : run:(task -> unit) -> (unit -> unit) -> unit (** [with_suspend ~run f] runs [f()] in an environment where [suspend] will work. If [f()] suspends with suspension handler [h], this calls [h ~run k] where [k] is the suspension. diff --git a/src/thread_local_storage_.mli b/src/thread_local_storage_.mli new file mode 100644 index 00000000..b7b50706 --- /dev/null +++ b/src/thread_local_storage_.mli @@ -0,0 +1,21 @@ +(** Thread local storage *) + +(* TODO: alias this to the library if present *) + +type 'a key +(** A TLS key for values of type ['a]. This allows the + storage of a single value of type ['a] per thread. *) + +val new_key : (unit -> 'a) -> 'a key +(** Allocate a new, generative key. + When the key is used for the first time on a thread, + the function is called to produce it. + + This should only ever be called at toplevel to produce + constants, do not use it in a loop. *) + +val get : 'a key -> 'a +(** Get the value for the current thread. *) + +val set : 'a key -> 'a -> unit +(** Set the value for the current thread. *) diff --git a/src/thread_local_storage_.real.ml b/src/thread_local_storage_.real.ml new file mode 100644 index 00000000..70d7a558 --- /dev/null +++ b/src/thread_local_storage_.real.ml @@ -0,0 +1,82 @@ +(* see: https://discuss.ocaml.org/t/a-hack-to-implement-efficient-tls-thread-local-storage/13264 *) + +module A = Atomic_ + +(* sanity check *) +let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ()) + +type 'a key = { + index: int; (** Unique index for this key. *) + compute: unit -> 'a; + (** Initializer for values for this key. Called at most + once per thread. *) +} + +(** Counter used to allocate new keys *) +let counter = A.make 0 + +(** Value used to detect a TLS slot that was not initialized yet *) +let[@inline] sentinel_value_for_uninit_tls_ () : Obj.t = Obj.repr counter + +let new_key compute : _ key = + let index = A.fetch_and_add counter 1 in + { index; compute } + +type thread_internal_state = { + _id: int; (** Thread ID (here for padding reasons) *) + mutable tls: Obj.t; (** Our data, stowed away in this unused field *) +} +(** A partial representation of the internal type [Thread.t], allowing + us to access the second field (unused after the thread + has started) and stash TLS data in it. *) + +let ceil_pow_2_minus_1 (n : int) : int = + let n = n lor (n lsr 1) in + let n = n lor (n lsr 2) in + let n = n lor (n lsr 4) in + let n = n lor (n lsr 8) in + let n = n lor (n lsr 16) in + if Sys.int_size > 32 then + n lor (n lsr 32) + else + n + +(** Grow the array so that [index] is valid. *) +let[@inline never] grow_tls (old : Obj.t array) (index : int) : Obj.t array = + let new_length = ceil_pow_2_minus_1 (index + 1) in + let new_ = Array.make new_length (sentinel_value_for_uninit_tls_ ()) in + Array.blit old 0 new_ 0 (Array.length old); + new_ + +let[@inline] get_tls_ (index : int) : Obj.t array = + let thread : thread_internal_state = Obj.magic (Thread.self ()) in + let tls = thread.tls in + if Obj.is_int tls then ( + let new_tls = grow_tls [||] index in + thread.tls <- Obj.magic new_tls; + new_tls + ) else ( + let tls = (Obj.magic tls : Obj.t array) in + if index < Array.length tls then + tls + else ( + let new_tls = grow_tls tls index in + thread.tls <- Obj.magic new_tls; + new_tls + ) + ) + +let get key = + let tls = get_tls_ key.index in + let value = Array.unsafe_get tls key.index in + if value != sentinel_value_for_uninit_tls_ () then + Obj.magic value + else ( + let value = key.compute () in + Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)); + value + ) + +let set key value = + let tls = get_tls_ key.index in + Array.unsafe_set tls key.index (Obj.repr (Sys.opaque_identity value)) diff --git a/src/thread_local_storage_.stub.ml b/src/thread_local_storage_.stub.ml new file mode 100644 index 00000000..88712b6d --- /dev/null +++ b/src/thread_local_storage_.stub.ml @@ -0,0 +1,3 @@ + +(* just defer to library *) +include Thread_local_storage diff --git a/src/util_pool_.ml b/src/util_pool_.ml new file mode 100644 index 00000000..8207062a --- /dev/null +++ b/src/util_pool_.ml @@ -0,0 +1,11 @@ +let num_threads ?num_threads () : int = + let n_domains = D_pool_.n_domains () in + + (* number of threads to run *) + let num_threads = + match num_threads with + | Some j -> max 1 j + | None -> n_domains + in + + num_threads diff --git a/src/util_pool_.mli b/src/util_pool_.mli new file mode 100644 index 00000000..68fdde22 --- /dev/null +++ b/src/util_pool_.mli @@ -0,0 +1,5 @@ +(** Utils for pools *) + +val num_threads : ?num_threads:int -> unit -> int +(** Number of threads a pool should have. + @param num_threads user-specified number of threads *) diff --git a/src/ws_deque_.ml b/src/ws_deque_.ml new file mode 100644 index 00000000..6c5d1419 --- /dev/null +++ b/src/ws_deque_.ml @@ -0,0 +1,122 @@ +module A = Atomic_ + +(* terminology: + + - Bottom: where we push/pop normally. Only one thread can do that. + - top: where work stealing happens (older values). + This only ever grows. + + Elements are always added on the bottom end. *) + +(** Circular array (size is [2 ^ log_size]) *) +module CA : sig + type 'a t + + val create : dummy:'a -> unit -> 'a t + val size : 'a t -> int + val get : 'a t -> int -> 'a + val set : 'a t -> int -> 'a -> unit +end = struct + (** The array has size 256. *) + let log_size = 8 + + type 'a t = { arr: 'a array } [@@unboxed] + + let[@inline] size (_self : _ t) = 1 lsl log_size + let create ~dummy () : _ t = { arr = Array.make (1 lsl log_size) dummy } + + let[@inline] get (self : 'a t) (i : int) : 'a = + Array.unsafe_get self.arr (i land ((1 lsl log_size) - 1)) + + let[@inline] set (self : 'a t) (i : int) (x : 'a) : unit = + Array.unsafe_set self.arr (i land ((1 lsl log_size) - 1)) x +end + +type 'a t = { + top: int A.t; (** Where we steal *) + bottom: int A.t; (** Where we push/pop from the owning thread *) + mutable top_cached: int; (** Last read value of [top] *) + arr: 'a CA.t; (** The circular array *) +} + +let create ~dummy () : _ t = + let top = A.make 0 in + let arr = CA.create ~dummy () in + (* allocate far from [top] to avoid false sharing *) + let bottom = A.make 0 in + { top; top_cached = 0; bottom; arr } + +let[@inline] size (self : _ t) : int = max 0 (A.get self.bottom - A.get self.top) + +exception Full + +let push (self : 'a t) (x : 'a) : bool = + try + let b = A.get self.bottom in + let t_approx = self.top_cached in + + (* Section 2.3: over-approximation of size. + Only if it seems too big do we actually read [t]. *) + let size_approx = b - t_approx in + if size_approx >= CA.size self.arr - 1 then ( + (* we need to read the actual value of [top], which might entail contention. *) + let t = A.get self.top in + self.top_cached <- t; + let size = b - t in + + if size >= CA.size self.arr - 1 then (* full! *) raise_notrace Full + ); + + CA.set self.arr b x; + A.set self.bottom (b + 1); + true + with Full -> false + +let pop (self : 'a t) : 'a option = + let b = A.get self.bottom in + let b = b - 1 in + A.set self.bottom b; + + let t = A.get self.top in + self.top_cached <- t; + + let size = b - t in + if size < 0 then ( + (* reset to basic empty state *) + A.set self.bottom t; + None + ) else if size > 0 then ( + (* can pop without modifying [top] *) + let x = CA.get self.arr b in + Some x + ) else ( + assert (size = 0); + (* there was exactly one slot, so we might be racing against stealers + to update [self.top] *) + if A.compare_and_set self.top t (t + 1) then ( + let x = CA.get self.arr b in + A.set self.bottom (t + 1); + Some x + ) else ( + A.set self.bottom (t + 1); + None + ) + ) + +let steal (self : 'a t) : 'a option = + (* read [top], but do not update [top_cached] + as we're in another thread *) + let t = A.get self.top in + let b = A.get self.bottom in + + let size = b - t in + if size <= 0 then + None + else ( + let x = CA.get self.arr t in + if A.compare_and_set self.top t (t + 1) then + (* successfully increased top to consume [x] *) + Some x + else + None + ) diff --git a/src/ws_deque_.mli b/src/ws_deque_.mli new file mode 100644 index 00000000..b696224e --- /dev/null +++ b/src/ws_deque_.mli @@ -0,0 +1,27 @@ +(** Work-stealing deque. + + Adapted from "Dynamic circular work stealing deque", Chase & Lev. + + However note that this one is not dynamic in the sense that there + is no resizing. Instead we return [false] when [push] fails, which + keeps the implementation fairly lightweight. + *) + +type 'a t +(** Deque containing values of type ['a] *) + +val create : dummy:'a -> unit -> 'a t +(** Create a new deque. *) + +val push : 'a t -> 'a -> bool +(** Push value at the bottom of deque. returns [true] if it succeeds. + This must be called only by the owner thread. *) + +val pop : 'a t -> 'a option +(** Pop value from the bottom of deque. + This must be called only by the owner thread. *) + +val steal : 'a t -> 'a option +(** Try to steal from the top of deque. This is thread-safe. *) + +val size : _ t -> int diff --git a/src/ws_pool.ml b/src/ws_pool.ml new file mode 100644 index 00000000..d32c71f8 --- /dev/null +++ b/src/ws_pool.ml @@ -0,0 +1,337 @@ +module WSQ = Ws_deque_ +module A = Atomic_ +module TLS = Thread_local_storage_ +include Runner + +let ( let@ ) = ( @@ ) + +module Id = struct + type t = unit ref + (** Unique identifier for a pool *) + + let create () : t = Sys.opaque_identity (ref ()) + let equal : t -> t -> bool = ( == ) +end + +type worker_state = { + pool_id_: Id.t; (** Unique per pool *) + mutable thread: Thread.t; + q: task WSQ.t; (** Work stealing queue *) + rng: Random.State.t; +} +(** State for a given worker. Only this worker is + allowed to push into the queue, but other workers + can come and steal from it if they're idle. *) + +type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task + +type state = { + id_: Id.t; + active: bool A.t; (** Becomes [false] when the pool is shutdown. *) + workers: worker_state array; (** Fixed set of workers. *) + main_q: task Queue.t; (** Main queue for tasks coming from the outside *) + mutable n_waiting: int; (* protected by mutex *) + mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *) + mutex: Mutex.t; + cond: Condition.t; + on_exn: exn -> Printexc.raw_backtrace -> unit; + around_task: around_task; +} +(** internal state *) + +let[@inline] size_ (self : state) = Array.length self.workers + +let num_tasks_ (self : state) : int = + let n = ref 0 in + n := Queue.length self.main_q; + Array.iter (fun w -> n := !n + WSQ.size w.q) self.workers; + !n + +(** TLS, used by worker to store their specific state + and be able to retrieve it from tasks when we schedule new + sub-tasks. *) +let k_worker_state : worker_state option ref TLS.key = + TLS.new_key (fun () -> ref None) + +let[@inline] find_current_worker_ () : worker_state option = + !(TLS.get k_worker_state) + +(** Try to wake up a waiter, if there's any. *) +let[@inline] try_wake_someone_ (self : state) : unit = + if self.n_waiting_nonzero then ( + Mutex.lock self.mutex; + Condition.signal self.cond; + Mutex.unlock self.mutex + ) + +(** Run [task] as is, on the pool. *) +let schedule_task_ (self : state) (w : worker_state option) (task : task) : unit + = + (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) + match w with + | Some w when Id.equal self.id_ w.pool_id_ -> + (* we're on this same pool, schedule in the worker's state. Otherwise + we might also be on pool A but asking to schedule on pool B, + so we have to check that identifiers match. *) + let pushed = WSQ.push w.q task in + if pushed then + try_wake_someone_ self + else ( + (* overflow into main queue *) + Mutex.lock self.mutex; + Queue.push task self.main_q; + if self.n_waiting_nonzero then Condition.signal self.cond; + Mutex.unlock self.mutex + ) + | _ -> + if A.get self.active then ( + (* push into the main queue *) + Mutex.lock self.mutex; + Queue.push task self.main_q; + if self.n_waiting_nonzero then Condition.signal self.cond; + Mutex.unlock self.mutex + ) else + (* notify the caller that scheduling tasks is no + longer permitted *) + raise Shutdown + +(** Run this task, now. Must be called from a worker. *) +let run_task_now_ (self : state) ~runner task : unit = + (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) + let (AT_pair (before_task, after_task)) = self.around_task in + let _ctx = before_task runner in + (* run the task now, catching errors *) + (try + (* run [task()] and handle [suspend] in it *) + Suspend_.with_suspend task ~run:(fun task' -> + let w = find_current_worker_ () in + schedule_task_ self w task') + with e -> + let bt = Printexc.get_raw_backtrace () in + self.on_exn e bt); + after_task runner _ctx + +let[@inline] run_async_ (self : state) (task : task) : unit = + let w = find_current_worker_ () in + schedule_task_ self w task + +(* TODO: function to schedule many tasks from the outside. + - build a queue + - lock + - queue transfer + - wakeup all (broadcast) + - unlock *) + +let run = run_async + +(** Wait on condition. Precondition: we hold the mutex. *) +let[@inline] wait_ (self : state) : unit = + self.n_waiting <- self.n_waiting + 1; + if self.n_waiting = 1 then self.n_waiting_nonzero <- true; + Condition.wait self.cond self.mutex; + self.n_waiting <- self.n_waiting - 1; + if self.n_waiting = 0 then self.n_waiting_nonzero <- false + +exception Got_task of task + +(** Try to steal a task *) +let try_to_steal_work_once_ (self : state) (w : worker_state) : task option = + let init = Random.State.int w.rng (Array.length self.workers) in + + try + for i = 0 to Array.length self.workers - 1 do + let w' = + Array.unsafe_get self.workers ((i + init) mod Array.length self.workers) + in + + if w != w' then ( + match WSQ.steal w'.q with + | Some t -> raise_notrace (Got_task t) + | None -> () + ) + done; + None + with Got_task t -> Some t + +(** Worker runs tasks from its queue until none remains *) +let worker_run_self_tasks_ (self : state) ~runner w : unit = + let continue = ref true in + while !continue && A.get self.active do + match WSQ.pop w.q with + | Some task -> + try_wake_someone_ self; + run_task_now_ self ~runner task + | None -> continue := false + done + +(** Main loop for a worker thread. *) +let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = + TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; + TLS.get k_worker_state := Some w; + + let rec main () : unit = + if A.get self.active then ( + worker_run_self_tasks_ self ~runner w; + try_steal () + ) + and run_task task : unit = + run_task_now_ self ~runner task; + main () + and try_steal () = + if A.get self.active then ( + match try_to_steal_work_once_ self w with + | Some task -> run_task task + | None -> wait () + ) + and wait () = + Mutex.lock self.mutex; + match Queue.pop self.main_q with + | task -> + Mutex.unlock self.mutex; + run_task task + | exception Queue.Empty -> + (* wait here *) + if A.get self.active then wait_ self; + + (* see if a task became available *) + let task = try Some (Queue.pop self.main_q) with Queue.Empty -> None in + Mutex.unlock self.mutex; + + (match task with + | Some t -> run_task t + | None -> try_steal ()) + in + + (* handle domain-local await *) + Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main + +let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () + +let shutdown_ ~wait (self : state) : unit = + if A.exchange self.active false then ( + Mutex.lock self.mutex; + Condition.broadcast self.cond; + Mutex.unlock self.mutex; + if wait then Array.iter (fun w -> Thread.join w.thread) self.workers + ) + +type ('a, 'b) create_args = + ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> + ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> + ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> + ?around_task:(t -> 'b) * (t -> 'b -> unit) -> + ?num_threads:int -> + 'a +(** Arguments used in {!create}. See {!create} for explanations. *) + +let dummy_task_ () = assert false + +let create ?(on_init_thread = default_thread_init_exit_) + ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) + ?around_task ?num_threads () : t = + let pool_id_ = Id.create () in + (* wrapper *) + let around_task = + match around_task with + | Some (f, g) -> AT_pair (f, g) + | None -> AT_pair (ignore, fun _ _ -> ()) + in + + let num_domains = D_pool_.n_domains () in + let num_threads = Util_pool_.num_threads ?num_threads () in + + (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) + let offset = Random.int num_domains in + + let workers : worker_state array = + let dummy = Thread.self () in + Array.init num_threads (fun i -> + { + pool_id_; + thread = dummy; + q = WSQ.create ~dummy:dummy_task_ (); + rng = Random.State.make [| i |]; + }) + in + + let pool = + { + id_ = pool_id_; + active = A.make true; + workers; + main_q = Queue.create (); + n_waiting = 0; + n_waiting_nonzero = true; + mutex = Mutex.create (); + cond = Condition.create (); + around_task; + on_exn; + } + in + + let runner = + Runner.For_runner_implementors.create + ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) + ~run_async:(fun f -> run_async_ pool f) + ~size:(fun () -> size_ pool) + ~num_tasks:(fun () -> num_tasks_ pool) + () + in + + (* temporary queue used to obtain thread handles from domains + on which the thread are started. *) + let receive_threads = Bb_queue.create () in + + (* start the thread with index [i] *) + let start_thread_with_idx i = + let w = pool.workers.(i) in + let dom_idx = (offset + i) mod num_domains in + + (* function run in the thread itself *) + let main_thread_fun () : unit = + let thread = Thread.self () in + let t_id = Thread.id thread in + on_init_thread ~dom_id:dom_idx ~t_id (); + + let run () = worker_thread_ pool ~runner w in + + (* now run the main loop *) + Fun.protect run ~finally:(fun () -> + (* on termination, decrease refcount of underlying domain *) + D_pool_.decr_on dom_idx); + on_exit_thread ~dom_id:dom_idx ~t_id () + in + + (* function called in domain with index [i], to + create the thread and push it into [receive_threads] *) + let create_thread_in_domain () = + let thread = Thread.create main_thread_fun () in + (* send the thread from the domain back to us *) + Bb_queue.push receive_threads (i, thread) + in + + D_pool_.run_on dom_idx create_thread_in_domain + in + + (* start all threads, placing them on the domains + according to their index and [offset] in a round-robin fashion. *) + for i = 0 to num_threads - 1 do + start_thread_with_idx i + done; + + (* receive the newly created threads back from domains *) + for _j = 1 to num_threads do + let i, th = Bb_queue.pop receive_threads in + let worker_state = pool.workers.(i) in + worker_state.thread <- th + done; + + runner + +let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads () f + = + let pool = + create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads () + in + let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in + f pool diff --git a/src/pool.mli b/src/ws_pool.mli similarity index 50% rename from src/pool.mli rename to src/ws_pool.mli index 11cac88b..c13e4c75 100644 --- a/src/pool.mli +++ b/src/ws_pool.mli @@ -1,7 +1,13 @@ -(** Thread pool. +(** Work-stealing thread pool. - A pool of threads. The pool contains a fixed number of threads that - wait for work items to come, process these, and loop. + A pool of threads with a worker-stealing scheduler. + The pool contains a fixed number of threads that wait for work + items to come, process these, and loop. + + This is good for CPU-intensive tasks that feature a lot of small tasks. + Note that tasks will not always be processed in the order they are + scheduled, so this is not great for workloads where the latency + of individual tasks matter (for that see {!Fifo_pool}). This implements {!Runner.t} since 0.3. @@ -15,27 +21,12 @@ include module type of Runner -type thread_loop_wrapper = - thread:Thread.t -> pool:t -> (unit -> unit) -> unit -> unit -(** A thread wrapper [f] takes the current thread, the current pool, - and the worker function [loop : unit -> unit] which is - the worker's main loop, and returns a new loop function. - By default it just returns the same loop function but it can be used - to install tracing, effect handlers, etc. *) - -val add_global_thread_loop_wrapper : thread_loop_wrapper -> unit -(** [add_global_thread_loop_wrapper f] installs [f] to be installed in every new pool worker - thread, for all existing pools, and all new pools created with [create]. - These wrappers accumulate: they all apply, but their order is not specified. *) - type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> - ?thread_wrappers:thread_loop_wrapper list -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?around_task:(t -> 'b) * (t -> 'b -> unit) -> - ?min:int -> - ?per_domain:int -> + ?num_threads:int -> 'a (** Arguments used in {!create}. See {!create} for explanations. *) @@ -43,17 +34,12 @@ val create : (unit -> t, _) create_args (** [create ()] makes a new thread pool. @param on_init_thread called at the beginning of each new thread in the pool. - @param min minimum size of the pool. It will be at least [1] internally, - so [0] or negative values make no sense. - @param per_domain is the number of threads allocated per domain in the fixed - domain pool. The default value is [0], but setting, say, [~per_domain:2] - means that if there are [8] domains (which might be the case on an 8-core machine) - then the minimum size of the pool is [16]. - If both [min] and [per_domain] are specified, the maximum of both - [min] and [per_domain * num_of_domains] is used. + @param num_threads size of the pool, ie. number of worker threads. + It will be at least [1] internally, so [0] or negative values make no sense. + The default is [Domain.recommended_domain_count()], ie one worker + thread per CPU core. + On OCaml 4 the default is [4] (since there is only one domain). @param on_exit_thread called at the end of each thread in the pool - @param thread_wrappers a list of {!thread_loop_wrapper} functions - to use for this pool's workers. @param around_task a pair of [before, after], where [before pool] is called before a task is processed, on the worker thread about to run it, and returns [x]; and [after pool x] is called by diff --git a/test/dune b/test/dune index 72c44bbf..43955ec6 100644 --- a/test/dune +++ b/test/dune @@ -1,6 +1,7 @@ (tests (names t_fib + t_ws_pool_confusion t_bench1 t_fib_rec t_futs1 @@ -8,10 +9,14 @@ t_props t_chan_train t_resource + t_unfair + t_ws_deque t_bounded_queue) (libraries moonpool qcheck-core qcheck-core.runner ;tracy-client.trace + unix + trace-tef trace)) diff --git a/test/effect-based/dune b/test/effect-based/dune index 1d4898d3..9989823f 100644 --- a/test/effect-based/dune +++ b/test/effect-based/dune @@ -1,11 +1,11 @@ (tests (names t_fib1 t_futs1 t_many t_fib_fork_join - t_fib_fork_join_all t_sort t_fork_join) + t_fib_fork_join_all t_sort t_fork_join t_fork_join_heavy) (preprocess (action (run %{project_root}/src/cpp/cpp.exe %{input-file}))) (enabled_if (>= %{ocaml_version} 5.0)) - (libraries moonpool trace + (libraries moonpool trace trace-tef qcheck-core qcheck-core.runner ;tracy-client.trace )) diff --git a/test/effect-based/t_fib1.ml b/test/effect-based/t_fib1.ml index e8d2f534..a7c8ebee 100644 --- a/test/effect-based/t_fib1.ml +++ b/test/effect-based/t_fib1.ml @@ -26,13 +26,13 @@ let fib ~on x : int Fut.t = let () = assert (List.init 10 fib_direct = [ 1; 1; 2; 3; 5; 8; 13; 21; 34; 55 ]) let fib_40 : int = - let pool = Pool.create ~min:8 () in + let pool = Ws_pool.create ~num_threads:8 () in fib ~on:pool 40 |> Fut.wait_block_exn let () = Printf.printf "fib 40 = %d\n%!" fib_40 let run_test () = - let pool = Pool.create ~min:8 () in + let pool = Ws_pool.create ~num_threads:8 () in assert ( List.init 10 (fib ~on:pool) @@ -42,7 +42,7 @@ let run_test () = let fibs = Array.init 3 (fun _ -> fib ~on:pool 40) in let res = Fut.join_array fibs |> Fut.wait_block in - Pool.shutdown pool; + Ws_pool.shutdown pool; assert (res = Ok (Array.make 3 fib_40)) diff --git a/test/effect-based/t_fib_fork_join.ml b/test/effect-based/t_fib_fork_join.ml index c6898833..4e6639b2 100644 --- a/test/effect-based/t_fib_fork_join.ml +++ b/test/effect-based/t_fib_fork_join.ml @@ -27,13 +27,13 @@ let fib ~on x : int Fut.t = let () = assert (List.init 10 fib_direct = [ 1; 1; 2; 3; 5; 8; 13; 21; 34; 55 ]) let fib_40 : int = - let pool = Pool.create ~min:8 () in + let pool = Ws_pool.create ~num_threads:8 () in fib ~on:pool 40 |> Fut.wait_block_exn let () = Printf.printf "fib 40 = %d\n%!" fib_40 let run_test () = - let pool = Pool.create ~min:8 () in + let pool = Ws_pool.create ~num_threads:8 () in assert ( List.init 10 (fib ~on:pool) @@ -43,7 +43,7 @@ let run_test () = let fibs = Array.init 3 (fun _ -> fib ~on:pool 40) in let res = Fut.join_array fibs |> Fut.wait_block in - Pool.shutdown pool; + Ws_pool.shutdown pool; assert (res = Ok (Array.make 3 fib_40)) diff --git a/test/effect-based/t_fib_fork_join_all.ml b/test/effect-based/t_fib_fork_join_all.ml index e1ae83f4..3caee9b9 100644 --- a/test/effect-based/t_fib_fork_join_all.ml +++ b/test/effect-based/t_fib_fork_join_all.ml @@ -22,13 +22,13 @@ let rec fib x : int = ) let fib_40 : int = - let@ pool = Pool.with_ ~min:8 () in + let@ pool = Ws_pool.with_ ~num_threads:8 () in Fut.spawn ~on:pool (fun () -> fib 40) |> Fut.wait_block_exn let () = Printf.printf "fib 40 = %d\n%!" fib_40 let run_test () = - let@ pool = Pool.with_ ~min:8 () in + let@ pool = Ws_pool.with_ ~num_threads:8 () in let fut = Fut.spawn ~on:pool (fun () -> @@ -37,7 +37,7 @@ let run_test () = in let res = Fut.wait_block_exn fut in - Pool.shutdown pool; + Ws_pool.shutdown pool; assert (res = (Array.make 3 fib_40 |> Array.to_list)) diff --git a/test/effect-based/t_fork_join.ml b/test/effect-based/t_fork_join.ml index 7fc8fa31..5c7134ca 100644 --- a/test/effect-based/t_fork_join.ml +++ b/test/effect-based/t_fork_join.ml @@ -5,11 +5,11 @@ let ( let@ ) = ( @@ ) open! Moonpool -let pool = Pool.create ~min:4 () +let pool = Ws_pool.create ~num_threads:4 () let () = let x = - Pool.run_wait_block pool (fun () -> + Ws_pool.run_wait_block pool (fun () -> let x, y = Fork_join.both (fun () -> @@ -25,7 +25,7 @@ let () = let () = try - Pool.run_wait_block pool (fun () -> + Ws_pool.run_wait_block pool (fun () -> Fork_join.both_ignore (fun () -> Thread.delay 0.005) (fun () -> @@ -36,21 +36,21 @@ let () = let () = let par_sum = - Pool.run_wait_block pool (fun () -> + Ws_pool.run_wait_block pool (fun () -> Fork_join.all_init 42 (fun i -> i * i) |> List.fold_left ( + ) 0) in let exp_sum = List.init 42 (fun x -> x * x) |> List.fold_left ( + ) 0 in assert (par_sum = exp_sum) let () = - Pool.run_wait_block pool (fun () -> + Ws_pool.run_wait_block pool (fun () -> Fork_join.for_ 0 (fun _ _ -> assert false)); () let () = let total_sum = Atomic.make 0 in - Pool.run_wait_block pool (fun () -> + Ws_pool.run_wait_block pool (fun () -> Fork_join.for_ ~chunk_size:5 100 (fun low high -> (* iterate on the range sequentially. The range should have 5 items or less. *) let local_sum = ref 0 in @@ -63,7 +63,7 @@ let () = let () = let total_sum = Atomic.make 0 in - Pool.run_wait_block pool (fun () -> + Ws_pool.run_wait_block pool (fun () -> Fork_join.for_ ~chunk_size:1 100 (fun low high -> assert (low = high); ignore (Atomic.fetch_and_add total_sum low : int))); @@ -270,7 +270,7 @@ end let t_eval = let arb = Q.set_stats [ "size", Evaluator.size ] Evaluator.arb in Q.Test.make ~name:"same eval" arb (fun e -> - let@ pool = Pool.with_ ~min:4 () in + let@ pool = Ws_pool.with_ ~num_threads:4 () in (* Printf.eprintf "eval %s\n%!" (Evaluator.show e); *) let x = Evaluator.eval_seq e in let y = Evaluator.eval_fork_join ~pool e in @@ -288,8 +288,8 @@ let t_for_nested ~min ~chunk_size () = let ref_l2 = List.map (List.map neg) ref_l1 in let l1, l2 = - let@ pool = Pool.with_ ~min () in - let@ () = Pool.run_wait_block pool in + let@ pool = Ws_pool.with_ ~num_threads:min () in + let@ () = Ws_pool.run_wait_block pool in let l1 = Fork_join.map_list ~chunk_size (Fork_join.map_list ~chunk_size neg) l in @@ -310,8 +310,8 @@ let t_map ~chunk_size () = Q.Test.make ~name:"map1" Q.(small_list small_int |> Q.set_stats [ "len", List.length ]) (fun l -> - let@ pool = Pool.with_ ~min:4 () in - let@ () = Pool.run_wait_block pool in + let@ pool = Ws_pool.with_ ~num_threads:4 () in + let@ () = Ws_pool.run_wait_block pool in let a1 = Fork_join.map_list ~chunk_size string_of_int l |> Array.of_list diff --git a/test/effect-based/t_fork_join_heavy.ml b/test/effect-based/t_fork_join_heavy.ml new file mode 100644 index 00000000..a981bee1 --- /dev/null +++ b/test/effect-based/t_fork_join_heavy.ml @@ -0,0 +1,57 @@ +[@@@ifge 5.0] + +module Q = QCheck + +let spf = Printf.sprintf +let ( let@ ) = ( @@ ) +let ppl = Q.Print.(list @@ list int) + +open! Moonpool + +let run ~min () = + let@ _sp = + Trace.with_span ~__FILE__ ~__LINE__ "run" ~data:(fun () -> + [ "min", `Int min ]) + in + + Printf.printf "run with min=%d\n%!" min; + let neg x = -x in + + let chunk_size = 100 in + let l = List.init 300 (fun _ -> List.init 15 (fun i -> i)) in + + let ref_l1 = List.map (List.map neg) l in + let ref_l2 = List.map (List.map neg) ref_l1 in + + for _i = 1 to 800 do + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "step" in + + let l1, l2 = + let@ pool = Ws_pool.with_ ~num_threads:min () in + let@ () = Ws_pool.run_wait_block pool in + + let l1, l2 = + Fork_join.both + (fun () -> + Fork_join.map_list ~chunk_size + (Fork_join.map_list ~chunk_size neg) + l) + (fun () -> + Fork_join.map_list ~chunk_size + (Fork_join.map_list ~chunk_size neg) + ref_l1) + in + l1, l2 + in + + if l1 <> ref_l1 then failwith (spf "l1=%s, ref_l1=%s" (ppl l1) (ppl ref_l1)); + if l2 <> ref_l2 then failwith (spf "l1=%s, ref_l1=%s" (ppl l2) (ppl ref_l2)) + done + +let () = + let@ () = Trace_tef.with_setup () in + run ~min:4 (); + run ~min:1 (); + Printf.printf "done\n%!" + +[@@@endif] diff --git a/test/effect-based/t_futs1.ml b/test/effect-based/t_futs1.ml index be58f50b..4df18226 100644 --- a/test/effect-based/t_futs1.ml +++ b/test/effect-based/t_futs1.ml @@ -2,7 +2,7 @@ open! Moonpool -let pool = Pool.create ~min:4 () +let pool = Ws_pool.create ~num_threads:4 () let () = let fut = Array.init 10 (fun i -> Fut.spawn ~on:pool (fun () -> i)) in diff --git a/test/effect-based/t_many.ml b/test/effect-based/t_many.ml index c9cad3c6..b4a2c8da 100644 --- a/test/effect-based/t_many.ml +++ b/test/effect-based/t_many.ml @@ -2,9 +2,9 @@ open Moonpool -let pool = Pool.create ~min:4 () +let ( let@ ) = ( @@ ) -let run () = +let run ~pool () = let t1 = Unix.gettimeofday () in let n = 200_000 in @@ -14,20 +14,35 @@ let run () = Fut.spawn ~on:pool (fun () -> List.fold_left (fun n x -> - let _res = Fut.await x in + let _res = Sys.opaque_identity (Fut.await x) in n + 1) 0 l) in - let futs = - List.init n_tasks (fun _ -> Fut.spawn ~on:pool task |> Fut.join ~on:pool) - in + let futs = List.init n_tasks (fun _ -> Fut.spawn ~on:pool task |> Fut.join) in let lens = List.map Fut.wait_block_exn futs in Printf.printf "awaited %d items (%d times)\n%!" (List.hd lens) n_tasks; Printf.printf "in %.4fs\n%!" (Unix.gettimeofday () -. t1); assert (List.for_all (fun s -> s = n) lens) -let () = run () +let () = + (print_endline "with fifo"; + let@ pool = Fifo_pool.with_ ~num_threads:4 () in + run ~pool ()); + + (print_endline "with WS(1)"; + let@ pool = Ws_pool.with_ ~num_threads:1 () in + run ~pool ()); + + (print_endline "with WS(2)"; + let@ pool = Ws_pool.with_ ~num_threads:2 () in + run ~pool ()); + + (print_endline "with WS(4)"; + let@ pool = Ws_pool.with_ ~num_threads:4 () in + run ~pool ()); + + () [@@@endif] diff --git a/test/effect-based/t_sort.ml b/test/effect-based/t_sort.ml index a732c740..8ccc372f 100644 --- a/test/effect-based/t_sort.ml +++ b/test/effect-based/t_sort.ml @@ -59,7 +59,7 @@ let rec quicksort arr i len : unit = (fun () -> quicksort arr !low (len - (!low - i))) ) -let pool = Moonpool.Pool.create ~min:8 () +let pool = Moonpool.Ws_pool.create ~num_threads:8 () let () = let arr = Array.init 400_000 (fun _ -> Random.int 300_000) in diff --git a/test/t_bench1.ml b/test/t_bench1.ml index abf4a7f2..cd1a8bfd 100644 --- a/test/t_bench1.ml +++ b/test/t_bench1.ml @@ -8,7 +8,7 @@ let rec fib x = let run ~psize ~n ~j () : _ Fut.t = Printf.printf "pool size=%d, n=%d, j=%d\n%!" psize n j; - let pool = Pool.create ~min:psize ~per_domain:0 () in + let pool = Ws_pool.create ~num_threads:psize () in (* TODO: a ppx for tracy so we can use instrumentation *) let loop () = diff --git a/test/t_chan_train.ml b/test/t_chan_train.ml index 5d1c40ef..20645a73 100644 --- a/test/t_chan_train.ml +++ b/test/t_chan_train.ml @@ -1,9 +1,9 @@ open Moonpool (* large pool, some of our tasks below are long lived *) -let pool = Pool.create ~min:30 () +let pool = Ws_pool.create ~num_threads:30 () -open (val Fut.infix pool) +open Fut.Infix type event = | E_int of int diff --git a/test/t_fib.ml b/test/t_fib.ml index 38e3cb50..3fc53bf9 100644 --- a/test/t_fib.ml +++ b/test/t_fib.ml @@ -1,5 +1,12 @@ open Moonpool +let ( let@ ) = ( @@ ) + +let with_pool ~kind () f = + match kind with + | `Fifo_pool -> Fifo_pool.with_ ~num_threads:4 () f + | `Ws_pool -> Ws_pool.with_ ~num_threads:4 () f + let rec fib x = if x <= 1 then 1 @@ -8,11 +15,10 @@ let rec fib x = let () = assert (List.init 10 fib = [ 1; 1; 2; 3; 5; 8; 13; 21; 34; 55 ]) -let run_test () = - let pool = Pool.create ~min:4 () in +let run_test ~pool () = let fibs = Array.init 30 (fun n -> Fut.spawn ~on:pool (fun () -> fib n)) in let res = Fut.join_array fibs |> Fut.wait_block in - Pool.shutdown pool; + Ws_pool.shutdown pool; assert ( res @@ -50,11 +56,23 @@ let run_test () = 832040; |]) -let () = +let run ~kind () = for _i = 1 to 4 do - run_test () + let@ pool = with_pool ~kind () in + run_test ~pool () done; (* now make sure we can do this with multiple pools in parallel *) - let jobs = Array.init 4 (fun _ -> Thread.create run_test ()) in + let jobs = + Array.init 4 (fun _ -> + Thread.create + (fun () -> + let@ pool = with_pool ~kind () in + run_test ~pool ()) + ()) + in Array.iter Thread.join jobs + +let () = + run ~kind:`Ws_pool (); + run ~kind:`Fifo_pool () diff --git a/test/t_fib_rec.ml b/test/t_fib_rec.ml index b76fe875..3495fcae 100644 --- a/test/t_fib_rec.ml +++ b/test/t_fib_rec.ml @@ -1,4 +1,6 @@ -open Moonpool +open! Moonpool + +let ( let@ ) = ( @@ ) let rec fib_direct x = if x <= 1 then @@ -6,24 +8,32 @@ let rec fib_direct x = else fib_direct (x - 1) + fib_direct (x - 2) +let n_calls_fib_direct = Atomic.make 0 + let rec fib ~on x : int Fut.t = if x <= 18 then - Fut.spawn ~on (fun () -> fib_direct x) + Fut.spawn ~on (fun () -> + Atomic.incr n_calls_fib_direct; + fib_direct x) else - let open Fut.Infix_local in + let open Fut.Infix in let+ t1 = fib ~on (x - 1) and+ t2 = fib ~on (x - 2) in t1 + t2 let () = assert (List.init 10 fib_direct = [ 1; 1; 2; 3; 5; 8; 13; 21; 34; 55 ]) -let fib_40 : int = - let pool = Pool.create ~min:8 () in - let r = fib ~on:pool 40 |> Fut.wait_block_exn in - Pool.shutdown pool; - r +let fib_40 : int lazy_t = + lazy + (let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fib40" in + let pool = Fifo_pool.create ~num_threads:8 () in + let r = fib ~on:pool 40 |> Fut.wait_block_exn in + Ws_pool.shutdown pool; + r) -let run_test () = - let pool = Pool.create ~min:8 () in +let run_test ~pool () = + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "run-test" in + + let (lazy fib_40) = fib_40 in assert ( List.init 10 (fib ~on:pool) @@ -34,16 +44,42 @@ let run_test () = let fibs = Array.init n_fibs (fun _ -> fib ~on:pool 40) in let res = Fut.join_array fibs |> Fut.wait_block in - Pool.shutdown pool; assert (res = Ok (Array.make n_fibs fib_40)) +let run_test_size ~size () = + Printf.printf "test pool(%d)\n%!" size; + let@ pool = Ws_pool.with_ ~num_threads:size () in + run_test ~pool () + +let run_test_fifo ~size () = + Printf.printf "test fifo(%d)\n%!" size; + let@ pool = Fifo_pool.with_ ~num_threads:size () in + run_test ~pool () + +let setup_counter () = + if Trace.enabled () then + ignore + (Thread.create + (fun () -> + while true do + Thread.delay 0.01; + Trace.counter_int "n-fib-direct" (Atomic.get n_calls_fib_direct) + done) + () + : Thread.t) + let () = + let@ () = Trace_tef.with_setup () in + setup_counter (); + + let (lazy fib_40) = fib_40 in Printf.printf "fib 40 = %d\n%!" fib_40; - for _i = 1 to 2 do - run_test () - done; + + run_test_fifo ~size:4 (); + + List.iter (fun size -> run_test_size ~size ()) [ 1; 2; 4; 8 ]; (* now make sure we can do this with multiple pools in parallel *) - let jobs = Array.init 4 (fun _ -> Thread.create run_test ()) in + let jobs = Array.init 4 (fun _ -> Thread.create (run_test_size ~size:4) ()) in Array.iter Thread.join jobs diff --git a/test/t_futs1.ml b/test/t_futs1.ml index 930c8bdc..03a1ac13 100644 --- a/test/t_futs1.ml +++ b/test/t_futs1.ml @@ -1,7 +1,7 @@ open! Moonpool -let pool = Pool.create ~min:4 () -let pool2 = Pool.create ~min:2 () +let pool = Ws_pool.create ~num_threads:4 () +let pool2 = Ws_pool.create ~num_threads:2 () let () = let fut = Fut.return 1 in diff --git a/test/t_props.ml b/test/t_props.ml index ae6638ae..fe187073 100644 --- a/test/t_props.ml +++ b/test/t_props.ml @@ -1,49 +1,54 @@ module Q = QCheck open Moonpool +let ( let@ ) = ( @@ ) let tests = ref [] let add_test t = tests := t :: !tests -(* main pool *) -let pool = Pool.create ~min:4 ~per_domain:1 () - -(* pool for future combinators *) -let pool_fut = Pool.create ~min:2 () - -module Fut2 = (val Fut.infix pool_fut) +let with_pool ~kind () f = + match kind with + | `Fifo_pool -> Fifo_pool.with_ () f + | `Ws_pool -> Ws_pool.with_ () f let () = - add_test - @@ Q.Test.make ~name:"map then join_list" - Q.(small_list small_int) - (fun l -> - let l' = List.map (fun x -> Fut.spawn ~on:pool (fun () -> x + 1)) l in - let l' = Fut.join_list l' |> Fut.wait_block_exn in - if l' <> List.map succ l then Q.Test.fail_reportf "bad list"; - true) + add_test @@ fun ~kind -> + Q.Test.make ~name:"map then join_list" + Q.(small_list small_int) + (fun l -> + let@ pool = with_pool ~kind () in + let l' = List.map (fun x -> Fut.spawn ~on:pool (fun () -> x + 1)) l in + let l' = Fut.join_list l' |> Fut.wait_block_exn in + if l' <> List.map succ l then Q.Test.fail_reportf "bad list"; + true) let () = - add_test - @@ Q.Test.make ~name:"map bind" - Q.(small_list small_int) - (fun l -> - let open Fut2 in - let l' = - l - |> List.map (fun x -> - let* x = Fut.spawn ~on:pool_fut (fun () -> x + 1) in - let* y = Fut.return (x - 1) in - let+ z = Fut.spawn ~on:pool_fut (fun () -> string_of_int y) in - z) - in + add_test @@ fun ~kind -> + Q.Test.make ~name:"map bind" + Q.(small_list small_int) + (fun l -> + let@ pool = with_pool ~kind () in + let open Fut.Infix in + let l' = + l + |> List.map (fun x -> + let* x = Fut.spawn ~on:pool (fun () -> x + 1) in + let* y = Fut.return (x - 1) in + let+ z = Fut.spawn ~on:pool (fun () -> string_of_int y) in + z) + in - Fut.wait_list l' |> Fut.wait_block_exn; + Fut.wait_list l' |> Fut.wait_block_exn; - let l_res = List.map Fut.get_or_fail_exn l' in - if l_res <> List.map string_of_int l then - Q.Test.fail_reportf "bad list: from %s, to %s" - Q.Print.(list int l) - Q.Print.(list string l_res); - true) + let l_res = List.map Fut.get_or_fail_exn l' in + if l_res <> List.map string_of_int l then + Q.Test.fail_reportf "bad list: from %s, to %s" + Q.Print.(list int l) + Q.Print.(list string l_res); + true) -let () = QCheck_base_runner.run_tests_main !tests +let () = + let tests = + List.map (fun t -> [ t ~kind:`Fifo_pool; t ~kind:`Ws_pool ]) !tests + |> List.flatten + in + QCheck_base_runner.run_tests_main tests diff --git a/test/t_resource.ml b/test/t_resource.ml index 5845c520..4c20e9fb 100644 --- a/test/t_resource.ml +++ b/test/t_resource.ml @@ -2,15 +2,26 @@ open! Moonpool let ( let@ ) = ( @@ ) +let with_pool ~kind () f = + match kind with + | `Fifo_pool -> Fifo_pool.with_ () f + | `Ws_pool -> Ws_pool.with_ () f + (* test proper resource handling *) -let () = +let run ~kind () = + let@ () = Trace_tef.with_setup () in let a = Atomic.make 0 in for _i = 1 to 1_000 do + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "loop.step" in (* give a chance to domains to die *) if _i mod 100 = 0 then Thread.delay 0.8; (* allocate a new pool at each iteration *) - let@ p = Pool.with_ ~min:4 () in - Pool.run_wait_block p (fun () -> Atomic.incr a) + let@ p = with_pool ~kind () in + Ws_pool.run_wait_block p (fun () -> Atomic.incr a) done; assert (Atomic.get a = 1_000) + +let () = + run ~kind:`Ws_pool (); + run ~kind:`Fifo_pool () diff --git a/test/t_tree_futs.ml b/test/t_tree_futs.ml index 5ebf2bff..c9905eae 100644 --- a/test/t_tree_futs.ml +++ b/test/t_tree_futs.ml @@ -2,6 +2,11 @@ open Moonpool let ( let@ ) = ( @@ ) +let with_pool ~kind ~j () f = + match kind with + | `Fifo_pool -> Fifo_pool.with_ ~num_threads:j () f + | `Ws_pool -> Ws_pool.with_ ~num_threads:j () f + type 'a tree = | Leaf of 'a | Node of 'a tree Fut.t * 'a tree Fut.t @@ -10,19 +15,16 @@ let rec mk_tree ~pool n : _ tree Fut.t = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "mk-tree" in if n <= 1 then Fut.return (Leaf 1) - else - let open (val Fut.infix pool) in - let l = - Fut.spawn ~on:pool (fun () -> mk_tree ~pool (n - 1)) |> Fut.join ~on:pool - and r = - Fut.spawn ~on:pool (fun () -> mk_tree ~pool (n - 1)) |> Fut.join ~on:pool - in + else ( + let l = Fut.spawn ~on:pool (fun () -> mk_tree ~pool (n - 1)) |> Fut.join + and r = Fut.spawn ~on:pool (fun () -> mk_tree ~pool (n - 1)) |> Fut.join in Fut.return @@ Node (l, r) + ) let rec rev ~pool (t : 'a tree Fut.t) : 'a tree Fut.t = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "rev" in - let open (val Fut.infix pool) in + let open Fut.Infix in t >>= function | Leaf n -> Fut.return (Leaf n) | Node (l, r) -> @@ -31,7 +33,7 @@ let rec rev ~pool (t : 'a tree Fut.t) : 'a tree Fut.t = let rec sum ~pool (t : int tree Fut.t) : int Fut.t = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "sum" in - let open (val Fut.infix pool) in + let open Fut.Infix in t >>= function | Leaf n -> Fut.return n | Node (l, r) -> @@ -40,7 +42,7 @@ let rec sum ~pool (t : int tree Fut.t) : int Fut.t = let run ~pool n : (int * int) Fut.t = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "run" in - let open (val Fut.infix pool) in + let open Fut.Infix in let t = Fut.return n >>= mk_tree ~pool in let t' = rev ~pool t in let sum_t = sum ~pool t in @@ -61,15 +63,13 @@ let stat_thread () = done) () -let () = - (* - Tracy_client_trace.setup (); - *) +let run_main ~kind () = + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "run_main" in let start = Unix.gettimeofday () in let n = try int_of_string (Sys.getenv "N") with _ -> default_n in let j = try int_of_string (Sys.getenv "J") with _ -> 4 in - let pool = Pool.create ~min:j () in + let@ pool = with_pool ~kind ~j () in ignore (stat_thread () : Thread.t); Printf.printf "n=%d, j=%d\n%!" n j; @@ -79,3 +79,11 @@ let () = assert (n1 = 1 lsl (n - 1)); assert (n2 = 1 lsl (n - 1)); () + +let () = + let@ () = Trace_tef.with_setup () in + (* + Tracy_client_trace.setup (); + *) + run_main ~kind:`Ws_pool (); + run_main ~kind:`Fifo_pool () diff --git a/test/t_unfair.ml b/test/t_unfair.ml new file mode 100644 index 00000000..cee4373e --- /dev/null +++ b/test/t_unfair.ml @@ -0,0 +1,52 @@ +(* exhibits unfairness *) + +open Moonpool + +let ( let@ ) = ( @@ ) + +let sleep_for f () = + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "sleep" in + Thread.delay f + +let run ~kind () = + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "run" in + + let pool = + let on_init_thread ~dom_id:_ ~t_id () = + Trace.set_thread_name (Printf.sprintf "pool worker %d" t_id) + and around_task = + ( (fun self -> Trace.counter_int "n_tasks" (Ws_pool.num_tasks self)), + fun self () -> Trace.counter_int "n_tasks" (Ws_pool.num_tasks self) ) + in + + match kind with + | `Simple -> Fifo_pool.create ~num_threads:3 ~on_init_thread ~around_task () + | `Ws_pool -> Ws_pool.create ~num_threads:3 ~on_init_thread ~around_task () + in + + (* make all threads busy *) + Ws_pool.run_async pool (sleep_for 0.01); + Ws_pool.run_async pool (sleep_for 0.01); + Ws_pool.run_async pool (sleep_for 0.05); + + let t = Unix.gettimeofday () in + for _i = 1 to 100 do + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "schedule step" in + Ws_pool.run_async pool (sleep_for 0.001); + Ws_pool.run_async pool (sleep_for 0.001); + Ws_pool.run_async pool (sleep_for 0.01) + done; + + Printf.printf "pool size: %d\n%!" (Ws_pool.num_tasks pool); + (let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "shutdown" in + Ws_pool.shutdown pool); + Printf.printf "pool size after shutdown: %d\n%!" (Ws_pool.num_tasks pool); + + let elapsed = Unix.gettimeofday () -. t in + Printf.printf "elapsed: %.4fs\n%!" elapsed + +let () = + let@ () = Trace_tef.with_setup () in + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main" in + run ~kind:`Simple (); + run ~kind:`Ws_pool () diff --git a/test/t_ws_deque.ml b/test/t_ws_deque.ml new file mode 100644 index 00000000..88429a8d --- /dev/null +++ b/test/t_ws_deque.ml @@ -0,0 +1,102 @@ +module A = Moonpool.Atomic +module D = Moonpool.Private.Ws_deque_ + +let ( let@ ) = ( @@ ) +let dummy = -100 + +let t_simple () = + let d = D.create ~dummy () in + assert (D.steal d = None); + assert (D.pop d = None); + assert (D.push d 1); + assert (D.push d 2); + assert (D.pop d = Some 2); + assert (D.steal d = Some 1); + assert (D.steal d = None); + assert (D.pop d = None); + assert (D.push d 3); + assert (D.pop d = Some 3); + assert (D.push d 4); + assert (D.push d 5); + assert (D.push d 6); + assert (D.steal d = Some 4); + assert (D.steal d = Some 5); + assert (D.pop d = Some 6); + assert (D.pop d = None); + + Printf.printf "basic tests passed\n"; + () + +(* big heavy test *) +let t_heavy () = + let sum = A.make 0 in + let ref_sum = ref 0 in + + let[@inline] add_to_sum x = ignore (A.fetch_and_add sum x : int) in + + let active = A.make true in + + let d = D.create ~dummy () in + + let stealer_loop () = + Trace.set_thread_name "stealer"; + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "stealer" in + while A.get active do + match D.steal d with + | None -> Thread.yield () + | Some x -> add_to_sum x + done + in + + let main_loop () = + Trace.set_thread_name "producer"; + for _i = 1 to 100_000 do + let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main.outer" in + + (* NOTE: we make sure to push less than 256 elements at once *) + for j = 1 to 100 do + ref_sum := !ref_sum + j; + assert (D.push d j); + ref_sum := !ref_sum + j; + assert (D.push d j); + + Option.iter (fun x -> add_to_sum x) (D.pop d); + Option.iter (fun x -> add_to_sum x) (D.pop d) + done; + + (* now compete with stealers to pop *) + let continue = ref true in + while !continue do + match D.pop d with + | Some x -> add_to_sum x + | None -> continue := false + done + done + in + + let ts = + Array.init 6 (fun _ -> Moonpool.start_thread_on_some_domain stealer_loop ()) + in + let t = Moonpool.start_thread_on_some_domain main_loop () in + + (* stop *) + A.set active false; + + Trace.message "joining t"; + Thread.join t; + Trace.message "joining stealers"; + Array.iter Thread.join ts; + Trace.message "done"; + + let ref_sum = !ref_sum in + let sum = A.get sum in + + Printf.printf "ref sum = %d, sum = %d\n%!" ref_sum sum; + assert (ref_sum = sum); + () + +let () = + let@ () = Trace_tef.with_setup () in + t_simple (); + t_heavy (); + () diff --git a/test/t_ws_pool_confusion.ml b/test/t_ws_pool_confusion.ml new file mode 100644 index 00000000..20488b65 --- /dev/null +++ b/test/t_ws_pool_confusion.ml @@ -0,0 +1,28 @@ +open Moonpool + +let delay () = Thread.delay 0.001 + +let run ~p_main:_ ~p_sub () = + let f1 = + Fut.spawn ~on:p_sub (fun () -> + delay (); + 1) + in + let f2 = + Fut.spawn ~on:p_sub (fun () -> + delay (); + 2) + in + Fut.wait_block_exn f1 + Fut.wait_block_exn f2 + +let () = + let p_main = Ws_pool.create ~num_threads:2 () in + let p_sub = Ws_pool.create ~num_threads:10 () in + + let futs = List.init 8 (fun _ -> Fut.spawn ~on:p_main (run ~p_main ~p_sub)) in + + let l = List.map Fut.wait_block_exn futs in + assert (l = List.init 8 (fun _ -> 3)); + + print_endline "ok"; + ()