diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index dbe2ae40..e54322e6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,8 +31,10 @@ jobs: dune-cache: true - run: opam install -t moonpool --deps-only - - - run: opam exec -- dune build '@install' - + - run: opam exec -- dune build @install - run: opam exec -- dune runtest + - run: opam install domain-local-await + if: matrix.ocaml-compiler == '5.0' + - run: opam exec -- dune build @install @runtest + if: matrix.ocaml-compiler == '5.0' diff --git a/README.md b/README.md index 322ddb22..dbcc9867 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,55 @@ The user can create several thread pools. These pools use regular posix threads, but the threads are spread across multiple domains (on OCaml 5), which enables parallelism. +The function `Pool.run pool task` runs `task()` on one of the workers +of `pool`, as soon as one is available. No result is returned. + ```ocaml +# #require "threads";; # let pool = Moonpool.Pool.create ~min:4 ();; val pool : Moonpool.Pool.t = +# begin + Moonpool.Pool.run pool + (fun () -> + Thread.delay 0.1; + print_endline "running from the pool"); + print_endline "running from the caller"; + Thread.delay 0.3; (* wait for task to run before returning *) + end ;; +running from the caller +running from the pool +- : unit = () +``` + +The function `Fut.spawn ~on f` schedules `f ()` on the pool `on`, and immediately +returns a _future_ which will eventually hold the result (or an exception). + +The function `Fut.peek` will return the current value, or `None` if the future is +still not completed. +The functions `Fut.wait_block` and `Fut.wait_block_exn` will +block the current thread and wait for the future to complete. +There are some deadlock risks associated with careless use of these, so +be sure to consult the documentation of the `Fut` module. + +```ocaml +# let fut = Moonpool.Fut.spawn ~on:pool + (fun () -> + Thread.delay 0.5; + 1+1);; +val fut : int Moonpool.Fut.t = + +# Moonpool.Fut.peek fut; +- : int Moonpool.Fut.or_error option = None + +# Moonpool.Fut.wait_block_exn fut;; +- : int = 2 +``` + +Some combinators on futures are also provided, e.g. to wait for all futures in +an array to complete: + +```ocaml # let rec fib x = if x <= 1 then 1 else fib (x-1) + fib (x-2);; val fib : int -> int = @@ -46,6 +91,42 @@ Ok 514229; 832040; 1346269; 2178309; 3524578; 5702887; 9227465|] ``` +### Support for `await` + +On OCaml 5, effect handlers can be used to implement `Fut.await : 'a Fut.t -> 'a`. + +The expression `Fut.await some_fut`, when run from inside some thread pool, +suspends its caller task; the suspended task is then parked, and will +be resumed when the future is completed. +The pool worker that was executing this expression, in the mean time, moves +on to another task. +This means that `await` is free of the deadlock risks associated with +`Fut.wait_block`. + +In the following example, we bypass the need for `Fut.join_array` by simply +using regular array functions along with `Fut.await`. + +```ocaml +# let main_fut = + let open Moonpool.Fut in + spawn ~on:pool @@ fun () -> + (* array of sub-futures *) + let tasks: _ Moonpool.Fut.t array = Array.init 100 (fun i -> + spawn ~on:pool (fun () -> + Thread.delay 0.01; + i+1)) + in + Array.fold_left (fun n fut -> n + await fut) 0 tasks + ;; +val main_fut : int Moonpool.Fut.t = + +# let expected_sum = Array.init 100 (fun i->i+1) |> Array.fold_left (+) 0;; +val expected_sum : int = 5050 + +# assert (expected_sum = Moonpool.Fut.wait_block_exn main_fut);; +- : unit = () +``` + ### More intuition To quote [gasche](https://discuss.ocaml.org/t/ann-moonpool-0-1/12387/15): diff --git a/dune b/dune index 0ac574ef..6ea92b3a 100644 --- a/dune +++ b/dune @@ -2,4 +2,5 @@ (env (_ (flags :standard -strict-sequence -warn-error -a+8 -w +a-4-40-42-70))) -(mdx (libraries moonpool)) +(mdx (libraries moonpool) + (enabled_if (>= %{ocaml_version} 5.0))) diff --git a/src/dune b/src/dune index ef116a05..d9b0d84a 100644 --- a/src/dune +++ b/src/dune @@ -9,10 +9,15 @@ (action (with-stdout-to %{targets} (run ./gen/gen.exe --ocaml %{ocaml_version} --atomic)))) - + (rule (targets domain_.ml) (action (with-stdout-to %{targets} (run ./gen/gen.exe --ocaml %{ocaml_version} --domain)))) - + +(rule + (targets suspend_.ml) + (action + (with-stdout-to %{targets} + (run ./gen/gen.exe --ocaml %{ocaml_version} --suspend)))) diff --git a/src/fut.ml b/src/fut.ml index 846b78da..db2d24d7 100644 --- a/src/fut.ml +++ b/src/fut.ml @@ -354,6 +354,28 @@ let wait_block_exn self = | Ok x -> x | Error (e, bt) -> Printexc.raise_with_backtrace e bt +let await (fut : 'a t) : 'a = + match peek fut with + | Some res -> + (* fast path: peek *) + (match res with + | Ok x -> x + | Error (exn, bt) -> Printexc.raise_with_backtrace exn bt) + | None -> + (* suspend until the future is resolved *) + Suspend_.suspend + { + Suspend_types_.handle = + (fun ~run k -> + on_result fut (function + | Ok _ -> run (fun () -> k (Ok ())) + | Error (exn, bt) -> + (* fail continuation immediately *) + k (Error (exn, bt)))); + }; + (* un-suspended: we should have a result! *) + get_or_fail_exn fut + module type INFIX = sig val ( >|= ) : 'a t -> ('a -> 'b) -> 'b t val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t diff --git a/src/fut.mli b/src/fut.mli index 99dcecb3..b14893f4 100644 --- a/src/fut.mli +++ b/src/fut.mli @@ -145,6 +145,17 @@ val for_list : on:Pool.t -> 'a list -> ('a -> unit) -> unit t (** [for_list ~on l f] is like [for_array ~on (Array.of_list l) f]. @since 0.2 *) +(** {2 Await} + +This is only available on OCaml 5. *) + +val await : 'a t -> 'a +(** [await fut] suspends the current tasks until [fut] is fulfilled, then + resumes the task on this same pool. + This must only be run from inside the pool itself. + @since 0.3 + {b NOTE}: only on OCaml 5 *) + (** {2 Blocking} *) val wait_block : 'a t -> 'a or_error diff --git a/src/gen/gen.ml b/src/gen/gen.ml index e7a85e89..3128577f 100644 --- a/src/gen/gen.ml +++ b/src/gen/gen.ml @@ -72,16 +72,56 @@ let spawn : _ -> t = Domain.spawn let relax = Domain.cpu_relax |} +let suspend_pre_5 = + {| +open Suspend_types_ +let suspend _ = failwith "Thread suspension is only available on OCaml >= 5.0" +let with_suspend ~run:_ f : unit = f() +|} + +let suspend_post_5 = + {| +open Suspend_types_ + +type _ Effect.t += + | Suspend : suspension_handler -> unit Effect.t + +let[@inline] suspend h = Effect.perform (Suspend h) + +let with_suspend ~(run:task -> unit) (f: unit -> unit) : unit = + let module E = Effect.Deep in + + (* effect handler *) + let effc + : type e. e Effect.t -> ((e, _) E.continuation -> _) option + = function + | Suspend h -> + Some (fun k -> + let k': suspension = function + | Ok () -> E.continue k () + | Error (exn, bt) -> + E.discontinue_with_backtrace k exn bt + in + h.handle ~run k' + ) + | _ -> None + in + + E.try_with f () {E.effc} +|} + let p_version s = Scanf.sscanf s "%d.%d" (fun x y -> x, y) let () = let atomic = ref false in let domain = ref false in + let suspend = ref false in let ocaml = ref Sys.ocaml_version in Arg.parse [ "--atomic", Arg.Set atomic, " atomic"; "--domain", Arg.Set domain, " domain"; + "--suspend", Arg.Set suspend, " suspend"; "--ocaml", Arg.Set_string ocaml, " set ocaml version"; ] ignore ""; @@ -104,4 +144,12 @@ let () = domain_post_5 in print_endline code + ) else if !suspend then ( + let code = + if (major, minor) < (5, 0) then + suspend_pre_5 + else + suspend_post_5 + in + print_endline code ) diff --git a/src/pool.ml b/src/pool.ml index eb2e9366..81aefce1 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -26,12 +26,13 @@ let add_global_thread_loop_wrapper f : unit = exception Shutdown -let run (self : t) (f : task) : unit = +(** Run [task] as is, on the pool. *) +let run_direct_ (self : t) (task : task) : unit = let n_qs = Array.length self.qs in let offset = A.fetch_and_add self.cur_q 1 in (* blocking push, last resort *) - let push_wait () = + let[@inline] push_wait f = let q_idx = offset mod Array.length self.qs in let q = self.qs.(q_idx) in Bb_queue.push q f @@ -43,14 +44,23 @@ let run (self : t) (f : task) : unit = for i = 0 to n_qs - 1 do let q_idx = (i + offset) mod Array.length self.qs in let q = self.qs.(q_idx) in - if Bb_queue.try_push q f then raise_notrace Exit + if Bb_queue.try_push q task then raise_notrace Exit done done; - push_wait () + push_wait task with | Exit -> () | Bb_queue.Closed -> raise Shutdown +(** Run [task]. It will be wrapped with an effect handler to + support {!Fut.await}. *) +let run (self : t) (task : task) : unit = + let task' () = + (* run [f()] and handle [suspend] in it *) + Suspend_.with_suspend task ~run:(run_direct_ self) + in + run_direct_ self task' + let[@inline] size self = Array.length self.threads let num_tasks (self : t) : int = diff --git a/src/suspend_.mli b/src/suspend_.mli new file mode 100644 index 00000000..5247f597 --- /dev/null +++ b/src/suspend_.mli @@ -0,0 +1,18 @@ +(** (Private) suspending tasks using Effects. + + This module is an implementation detail of Moonpool and should + not be used outside of it. *) + +open Suspend_types_ + +val suspend : suspension_handler -> unit +(** [suspend h] jumps back to the nearest {!with_suspend} + and calls [h.handle] with the current continuation [k] + and a task runner function. +*) + +val with_suspend : run:(task -> unit) -> (unit -> unit) -> unit +(** [with_suspend ~run f] runs [f()] in an environment where [suspend] + will work. If [f()] suspends with suspension handler [h], + this calls [h ~run k] where [k] is the suspension. +*) diff --git a/src/suspend_types_.ml b/src/suspend_types_.ml new file mode 100644 index 00000000..22fb9eff --- /dev/null +++ b/src/suspend_types_.ml @@ -0,0 +1,13 @@ +(** (Private) types for {!Suspend_}. + + This module is an implementation detail of Moonpool and should + not be used outside of it. *) + +type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit +(** A suspended computation *) + +type task = unit -> unit + +type suspension_handler = { handle: run:(task -> unit) -> suspension -> unit } +[@@unboxed] +(** The handler that knows what to do with the suspended computation *) diff --git a/test/await/dune b/test/await/dune new file mode 100644 index 00000000..affbb54e --- /dev/null +++ b/test/await/dune @@ -0,0 +1,6 @@ + +(tests + (names t_fib1 t_futs1 t_many) + (enabled_if (>= %{ocaml_version} 5.0)) + (libraries moonpool trace ;tracy-client.trace + )) diff --git a/test/await/t_fib1.ml b/test/await/t_fib1.ml new file mode 100644 index 00000000..08a88384 --- /dev/null +++ b/test/await/t_fib1.ml @@ -0,0 +1,50 @@ +open Moonpool + +let rec fib_direct x = + if x <= 1 then + 1 + else + fib_direct (x - 1) + fib_direct (x - 2) + +let fib ~on x : int Fut.t = + let rec fib_rec x : int = + if x <= 18 then + fib_direct x + else ( + let t1 = Fut.spawn ~on (fun () -> fib_rec (x - 1)) + and t2 = Fut.spawn ~on (fun () -> fib_rec (x - 2)) in + Fut.await t1 + Fut.await t2 + ) + in + Fut.spawn ~on (fun () -> fib_rec x) + +(* NOTE: for tracy support + let () = Tracy_client_trace.setup () +*) +let () = assert (List.init 10 fib_direct = [ 1; 1; 2; 3; 5; 8; 13; 21; 34; 55 ]) + +let fib_40 : int = + let pool = Pool.create ~min:8 () in + fib ~on:pool 40 |> Fut.wait_block_exn + +let () = Printf.printf "fib 40 = %d\n%!" fib_40 + +let run_test () = + let pool = Pool.create ~min:8 () in + + assert ( + List.init 10 (fib ~on:pool) + |> Fut.join_list |> Fut.wait_block_exn + = [ 1; 1; 2; 3; 5; 8; 13; 21; 34; 55 ]); + + let fibs = Array.init 3 (fun _ -> fib ~on:pool 40) in + + let res = Fut.join_array fibs |> Fut.wait_block in + Pool.shutdown pool; + + assert (res = Ok (Array.make 3 fib_40)) + +let () = + (* now make sure we can do this with multiple pools in parallel *) + let jobs = Array.init 2 (fun _ -> Thread.create run_test ()) in + Array.iter Thread.join jobs diff --git a/test/await/t_futs1.ml b/test/await/t_futs1.ml new file mode 100644 index 00000000..aa974f45 --- /dev/null +++ b/test/await/t_futs1.ml @@ -0,0 +1,53 @@ +open! Moonpool + +let pool = Pool.create ~min:4 () + +let () = + let fut = Array.init 10 (fun i -> Fut.spawn ~on:pool (fun () -> i)) in + let fut2 = Fut.spawn ~on:pool (fun () -> Array.map Fut.await fut) in + assert (Fut.wait_block fut2 = Ok (Array.init 10 (fun x -> x))) + +let () = + let fut = + Array.init 10 (fun i -> + Fut.spawn ~on:pool (fun () -> + if i < 9 then + i + else + raise Not_found)) + in + let fut2 = Fut.spawn ~on:pool (fun () -> Array.map Fut.await fut) in + (* must fail *) + assert (Fut.wait_block fut2 |> Result.is_error) + +let mk_ret_delay ?(on = pool) n x = + Fut.spawn ~on (fun () -> + Thread.delay n; + x) + +let () = + let f1 = mk_ret_delay 0.01 1 in + let f2 = mk_ret_delay 0.01 2 in + let fut = Fut.spawn ~on:pool (fun () -> Fut.await f1, Fut.await f2) in + assert (Fut.wait_block_exn fut = (1, 2)) + +let () = + let f1 = + let f = + Fut.spawn ~on:pool (fun () -> + Thread.delay 0.01; + 1) + in + Fut.spawn ~on:pool (fun () -> Fut.await f + 1) + and f2 = + let f = + Fut.spawn ~on:pool (fun () -> + Thread.delay 0.01; + 10) + in + Fut.spawn ~on:pool (fun () -> + Thread.delay 0.01; + Fut.await f * 2) + in + let fut = Fut.both f1 f2 in + assert (Fut.wait_block fut = Ok (2, 20)) diff --git a/test/await/t_many.ml b/test/await/t_many.ml new file mode 100644 index 00000000..7b29ae16 --- /dev/null +++ b/test/await/t_many.ml @@ -0,0 +1,29 @@ +open Moonpool + +let pool = Pool.create ~min:4 () + +let run () = + let t1 = Unix.gettimeofday () in + + let n = 1_000_000 in + let n_tasks = 3 in + let task () = + let l = List.init n (fun x -> Fut.spawn ~on:pool (fun () -> x)) in + Fut.spawn ~on:pool (fun () -> + List.fold_left + (fun n x -> + let _res = Fut.await x in + n + 1) + 0 l) + in + + let futs = + List.init n_tasks (fun _ -> Fut.spawn ~on:pool task |> Fut.join ~on:pool) + in + + let lens = List.map Fut.wait_block_exn futs in + Printf.printf "awaited %d items (%d times)\n%!" (List.hd lens) n_tasks; + Printf.printf "in %.4fs\n%!" (Unix.gettimeofday () -. t1); + assert (List.for_all (fun s -> s = n) lens) + +let () = run () diff --git a/test/dune b/test/dune index 5b210cad..50a73317 100644 --- a/test/dune +++ b/test/dune @@ -1,5 +1,5 @@ (tests - (names t_fib t_bench1 t_fib_rec t_futs1 t_tree_futs t_props) + (names t_fib t_bench1 t_fib_rec t_futs1 t_tree_futs t_props t_chan_train) (libraries moonpool qcheck-core qcheck-core.runner ;tracy-client.trace trace)) diff --git a/test/t_chan_train.ml b/test/t_chan_train.ml new file mode 100644 index 00000000..5d1c40ef --- /dev/null +++ b/test/t_chan_train.ml @@ -0,0 +1,95 @@ +open Moonpool + +(* large pool, some of our tasks below are long lived *) +let pool = Pool.create ~min:30 () + +open (val Fut.infix pool) + +type event = + | E_int of int + | E_close + +let mk_chan (ic : event Chan.t) : event Chan.t = + let out = Chan.create () in + + let rec loop () = + let* ev = Chan.pop ic in + Chan.push out ev; + match ev with + | E_close -> Fut.return () + | E_int _x -> loop () + in + + ignore (Fut.spawn ~on:pool loop : _ Fut.t); + out + +(* a train of channels connected to one another, with a + loop pushing events from the input to the output *) +let rec mk_train n ic : _ Chan.t = + if n = 0 then + ic + else ( + let c = mk_chan ic in + mk_train (n - 1) c + ) + +let run () = + let start = Unix.gettimeofday () in + + let n_trains = 4 in + let len_train = 80 in + let n_events = 1_000 in + let range = 5 in + + (* start trains *) + let trains = + List.init n_trains (fun _ -> + let c = Chan.create () in + let out = mk_train len_train c in + c, out) + in + + let pushers = + List.map + (fun (ic, _oc) -> + Fut.spawn ~on:pool (fun () -> + for i = 1 to n_events do + Chan.push ic (E_int (i mod range)) + done; + Chan.push ic E_close)) + trains + in + + let gatherers = + List.map + (fun (_ic, oc) -> + let sum = ref 0 in + try + while true do + match Chan.pop_block_exn oc with + | E_close -> raise Exit + | E_int x -> sum := !sum + x + done; + assert false + with Exit -> !sum) + trains + in + + Fut.wait_block_exn (Fut.wait_list pushers); + + let expected_sum = + let sum = ref 0 in + for i = 1 to n_events do + sum := !sum + (i mod range) + done; + !sum + in + + Printf.printf "got %d events in %d trains (len=%d) in %.2fs\n%!" n_events + n_trains len_train + (Unix.gettimeofday () -. start); + + assert (gatherers = List.init n_trains (fun _ -> expected_sum)); + () + +let () = run ()