diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 91d18141..b7f5dbb7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -35,10 +35,8 @@ jobs: - run: opam install -t moonpool --deps-only - run: opam exec -- dune build @install - run: opam exec -- dune runtest - - run: opam install domain-local-await - if: matrix.ocaml-compiler == '5.0' + - run: opam install thread-local-storage trace - run: opam exec -- dune build @install @runtest - if: matrix.ocaml-compiler == '5.0' - run: opam install trace thread-local-storage - run: opam exec -- dune build @install diff --git a/README.md b/README.md index 4353a757..c51361df 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,8 @@ In addition, some concurrency and parallelism primitives are provided: On OCaml 5 (meaning there's actual domains and effects, not just threads), a `Fut.await` primitive is provided. It's simpler and more powerful than the monadic combinators. -- `Moonpool.Fork_join` provides the fork-join parallelism primitives +- `Moonpool_forkjoin`, in the library `moonpool.forkjoin` + provides the fork-join parallelism primitives to use within tasks running in the pool. ## Usage @@ -166,7 +167,8 @@ val expected_sum : int = 5050 ### Fork-join -On OCaml 5, again using effect handlers, the module `Fork_join` +On OCaml 5, again using effect handlers, the sublibrary `moonpool.forkjoin` +provides a module `Moonpool_forkjoin` implements the [fork-join model](https://en.wikipedia.org/wiki/Fork%E2%80%93join_model). It must run on a pool (using `Runner.run_async` or inside a future via `Fut.spawn`). @@ -220,7 +222,7 @@ And a parallel quicksort for larger slices: done; (* sort lower half and upper half in parallel *) - Moonpool.Fork_join.both_ignore + Moonpool_forkjoin.both_ignore (fun () -> quicksort arr i (!low - i)) (fun () -> quicksort arr !low (len - (!low - i))) );; diff --git a/benchs/dune b/benchs/dune index ff0f878b..14393230 100644 --- a/benchs/dune +++ b/benchs/dune @@ -1,6 +1,6 @@ - (executables (names fib_rec pi) - (preprocess (action + (preprocess + (action (run %{project_root}/src/cpp/cpp.exe %{input-file}))) - (libraries moonpool unix trace trace-tef domainslib)) + (libraries moonpool moonpool.forkjoin unix trace trace-tef domainslib)) diff --git a/benchs/fib_rec.ml b/benchs/fib_rec.ml index 25291e8c..66eded93 100644 --- a/benchs/fib_rec.ml +++ b/benchs/fib_rec.ml @@ -1,5 +1,6 @@ open Moonpool module Trace = Trace_core +module FJ = Moonpool_forkjoin let ( let@ ) = ( @@ ) @@ -25,7 +26,7 @@ let fib_fj ~on x : int Fut.t = fib_direct x else ( let n1, n2 = - Fork_join.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) + FJ.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) in n1 + n2 ) diff --git a/benchs/pi.ml b/benchs/pi.ml index 7e0dfd91..4eae7eb0 100644 --- a/benchs/pi.ml +++ b/benchs/pi.ml @@ -1,6 +1,7 @@ (* compute Pi *) open Moonpool +module FJ = Moonpool_forkjoin let ( let@ ) = ( @@ ) let j = ref 0 @@ -76,7 +77,7 @@ let run_fork_join ~kind num_steps : float = let global_sum = Lock.create 0. in Ws_pool.run_wait_block ~name:"pi.fj" pool (fun () -> - Fork_join.for_ + FJ.for_ ~chunk_size:(3 + (num_steps / num_tasks)) num_steps (fun low high -> diff --git a/dune b/dune index 00264a6c..32ba6647 100644 --- a/dune +++ b/dune @@ -1,6 +1,8 @@ - (env - (_ (flags :standard -strict-sequence -warn-error -a+8 -w +a-4-40-42-70))) + (_ + (flags :standard -strict-sequence -warn-error -a+8 -w +a-4-40-42-70))) -(mdx (libraries moonpool threads) - (enabled_if (>= %{ocaml_version} 5.0))) +(mdx + (libraries moonpool moonpool.forkjoin threads) + (enabled_if + (>= %{ocaml_version} 5.0))) diff --git a/dune-project b/dune-project index b4e69e84..55cb93f1 100644 --- a/dune-project +++ b/dune-project @@ -29,8 +29,7 @@ :with-test))) (depopts (trace (>= 0.6)) - thread-local-storage - (domain-local-await (>= 0.2))) + thread-local-storage) (tags (thread pool domain futures fork-join))) diff --git a/moonpool.opam b/moonpool.opam index 6ff7c8b0..c8afba80 100644 --- a/moonpool.opam +++ b/moonpool.opam @@ -21,7 +21,6 @@ depends: [ depopts: [ "trace" {>= "0.6"} "thread-local-storage" - "domain-local-await" {>= "0.2"} ] build: [ ["dune" "subst"] {dev} diff --git a/src/bb_queue.ml b/src/core/bb_queue.ml similarity index 100% rename from src/bb_queue.ml rename to src/core/bb_queue.ml diff --git a/src/bb_queue.mli b/src/core/bb_queue.mli similarity index 100% rename from src/bb_queue.mli rename to src/core/bb_queue.mli diff --git a/src/bounded_queue.ml b/src/core/bounded_queue.ml similarity index 100% rename from src/bounded_queue.ml rename to src/core/bounded_queue.ml diff --git a/src/bounded_queue.mli b/src/core/bounded_queue.mli similarity index 100% rename from src/bounded_queue.mli rename to src/core/bounded_queue.mli diff --git a/src/chan.ml b/src/core/chan.ml similarity index 100% rename from src/chan.ml rename to src/core/chan.ml diff --git a/src/chan.mli b/src/core/chan.mli similarity index 100% rename from src/chan.mli rename to src/core/chan.mli diff --git a/src/d_pool_.ml b/src/core/domain_pool_.ml similarity index 99% rename from src/d_pool_.ml rename to src/core/domain_pool_.ml index d12a4f6a..31f11d26 100644 --- a/src/d_pool_.ml +++ b/src/core/domain_pool_.ml @@ -33,8 +33,6 @@ let domains_ : (worker_state option * Domain_.t option) Lock.t array = in a tight loop), and if nothing happens it tries to stop to free resources. *) let work_ idx (st : worker_state) : unit = - Dla_.setup_domain (); - let main_loop () = let continue = ref true in while !continue do diff --git a/src/d_pool_.mli b/src/core/domain_pool_.mli similarity index 100% rename from src/d_pool_.mli rename to src/core/domain_pool_.mli diff --git a/src/core/dune b/src/core/dune new file mode 100644 index 00000000..ff084a49 --- /dev/null +++ b/src/core/dune @@ -0,0 +1,9 @@ +(library + (public_name moonpool) + (name moonpool) + (libraries moonpool.private) + (flags :standard -open Moonpool_private) + (private_modules types_ domain_pool_ util_pool_) + (preprocess + (action + (run %{project_root}/src/cpp/cpp.exe %{input-file})))) diff --git a/src/core/exn_bt.ml b/src/core/exn_bt.ml new file mode 100644 index 00000000..b69f6614 --- /dev/null +++ b/src/core/exn_bt.ml @@ -0,0 +1,18 @@ +type t = exn * Printexc.raw_backtrace + +let[@inline] make exn bt : t = exn, bt +let[@inline] exn (e, _) = e +let[@inline] bt (_, bt) = bt + +let[@inline] get exn = + let bt = Printexc.get_raw_backtrace () in + make exn bt + +let[@inline] get_callstack n exn = + let bt = Printexc.get_callstack n in + make exn bt + +let show self = Printexc.to_string (fst self) +let[@inline] raise self = Printexc.raise_with_backtrace (exn self) (bt self) + +type nonrec 'a result = ('a, t) result diff --git a/src/core/exn_bt.mli b/src/core/exn_bt.mli new file mode 100644 index 00000000..becfbf3b --- /dev/null +++ b/src/core/exn_bt.mli @@ -0,0 +1,25 @@ +(** Exception with backtrace. + + @since NEXT_RELEASE *) + +type t = exn * Printexc.raw_backtrace +(** An exception bundled with a backtrace *) + +val exn : t -> exn +val bt : t -> Printexc.raw_backtrace + +val make : exn -> Printexc.raw_backtrace -> t +(** Trivial builder *) + +val get : exn -> t +(** [get exn] is [make exn (get_raw_backtrace ())] *) + +val get_callstack : int -> exn -> t + +val raise : t -> 'a +(** Raise the exception with its save backtrace *) + +val show : t -> string +(** Simple printing *) + +type nonrec 'a result = ('a, t) result diff --git a/src/fifo_pool.ml b/src/core/fifo_pool.ml similarity index 78% rename from src/fifo_pool.ml rename to src/core/fifo_pool.ml index a4f03116..d2757324 100644 --- a/src/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -1,16 +1,18 @@ -module TLS = Thread_local_storage_ +open Types_ include Runner let ( let@ ) = ( @@ ) +let k_storage = Task_local_storage.Private_.Storage.k_storage -type task_with_name = { +type task_full = { f: unit -> unit; name: string; + ls: Task_local_storage.storage; } type state = { threads: Thread.t array; - q: task_with_name Bb_queue.t; (** Queue for tasks. *) + q: task_full Bb_queue.t; (** Queue for tasks. *) } (** internal state *) @@ -18,13 +20,16 @@ let[@inline] size_ (self : state) = Array.length self.threads let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q (** Run [task] as is, on the pool. *) -let schedule_ (self : state) (task : task_with_name) : unit = +let schedule_ (self : state) (task : task_full) : unit = try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = + let cur_ls : Task_local_storage.storage ref = ref Task_local_storage.Private_.Storage.dummy in + TLS.set k_storage (Some cur_ls); TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; + let (AT_pair (before_task, after_task)) = around_task in let cur_span = ref Tracing_.dummy_span in @@ -34,20 +39,42 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = cur_span := Tracing_.dummy_span in - let run_another_task ~name task' = schedule_ self { f = task'; name } in + let on_suspend () = + exit_span_ (); + !cur_ls + in - let run_task (task : task_with_name) : unit = + let run_another_task ls ~name task' = + let ls' = Task_local_storage.Private_.Storage.copy ls in + schedule_ self { f = task'; name; ls = ls' } + in + + let run_task (task : task_full) : unit = + cur_ls := task.ls; let _ctx = before_task runner in cur_span := Tracing_.enter_span task.name; - (* run the task now, catching errors *) + + let resume ls k res = + schedule_ self { f = (fun () -> k res); name = task.name; ls } + in + + (* run the task now, catching errors, handling effects *) (try - Suspend_.with_suspend task.f ~name:task.name ~run:run_another_task - ~on_suspend:exit_span_ +[@@@ifge 5.0] + Suspend_.with_suspend (WSH { + run=run_another_task; + resume; + on_suspend; + }) task.f +[@@@else_] + task.f() +[@@@endif] with e -> let bt = Printexc.get_raw_backtrace () in on_exn e bt); exit_span_ (); - after_task runner _ctx + after_task runner _ctx; + cur_ls := Task_local_storage.Private_.Storage.dummy in let main_loop () = @@ -91,7 +118,7 @@ let create ?(on_init_thread = default_thread_init_exit_) | None -> AT_pair (ignore, fun _ _ -> ()) in - let num_domains = D_pool_.n_domains () in + let num_domains = Domain_pool_.n_domains () in (* number of threads to run *) let num_threads = Util_pool_.num_threads ?num_threads () in @@ -104,7 +131,7 @@ let create ?(on_init_thread = default_thread_init_exit_) { threads = Array.make num_threads dummy; q = Bb_queue.create () } in - let run_async ~name f = schedule_ pool { f; name } in + let run_async ~name ~ls f = schedule_ pool { f; name; ls } in let runner = Runner.For_runner_implementors.create @@ -140,7 +167,7 @@ let create ?(on_init_thread = default_thread_init_exit_) (* now run the main loop *) Fun.protect run ~finally:(fun () -> (* on termination, decrease refcount of underlying domain *) - D_pool_.decr_on dom_idx); + Domain_pool_.decr_on dom_idx); on_exit_thread ~dom_id:dom_idx ~t_id () in @@ -152,7 +179,7 @@ let create ?(on_init_thread = default_thread_init_exit_) Bb_queue.push receive_threads (i, thread) in - D_pool_.run_on dom_idx create_thread_in_domain + Domain_pool_.run_on dom_idx create_thread_in_domain in (* start all threads, placing them on the domains diff --git a/src/fifo_pool.mli b/src/core/fifo_pool.mli similarity index 100% rename from src/fifo_pool.mli rename to src/core/fifo_pool.mli diff --git a/src/fut.ml b/src/core/fut.ml similarity index 95% rename from src/fut.ml rename to src/core/fut.ml index 7fed5894..2c7d6896 100644 --- a/src/fut.ml +++ b/src/core/fut.ml @@ -1,6 +1,6 @@ module A = Atomic_ -type 'a or_error = ('a, exn * Printexc.raw_backtrace) result +type 'a or_error = ('a, Exn_bt.t) result type 'a waiter = 'a or_error -> unit type 'a state = @@ -25,6 +25,7 @@ let make ?(name = "") () = let[@inline] of_result x : _ t = { st = A.make (Done x) } let[@inline] return x : _ t = of_result (Ok x) let[@inline] fail e bt : _ t = of_result (Error (e, bt)) +let[@inline] fail_exn_bt ebt = of_result (Error ebt) let[@inline] is_resolved self : bool = match A.get self.st with @@ -41,6 +42,16 @@ let[@inline] is_done self : bool = | Done _ -> true | Waiting _ -> false +let[@inline] is_success self = + match A.get self.st with + | Done (Ok _) -> true + | _ -> false + +let[@inline] is_failed self = + match A.get self.st with + | Done (Error _) -> true + | _ -> false + exception Not_ready let[@inline] get_or_fail self = @@ -94,7 +105,7 @@ let[@inline] fulfill_idempotent self r = (* ### combinators ### *) -let spawn ?name ~on f : _ t = +let spawn ?name ?ls ~on f : _ t = let fut, promise = make () in let task () = @@ -107,13 +118,13 @@ let spawn ?name ~on f : _ t = fulfill promise res in - Runner.run_async ?name on task; + Runner.run_async ?name ?ls on task; fut -let spawn_on_current_runner ?name f : _ t = +let spawn_on_current_runner ?name ?ls f : _ t = match Runner.get_current_runner () with | None -> failwith "Fut.spawn_on_current_runner: not running on a runner" - | Some on -> spawn ?name ~on f + | Some on -> spawn ?name ?ls ~on f let reify_error (f : 'a t) : 'a or_error t = match peek f with @@ -426,11 +437,11 @@ let await (fut : 'a t) : 'a = Suspend_.suspend { Suspend_.handle = - (fun ~name ~run k -> + (fun ~run:_ ~resume k -> on_result fut (function | Ok _ -> (* schedule continuation with the same name *) - run ~name (fun () -> k (Ok ())) + resume k (Ok ()) | Error (exn, bt) -> (* fail continuation immediately *) k (Error (exn, bt)))); @@ -451,3 +462,7 @@ end include Infix module Infix_local = Infix [@@deprecated "use Infix"] + +module Private_ = struct + let[@inline] unsafe_promise_of_fut x = x +end diff --git a/src/fut.mli b/src/core/fut.mli similarity index 92% rename from src/fut.mli rename to src/core/fut.mli index 9b10d420..a82975f3 100644 --- a/src/fut.mli +++ b/src/core/fut.mli @@ -17,7 +17,7 @@ the runner [pool] (once [fut] resolves successfully with a value). *) -type 'a or_error = ('a, exn * Printexc.raw_backtrace) result +type 'a or_error = ('a, Exn_bt.t) result type 'a t (** A future with a result of type ['a]. *) @@ -51,6 +51,10 @@ val return : 'a -> 'a t val fail : exn -> Printexc.raw_backtrace -> _ t (** Already settled future, with a failure *) +val fail_exn_bt : Exn_bt.t -> _ t +(** Fail from a bundle of exception and backtrace + @since NEXT_RELEASE *) + val of_result : 'a or_error -> 'a t val is_resolved : _ t -> bool @@ -80,13 +84,27 @@ val is_done : _ t -> bool (** Is the future resolved? This is the same as [peek fut |> Option.is_some]. @since 0.2 *) +val is_success : _ t -> bool +(** Checks if the future is resolved with [Ok _] as a result. + @since NEXT_RELEASE *) + +val is_failed : _ t -> bool +(** Checks if the future is resolved with [Error _] as a result. + @since NEXT_RELEASE *) + (** {2 Combinators} *) -val spawn : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a t +val spawn : + ?name:string -> + ?ls:Task_local_storage.storage -> + on:Runner.t -> + (unit -> 'a) -> + 'a t (** [spaw ~on f] runs [f()] on the given runner [on], and return a future that will hold its result. *) -val spawn_on_current_runner : ?name:string -> (unit -> 'a) -> 'a t +val spawn_on_current_runner : + ?name:string -> ?ls:Task_local_storage.storage -> (unit -> 'a) -> 'a t (** This must be run from inside a runner, and schedules the new task on it as well. @@ -204,7 +222,8 @@ val for_list : on:Runner.t -> 'a list -> ('a -> unit) -> unit t val await : 'a t -> 'a (** [await fut] suspends the current tasks until [fut] is fulfilled, then - resumes the task on this same runner. + resumes the task on this same runner (but possibly on a different + thread/domain). @since 0.3 @@ -263,3 +282,12 @@ include module type of Infix module Infix_local = Infix [@@deprecated "Use Infix"] (** @deprecated use Infix instead *) + +(**/**) + +module Private_ : sig + val unsafe_promise_of_fut : 'a t -> 'a promise + (** please do not use *) +end + +(**/**) diff --git a/src/immediate_runner.ml b/src/core/immediate_runner.ml similarity index 63% rename from src/immediate_runner.ml rename to src/core/immediate_runner.ml index db9725f5..c260f439 100644 --- a/src/immediate_runner.ml +++ b/src/core/immediate_runner.ml @@ -1,14 +1,23 @@ +open Types_ include Runner -let run_async_ ~name f = +(* convenient alias *) +let k_ls = Task_local_storage.Private_.Storage.k_storage + +let run_async_ ~name ~ls f = + let cur_ls = ref ls in + TLS.set k_ls (Some cur_ls); + cur_ls := ls; let sp = Tracing_.enter_span name in try let x = f () in Tracing_.exit_span sp; + TLS.set k_ls None; x with e -> let bt = Printexc.get_raw_backtrace () in Tracing_.exit_span sp; + TLS.set k_ls None; Printexc.raise_with_backtrace e bt let runner : t = diff --git a/src/immediate_runner.mli b/src/core/immediate_runner.mli similarity index 85% rename from src/immediate_runner.mli rename to src/core/immediate_runner.mli index 8917d8b5..0a07d42a 100644 --- a/src/immediate_runner.mli +++ b/src/core/immediate_runner.mli @@ -11,6 +11,9 @@ Another situation is when threads cannot be used at all (e.g. because you plan to call [Unix.fork] later). + {b NOTE}: this does not handle the [Suspend] effect, so [await], fork-join, + etc. will {b NOT} work on this runner. + @since 0.5 *) diff --git a/src/lock.ml b/src/core/lock.ml similarity index 100% rename from src/lock.ml rename to src/core/lock.ml diff --git a/src/lock.mli b/src/core/lock.mli similarity index 62% rename from src/lock.mli rename to src/core/lock.mli index 41ff47c6..f85f3d49 100644 --- a/src/lock.mli +++ b/src/core/lock.mli @@ -1,5 +1,28 @@ (** Mutex-protected resource. + This lock is a synchronous concurrency primitive, as a thin wrapper + around {!Mutex} that encourages proper management of the critical + section in RAII style: + + {[ + let (let@) = (@@) + + + … + let compute_foo = + (* enter critical section *) + let@ x = Lock.with_ protected_resource in + use_x; + return_foo () + (* exit critical section *) + in + … + ]} + + This lock does not work well with {!Fut.await}. A critical section + that contains a call to [await] might cause deadlocks, or lock starvation, + because it will hold onto the lock while it goes to sleep. + @since 0.3 *) type 'a t diff --git a/src/moonpool.ml b/src/core/moonpool.ml similarity index 66% rename from src/moonpool.ml rename to src/core/moonpool.ml index f2cf0174..c69b5581 100644 --- a/src/moonpool.ml +++ b/src/core/moonpool.ml @@ -1,8 +1,11 @@ +exception Shutdown = Runner.Shutdown + let start_thread_on_some_domain f x = - let did = Random.int (D_pool_.n_domains ()) in - D_pool_.run_on_and_wait did (fun () -> Thread.create f x) + let did = Random.int (Domain_pool_.n_domains ()) in + Domain_pool_.run_on_and_wait did (fun () -> Thread.create f x) let run_async = Runner.run_async +let run_wait_block = Runner.run_wait_block let recommended_thread_count () = Domain_.recommended_number () let spawn = Fut.spawn let spawn_on_current_runner = Fut.spawn_on_current_runner @@ -17,17 +20,20 @@ module Atomic = Atomic_ module Blocking_queue = Bb_queue module Bounded_queue = Bounded_queue module Chan = Chan +module Exn_bt = Exn_bt module Fifo_pool = Fifo_pool -module Fork_join = Fork_join module Fut = Fut module Lock = Lock module Immediate_runner = Immediate_runner -module Pool = Fifo_pool module Runner = Runner +module Task_local_storage = Task_local_storage module Thread_local_storage = Thread_local_storage_ module Ws_pool = Ws_pool module Private = struct module Ws_deque_ = Ws_deque_ module Suspend_ = Suspend_ + module Domain_ = Domain_ + + let num_domains = Domain_pool_.n_domains end diff --git a/src/moonpool.mli b/src/core/moonpool.mli similarity index 82% rename from src/moonpool.mli rename to src/core/moonpool.mli index 5001e178..23ee52d8 100644 --- a/src/moonpool.mli +++ b/src/core/moonpool.mli @@ -13,17 +13,24 @@ module Ws_pool = Ws_pool module Fifo_pool = Fifo_pool module Runner = Runner module Immediate_runner = Immediate_runner +module Exn_bt = Exn_bt -module Pool = Fifo_pool -[@@deprecated "use Fifo_pool or Ws_pool to be more explicit"] -(** Default pool. Please explicitly pick an implementation instead. *) +exception Shutdown +(** Exception raised when trying to run tasks on + runners that have been shut down. + @since NEXT_RELEASE *) val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t (** Similar to {!Thread.create}, but it picks a background domain at random to run the thread. This ensures that we don't always pick the same domain to run all the various threads needed in an application (timers, event loops, etc.) *) -val run_async : ?name:string -> Runner.t -> (unit -> unit) -> unit +val run_async : + ?name:string -> + ?ls:Task_local_storage.storage -> + Runner.t -> + (unit -> unit) -> + unit (** [run_async runner task] schedules the task to run on the given runner. This means [task()] will be executed at some point in the future, possibly in another thread. @@ -32,20 +39,43 @@ val run_async : ?name:string -> Runner.t -> (unit -> unit) -> unit (since NEXT_RELEASE) @since 0.5 *) +val run_wait_block : + ?name:string -> + ?ls:Task_local_storage.storage -> + Runner.t -> + (unit -> 'a) -> + 'a +(** [run_wait_block runner f] schedules [f] for later execution + on the runner, like {!run_async}. + It then blocks the current thread until [f()] is done executing, + and returns its result. If [f()] raises an exception, then [run_wait_block pool f] + will raise it as well. + + {b NOTE} be careful with deadlocks (see notes in {!Fut.wait_block} + about the required discipline to avoid deadlocks). + @raise Shutdown if the runner was already shut down + @since NEXT_RELEASE *) + val recommended_thread_count : unit -> int (** Number of threads recommended to saturate the CPU. For IO pools this makes little sense (you might want more threads than this because many of them will be blocked most of the time). @since 0.5 *) -val spawn : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a Fut.t +val spawn : + ?name:string -> + ?ls:Task_local_storage.storage -> + on:Runner.t -> + (unit -> 'a) -> + 'a Fut.t (** [spawn ~on f] runs [f()] on the runner (a thread pool typically) and returns a future result for it. See {!Fut.spawn}. @param name if provided and [Trace] is present in dependencies, a span will be created for the future. (since 0.6) @since 0.5 *) -val spawn_on_current_runner : ?name:string -> (unit -> 'a) -> 'a Fut.t +val spawn_on_current_runner : + ?name:string -> ?ls:Task_local_storage.storage -> (unit -> 'a) -> 'a Fut.t (** See {!Fut.spawn_on_current_runner}. @param name see {!spawn}. since 0.6. @since 0.5 *) @@ -62,7 +92,7 @@ val await : 'a Fut.t -> 'a module Lock = Lock module Fut = Fut module Chan = Chan -module Fork_join = Fork_join +module Task_local_storage = Task_local_storage module Thread_local_storage = Thread_local_storage_ (** A simple blocking queue. @@ -191,8 +221,10 @@ module Atomic = Atomic_ (**/**) +(** Private internals, with no stability guarantees *) module Private : sig module Ws_deque_ = Ws_deque_ + (** A deque for work stealing, fixed size. *) (** {2 Suspensions} *) @@ -204,4 +236,10 @@ module Private : sig This is only going to work on OCaml 5.x. {b NOTE}: this is not stable for now. *) + + module Domain_ = Domain_ + (** Utils for domains *) + + val num_domains : unit -> int + (** Number of domains in the backing domain pool *) end diff --git a/src/runner.ml b/src/core/runner.ml similarity index 77% rename from src/runner.ml rename to src/core/runner.ml index 437e24c4..207ea56d 100644 --- a/src/runner.ml +++ b/src/core/runner.ml @@ -3,7 +3,7 @@ module TLS = Thread_local_storage_ type task = unit -> unit type t = { - run_async: name:string -> task -> unit; + run_async: name:string -> ls:Task_local_storage.storage -> task -> unit; shutdown: wait:bool -> unit -> unit; size: unit -> int; num_tasks: unit -> int; @@ -11,7 +11,10 @@ type t = { exception Shutdown -let[@inline] run_async ?(name = "") (self : t) f : unit = self.run_async ~name f +let[@inline] run_async ?(name = "") + ?(ls = Task_local_storage.Private_.Storage.create ()) (self : t) f : unit = + self.run_async ~name ~ls f + let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true () let[@inline] shutdown_without_waiting (self : t) : unit = @@ -20,9 +23,9 @@ let[@inline] shutdown_without_waiting (self : t) : unit = let[@inline] num_tasks (self : t) : int = self.num_tasks () let[@inline] size (self : t) : int = self.size () -let run_wait_block ?name self (f : unit -> 'a) : 'a = +let run_wait_block ?name ?ls self (f : unit -> 'a) : 'a = let q = Bb_queue.create () in - run_async ?name self (fun () -> + run_async ?name ?ls self (fun () -> try let x = f () in Bb_queue.push q (Ok x) diff --git a/src/runner.mli b/src/core/runner.mli similarity index 89% rename from src/runner.mli rename to src/core/runner.mli index 3b959496..5b937c09 100644 --- a/src/runner.mli +++ b/src/core/runner.mli @@ -33,16 +33,19 @@ val shutdown_without_waiting : t -> unit exception Shutdown -val run_async : ?name:string -> t -> task -> unit +val run_async : + ?name:string -> ?ls:Task_local_storage.storage -> t -> task -> unit (** [run_async pool f] schedules [f] for later execution on the runner in one of the threads. [f()] will run on one of the runner's worker threads/domains. @param name if provided and [Trace] is present in dependencies, a span will be created when the task starts, and will stop when the task is over. (since NEXT_RELEASE) + @param ls if provided, run the task with this initial local storage @raise Shutdown if the runner was shut down before [run_async] was called. *) -val run_wait_block : ?name:string -> t -> (unit -> 'a) -> 'a +val run_wait_block : + ?name:string -> ?ls:Task_local_storage.storage -> t -> (unit -> 'a) -> 'a (** [run_wait_block pool f] schedules [f] for later execution on the pool, like {!run_async}. It then blocks the current thread until [f()] is done executing, @@ -62,7 +65,7 @@ module For_runner_implementors : sig size:(unit -> int) -> num_tasks:(unit -> int) -> shutdown:(wait:bool -> unit -> unit) -> - run_async:(name:string -> task -> unit) -> + run_async:(name:string -> ls:Task_local_storage.storage -> task -> unit) -> unit -> t (** Create a new runner. diff --git a/src/core/suspend_.ml b/src/core/suspend_.ml new file mode 100644 index 00000000..4d15ac77 --- /dev/null +++ b/src/core/suspend_.ml @@ -0,0 +1,89 @@ +module A = Atomic_ + +type suspension = unit Exn_bt.result -> unit +type task = unit -> unit + +[@@@ifge 5.0] + +type suspension_handler = { + handle: + run:(name:string -> task -> unit) -> + resume:(suspension -> unit Exn_bt.result -> unit) -> + suspension -> + unit; +} +[@@unboxed] + +[@@@ocaml.alert "-unstable"] + +type _ Effect.t += + | Suspend : suspension_handler -> unit Effect.t + | Yield : unit Effect.t + +let[@inline] yield () = Effect.perform Yield +let[@inline] suspend h = Effect.perform (Suspend h) + +type with_suspend_handler = + | WSH : { + on_suspend: unit -> 'state; + (** on_suspend called when [f()] suspends itself. *) + run: 'state -> name:string -> task -> unit; + (** run used to schedule new tasks *) + resume: 'state -> suspension -> unit Exn_bt.result -> unit; + (** resume run the suspension. Must be called exactly once. *) + } + -> with_suspend_handler + +let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit = + let module E = Effect.Deep in + (* effect handler *) + let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = + function + | Suspend h -> + (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) + Some + (fun k -> + let state = on_suspend () in + let k' : suspension = function + | Ok () -> E.continue k () + | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt + in + h.handle ~run:(run state) ~resume:(resume state) k') + | Yield -> + (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) + Some + (fun k -> + let state = on_suspend () in + let k' : suspension = function + | Ok () -> E.continue k () + | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt + in + resume state k' @@ Ok ()) + | _ -> None + in + + E.try_with f () { E.effc } + +(* DLA interop *) +let prepare_for_await () : Dla_.t = + (* current state *) + let st : (_ * suspension) option A.t = A.make None in + + let release () : unit = + match A.exchange st None with + | None -> () + | Some (resume, k) -> resume k @@ Ok () + and await () : unit = + suspend { handle = (fun ~run:_ ~resume k -> A.set st (Some (resume, k))) } + in + + let t = { Dla_.release; await } in + t + +[@@@ocaml.alert "+unstable"] +[@@@else_] + +let[@inline] with_suspend ~on_suspend:_ ~run:_ ~resume:_ f = f () +let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore } + +[@@@endif] diff --git a/src/suspend_.mli b/src/core/suspend_.mli similarity index 54% rename from src/suspend_.mli rename to src/core/suspend_.mli index 0334225f..1fff43ac 100644 --- a/src/suspend_.mli +++ b/src/core/suspend_.mli @@ -3,13 +3,21 @@ This module is an implementation detail of Moonpool and should not be used outside of it, except by experts to implement {!Runner}. *) -type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit +open Types_ + +type suspension = unit Exn_bt.result -> unit (** A suspended computation *) +[@@@ifge 5.0] + type task = unit -> unit type suspension_handler = { - handle: name:string -> run:(name:string -> task -> unit) -> suspension -> unit; + handle: + run:(name:string -> task -> unit) -> + resume:(suspension -> unit Exn_bt.result -> unit) -> + suspension -> + unit; } [@@unboxed] (** The handler that knows what to do with the suspended computation. @@ -21,6 +29,8 @@ type suspension_handler = { eventually); - a [run] function that can be used to start tasks to perform some computation. + - a [resume] function to resume the suspended computation. This + must be called exactly once, in all situations. This means that a fork-join primitive, for example, can use a single call to {!suspend} to: @@ -30,9 +40,9 @@ type suspension_handler = { runs in parallel with the other calls. The calls must coordinate so that, once they are all done, the suspended caller is resumed with the aggregated result of the computation. + - use [resume] exactly *) -[@@@ifge 5.0] [@@@ocaml.alert "-unstable"] type _ Effect.t += @@ -40,32 +50,45 @@ type _ Effect.t += (** The effect used to suspend the current thread and pass it, suspended, to the handler. The handler will ensure that the suspension is resumed later once some computation has been done. *) + | Yield : unit Effect.t + (** The effect used to interrupt the current computation and immediately re-schedule + it on the same runner. *) [@@@ocaml.alert "+unstable"] +val yield : unit -> unit +(** Interrupt current computation, and re-schedule it at the end of the + runner's job queue. *) + val suspend : suspension_handler -> unit (** [suspend h] jumps back to the nearest {!with_suspend} and calls [h.handle] with the current continuation [k] and a task runner function. *) +type with_suspend_handler = + | WSH : { + on_suspend: unit -> 'state; + (** on_suspend called when [f()] suspends itself. *) + run: 'state -> name:string -> task -> unit; + (** run used to schedule new tasks *) + resume: 'state -> suspension -> unit Exn_bt.result -> unit; + (** resume run the suspension. Must be called exactly once. *) + } + -> with_suspend_handler + +val with_suspend : with_suspend_handler -> (unit -> unit) -> unit +(** [with_suspend wsh f] + runs [f()] in an environment where [suspend] will work. + + If [f()] suspends with suspension handler [h], + this calls [wsh.on_suspend()] to capture the current state [st]. + Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension. + The suspension should always be passed exactly once to + [resume]. [run] should be used to start other tasks. +*) + [@@@endif] val prepare_for_await : unit -> Dla_.t (** Our stub for DLA. Unstable. *) - -val with_suspend : - name:string -> - on_suspend:(unit -> unit) -> - run:(name:string -> task -> unit) -> - (unit -> unit) -> - unit -(** [with_suspend ~run f] runs [f()] in an environment where [suspend] - will work. If [f()] suspends with suspension handler [h], - this calls [h ~run k] where [k] is the suspension. - The suspension should always run in a new task, via [run]. - - @param on_suspend called when [f()] suspends itself. - - This will not do anything on OCaml 4.x. -*) diff --git a/src/core/task_local_storage.ml b/src/core/task_local_storage.ml new file mode 100644 index 00000000..87ca1424 --- /dev/null +++ b/src/core/task_local_storage.ml @@ -0,0 +1,70 @@ +open Types_ +module A = Atomic_ + +type 'a key = 'a ls_key + +let key_count_ = A.make 0 + +type storage = task_ls + +let new_key (type t) ~init () : t key = + let offset = A.fetch_and_add key_count_ 1 in + (module struct + type nonrec t = t + type ls_value += V of t + + let offset = offset + let init = init + end : LS_KEY + with type t = t) + +type ls_value += Dummy + +(** Resize array of TLS values *) +let[@inline never] resize_ (cur : ls_value array ref) n = + if n > Sys.max_array_length then failwith "too many task local storage keys"; + let len = Array.length !cur in + let new_ls = + Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy + in + Array.blit !cur 0 new_ls 0 len; + cur := new_ls + +let[@inline] get_cur_ () : ls_value array ref = + match TLS.get k_ls_values with + | Some r -> r + | None -> failwith "Task local storage must be accessed from within a runner." + +let get (type a) ((module K) : a key) : a = + let cur = get_cur_ () in + if K.offset >= Array.length !cur then resize_ cur (K.offset + 1); + match !cur.(K.offset) with + | K.V x -> (* common case first *) x + | Dummy -> + (* first time we access this *) + let v = K.init () in + !cur.(K.offset) <- K.V v; + v + | _ -> assert false + +let set (type a) ((module K) : a key) (v : a) : unit = + let cur = get_cur_ () in + if K.offset >= Array.length !cur then resize_ cur (K.offset + 1); + !cur.(K.offset) <- K.V v; + () + +let with_value key x f = + let old = get key in + set key x; + Fun.protect ~finally:(fun () -> set key old) f + +module Private_ = struct + module Storage = struct + type t = storage + + let k_storage = k_ls_values + let[@inline] create () = [||] + let copy = Array.copy + let dummy = [||] + end +end diff --git a/src/core/task_local_storage.mli b/src/core/task_local_storage.mli new file mode 100644 index 00000000..c2ce778a --- /dev/null +++ b/src/core/task_local_storage.mli @@ -0,0 +1,61 @@ +(** Task-local storage. + + This storage is associated to the current task, + just like thread-local storage is associated with + the current thread. The storage is carried along in case + the current task is suspended. + + @since NEXT_RELEASE +*) + +type storage +(** Underlying storage for a task *) + +type 'a key +(** A key used to access a particular (typed) storage slot on every task. *) + +val new_key : init:(unit -> 'a) -> unit -> 'a key +(** [new_key ~init ()] makes a new key. Keys are expensive and + should never be allocated dynamically or in a loop. + The correct pattern is, at toplevel: + + {[ + let k_foo : foo Task_ocal_storage.key = + Task_local_storage.new_key ~init:(fun () -> make_foo ()) () + + (* … *) + + (* use it: *) + let … = Task_local_storage.get k_foo + ]} +*) + +val get : 'a key -> 'a +(** [get k] gets the value for the current task for key [k]. + Must be run from inside a task running on a runner. + @raise Failure otherwise *) + +val set : 'a key -> 'a -> unit +(** [set k v] sets the storage for [k] to [v]. + Must be run from inside a task running on a runner. + @raise Failure otherwise *) + +val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b +(** [with_value k v f] sets [k] to [v] for the duration of the call + to [f()]. When [f()] returns (or fails), [k] is restored + to its old value. *) + +(**/**) + +module Private_ : sig + module Storage : sig + type t = storage + + val k_storage : t ref option Thread_local_storage_.key + val create : unit -> t + val copy : t -> t + val dummy : t + end +end + +(**/**) diff --git a/src/core/types_.ml b/src/core/types_.ml new file mode 100644 index 00000000..00ffbe23 --- /dev/null +++ b/src/core/types_.ml @@ -0,0 +1,26 @@ +module TLS = Thread_local_storage_ + +type ls_value = .. + +(** Key for task local storage *) +module type LS_KEY = sig + type t + type ls_value += V of t + + val offset : int + (** Unique offset *) + + val init : unit -> t +end + +type 'a ls_key = (module LS_KEY with type t = 'a) +(** A LS key (task local storage) *) + +type task_ls = ls_value array + +(** Store the current LS values for the current thread. + + A worker thread is going to cycle through many tasks, each of which + has its own storage. This key allows tasks running on the worker + to access their own storage *) +let k_ls_values : task_ls ref option TLS.key = TLS.new_key (fun () -> None) diff --git a/src/util_pool_.ml b/src/core/util_pool_.ml similarity index 80% rename from src/util_pool_.ml rename to src/core/util_pool_.ml index 8207062a..666472b4 100644 --- a/src/util_pool_.ml +++ b/src/core/util_pool_.ml @@ -1,5 +1,5 @@ let num_threads ?num_threads () : int = - let n_domains = D_pool_.n_domains () in + let n_domains = Domain_pool_.n_domains () in (* number of threads to run *) let num_threads = diff --git a/src/util_pool_.mli b/src/core/util_pool_.mli similarity index 100% rename from src/util_pool_.mli rename to src/core/util_pool_.mli diff --git a/src/ws_pool.ml b/src/core/ws_pool.ml similarity index 85% rename from src/ws_pool.ml rename to src/core/ws_pool.ml index 364aaa81..d1fd7cf3 100644 --- a/src/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -4,6 +4,7 @@ module TLS = Thread_local_storage_ include Runner let ( let@ ) = ( @@ ) +let k_storage = Task_local_storage.Private_.Storage.k_storage module Id = struct type t = unit ref @@ -13,29 +14,31 @@ module Id = struct let equal : t -> t -> bool = ( == ) end -type task_with_name = { +type task_full = { f: task; name: string; + ls: Task_local_storage.storage; } +type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task + type worker_state = { pool_id_: Id.t; (** Unique per pool *) mutable thread: Thread.t; - q: task_with_name WSQ.t; (** Work stealing queue *) + q: task_full WSQ.t; (** Work stealing queue *) mutable cur_span: int64; + cur_ls: Task_local_storage.storage ref; (** Task storage *) rng: Random.State.t; } (** State for a given worker. Only this worker is allowed to push into the queue, but other workers can come and steal from it if they're idle. *) -type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task - type state = { id_: Id.t; active: bool A.t; (** Becomes [false] when the pool is shutdown. *) workers: worker_state array; (** Fixed set of workers. *) - main_q: task_with_name Queue.t; + main_q: task_full Queue.t; (** Main queue for tasks coming from the outside *) mutable n_waiting: int; (* protected by mutex *) mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *) @@ -72,10 +75,10 @@ let[@inline] try_wake_someone_ (self : state) : unit = ) (** Run [task] as is, on the pool. *) -let schedule_task_ (self : state) ~name (w : worker_state option) (f : task) : - unit = +let schedule_task_ (self : state) ~name ~ls (w : worker_state option) (f : task) + : unit = (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) - let task = { f; name } in + let task = { f; name; ls } in match w with | Some w when Id.equal self.id_ w.pool_id_ -> (* we're on this same pool, schedule in the worker's state. Otherwise @@ -104,9 +107,11 @@ let schedule_task_ (self : state) ~name (w : worker_state option) (f : task) : raise Shutdown (** Run this task, now. Must be called from a worker. *) -let run_task_now_ (self : state) ~runner (w : worker_state) ~name task : unit = +let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task : + unit = (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) let (AT_pair (before_task, after_task)) = self.around_task in + w.cur_ls := ls; let _ctx = before_task runner in w.cur_span <- Tracing_.enter_span name; @@ -115,25 +120,45 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name task : unit = w.cur_span <- Tracing_.dummy_span in - let run_another_task ~name task' = + let on_suspend () = + exit_span_ (); + !(w.cur_ls) + in + + let run_another_task ls ~name task' = let w = find_current_worker_ () in - schedule_task_ self w ~name task' + let ls' = Task_local_storage.Private_.Storage.copy ls in + schedule_task_ self w ~name ~ls:ls' task' + in + + let resume ls k r = + let w = find_current_worker_ () in + schedule_task_ self w ~name ~ls (fun () -> k r) in (* run the task now, catching errors *) (try (* run [task()] and handle [suspend] in it *) - Suspend_.with_suspend task ~name ~run:run_another_task - ~on_suspend:exit_span_ +[@@@ifge 5.0] + Suspend_.with_suspend (WSH { + on_suspend; + run=run_another_task; + resume; + }) task +[@@@else_] + task () +[@@@endif] with e -> let bt = Printexc.get_raw_backtrace () in self.on_exn e bt); - exit_span_ (); - after_task runner _ctx -let[@inline] run_async_ (self : state) ~name (f : task) : unit = + exit_span_ (); + after_task runner _ctx; + w.cur_ls := Task_local_storage.Private_.Storage.dummy + +let[@inline] run_async_ (self : state) ~name ~ls (f : task) : unit = let w = find_current_worker_ () in - schedule_task_ self w ~name f + schedule_task_ self w ~name ~ls f (* TODO: function to schedule many tasks from the outside. - build a queue @@ -150,11 +175,11 @@ let[@inline] wait_ (self : state) : unit = self.n_waiting <- self.n_waiting - 1; if self.n_waiting = 0 then self.n_waiting_nonzero <- false -exception Got_task of task_with_name +exception Got_task of task_full (** Try to steal a task *) -let try_to_steal_work_once_ (self : state) (w : worker_state) : - task_with_name option = +let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option + = let init = Random.State.int w.rng (Array.length self.workers) in try @@ -179,7 +204,7 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit = match WSQ.pop w.q with | Some task -> try_wake_someone_ self; - run_task_now_ self ~runner w ~name:task.name task.f + run_task_now_ self ~runner w ~name:task.name ~ls:task.ls task.f | None -> continue := false done @@ -192,7 +217,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = worker_run_self_tasks_ self ~runner w; try_steal () and run_task task : unit = - run_task_now_ self ~runner w ~name:task.name task.f; + run_task_now_ self ~runner w ~name:task.name ~ls:task.ls task.f; main () and try_steal () = match try_to_steal_work_once_ self w with @@ -251,7 +276,7 @@ type ('a, 'b) create_args = 'a (** Arguments used in {!create}. See {!create} for explanations. *) -let dummy_task_ = { f = ignore; name = "DUMMY_TASK" } +let dummy_task_ = { f = ignore; ls = Task_local_storage.Private_.Storage.dummy ; name = "DUMMY_TASK" } let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) @@ -264,7 +289,7 @@ let create ?(on_init_thread = default_thread_init_exit_) | None -> AT_pair (ignore, fun _ _ -> ()) in - let num_domains = D_pool_.n_domains () in + let num_domains = Domain_pool_.n_domains () in let num_threads = Util_pool_.num_threads ?num_threads () in (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) @@ -279,6 +304,7 @@ let create ?(on_init_thread = default_thread_init_exit_) cur_span = Tracing_.dummy_span; q = WSQ.create ~dummy:dummy_task_ (); rng = Random.State.make [| i |]; + cur_ls = ref Task_local_storage.Private_.Storage.dummy; }) in @@ -300,7 +326,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let runner = Runner.For_runner_implementors.create ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) - ~run_async:(fun ~name f -> run_async_ pool ~name f) + ~run_async:(fun ~name ~ls f -> run_async_ pool ~name ~ls f) ~size:(fun () -> size_ pool) ~num_tasks:(fun () -> num_tasks_ pool) () @@ -320,6 +346,7 @@ let create ?(on_init_thread = default_thread_init_exit_) let thread = Thread.self () in let t_id = Thread.id thread in on_init_thread ~dom_id:dom_idx ~t_id (); + TLS.set k_storage (Some w.cur_ls); (* set thread name *) Option.iter @@ -332,7 +359,7 @@ let create ?(on_init_thread = default_thread_init_exit_) (* now run the main loop *) Fun.protect run ~finally:(fun () -> (* on termination, decrease refcount of underlying domain *) - D_pool_.decr_on dom_idx); + Domain_pool_.decr_on dom_idx); on_exit_thread ~dom_id:dom_idx ~t_id () in @@ -344,7 +371,7 @@ let create ?(on_init_thread = default_thread_init_exit_) Bb_queue.push receive_threads (i, thread) in - D_pool_.run_on dom_idx create_thread_in_domain + Domain_pool_.run_on dom_idx create_thread_in_domain in (* start all threads, placing them on the domains diff --git a/src/ws_pool.mli b/src/core/ws_pool.mli similarity index 100% rename from src/ws_pool.mli rename to src/core/ws_pool.mli diff --git a/src/cpp/dune b/src/cpp/dune index 6ec12a60..c4c75e8b 100644 --- a/src/cpp/dune +++ b/src/cpp/dune @@ -2,4 +2,5 @@ (executable (name cpp) - (modes (best exe))) + (modes + (best exe))) diff --git a/src/dune b/src/dune deleted file mode 100644 index db4763df..00000000 --- a/src/dune +++ /dev/null @@ -1,17 +0,0 @@ -(library - (public_name moonpool) - (name moonpool) - (private_modules d_pool_ dla_ tracing_) - (preprocess - (action - (run %{project_root}/src/cpp/cpp.exe %{input-file}))) - (libraries threads either - (select thread_local_storage_.ml from - (thread-local-storage -> thread_local_storage_.stub.ml) - (-> thread_local_storage_.real.ml)) - (select tracing_.ml from - (trace.core -> tracing_.real.ml) - (-> tracing_.dummy.ml)) - (select dla_.ml from - (domain-local-await -> dla_.real.ml) - ( -> dla_.dummy.ml)))) diff --git a/src/fib/dune b/src/fib/dune new file mode 100644 index 00000000..17a2f48c --- /dev/null +++ b/src/fib/dune @@ -0,0 +1,13 @@ + +(library + (name moonpool_fib) + (public_name moonpool.fib) + (synopsis "Fibers and structured concurrency for Moonpool") + (libraries moonpool) + (enabled_if + (>= %{ocaml_version} 5.0)) + (flags :standard -open Moonpool_private -open Moonpool) + (optional) + (preprocess + (action + (run %{project_root}/src/cpp/cpp.exe %{input-file})))) diff --git a/src/fib/fiber.ml b/src/fib/fiber.ml new file mode 100644 index 00000000..25a4485e --- /dev/null +++ b/src/fib/fiber.ml @@ -0,0 +1,235 @@ +module A = Atomic +module FM = Handle.Map + +type 'a callback = 'a Exn_bt.result -> unit +(** Callbacks that are called when a fiber is done. *) + +type cancel_callback = Exn_bt.t -> unit + +let prom_of_fut : 'a Fut.t -> 'a Fut.promise = + Fut.Private_.unsafe_promise_of_fut + +type 'a t = { + id: Handle.t; (** unique identifier for this fiber *) + state: 'a state A.t; (** Current state in the lifetime of the fiber *) + res: 'a Fut.t; + runner: Runner.t; +} + +and 'a state = + | Alive of { + children: children; + on_cancel: cancel_callback list; + } + | Terminating_or_done of 'a Exn_bt.result A.t + +and children = any FM.t +and any = Any : _ t -> any [@@unboxed] + +let[@inline] res self = self.res +let[@inline] peek self = Fut.peek self.res +let[@inline] is_done self = Fut.is_done self.res +let[@inline] is_success self = Fut.is_success self.res +let[@inline] is_cancelled self = Fut.is_failed self.res +let[@inline] on_result (self : _ t) f = Fut.on_result self.res f + +(** Resolve [promise] once [children] are all done *) +let resolve_once_children_are_done_ ~children ~promise + (res : 'a Exn_bt.result A.t) : unit = + let n_children = FM.cardinal children in + if n_children > 0 then ( + (* wait for all children to be done *) + let n_waiting = A.make (FM.cardinal children) in + let on_child_finish (r : _ result) = + (* make sure the parent fails if any child fails *) + (match r with + | Ok _ -> () + | Error ebt -> A.set res (Error ebt)); + + (* if we're the last to finish, resolve the parent fiber's [res] *) + if A.fetch_and_add n_waiting (-1) = 1 then ( + let res = A.get res in + Fut.fulfill promise res + ) + in + FM.iter (fun _ (Any f) -> Fut.on_result f.res on_child_finish) children + ) else + Fut.fulfill promise @@ A.get res + +let rec resolve_as_failed_ : type a. a t -> Exn_bt.t -> unit = + fun self ebt -> + let promise = prom_of_fut self.res in + while + match A.get self.state with + | Alive { children; on_cancel } as old -> + let new_st = Terminating_or_done (A.make @@ Error ebt) in + if A.compare_and_set self.state old new_st then ( + (* here, unlike in {!resolve_fiber}, we immediately cancel children *) + cancel_children_ ~children ebt; + List.iter (fun cb -> cb ebt) on_cancel; + resolve_once_children_are_done_ ~children ~promise (A.make @@ Error ebt); + false + ) else + true + | Terminating_or_done _ -> false + do + () + done + +(** Cancel eagerly all children *) +and cancel_children_ ebt ~children : unit = + FM.iter (fun _ (Any f) -> resolve_as_failed_ f ebt) children + +(** Successfully resolve the fiber *) +let resolve_ok_ (self : 'a t) (r : 'a) : unit = + let r = A.make @@ Ok r in + let promise = prom_of_fut self.res in + while + match A.get self.state with + | Alive { children; on_cancel = _ } as old -> + let new_st = Terminating_or_done r in + if A.compare_and_set self.state old new_st then ( + resolve_once_children_are_done_ ~children ~promise r; + false + ) else + true + | Terminating_or_done _ -> false + do + () + done + +let remove_child_ (self : _ t) (child : _ t) = + while + match A.get self.state with + | Alive { children; on_cancel } as old -> + let new_st = + Alive { children = FM.remove child.id children; on_cancel } + in + not (A.compare_and_set self.state old new_st) + | _ -> false + do + () + done + +(** Add a child to [self]. + @param protected if true, the child's failure will not affect [self]. *) +let add_child_ ~protect (self : _ t) (child : _ t) = + while + match A.get self.state with + | Alive { children; on_cancel } as old -> + let new_st = + Alive { children = FM.add child.id (Any child) children; on_cancel } + in + + if A.compare_and_set self.state old new_st then ( + (* make sure to remove [child] from [self.children] once it's done; + fail [self] is [child] failed and [protect=false] *) + Fut.on_result child.res (function + | Ok _ -> remove_child_ self child + | Error ebt -> + (* child failed, we must fail too *) + remove_child_ self child; + if not protect then resolve_as_failed_ self ebt); + false + ) else + true + | Terminating_or_done r -> + (match A.get r with + | Error ebt -> + (* cancel child immediately *) + resolve_as_failed_ child ebt + | Ok _ -> ()); + false + do + () + done + +(** Key to access the current fiber. *) +let k_current_fiber : any option Task_local_storage.key = + Task_local_storage.new_key ~init:(fun () -> None) () + +let spawn_ ?name ~on (f : _ -> 'a) : 'a t = + let id = Handle.generate_fresh () in + let res, _promise = Fut.make ?name () in + let fib = + { + state = A.make @@ Alive { children = FM.empty; on_cancel = [] }; + id; + res; + runner = on; + } + in + + let run () = + (* make sure the fiber is accessible from inside itself *) + Task_local_storage.set k_current_fiber (Some (Any fib)); + try + let res = f () in + resolve_ok_ fib res + with exn -> + let bt = Printexc.get_raw_backtrace () in + let ebt = Exn_bt.make exn bt in + resolve_as_failed_ fib ebt + in + + Runner.run_async on ?name run; + + fib + +let[@inline] spawn_top ?name ~on f : _ t = spawn_ ?name ~on f + +let spawn_link ?name ~protect f : _ t = + match Task_local_storage.get k_current_fiber with + | None -> failwith "Fiber.spawn_link: must be run from inside a fiber." + | Some (Any parent) -> + let child = spawn_ ?name ~on:parent.runner f in + add_child_ ~protect parent child; + child + +let add_cancel_cb_ (self : _ t) cb = + while + match A.get self.state with + | Alive { children; on_cancel } as old -> + let new_st = Alive { children; on_cancel = cb :: on_cancel } in + not (A.compare_and_set self.state old new_st) + | Terminating_or_done r -> + (match A.get r with + | Error ebt -> cb ebt + | Ok _ -> ()); + false + do + () + done + +let remove_top_cancel_cb_ (self : _ t) = + while + match A.get self.state with + | Alive { on_cancel = []; _ } -> assert false + | Alive { children; on_cancel = _ :: tl } as old -> + let new_st = Alive { children; on_cancel = tl } in + not (A.compare_and_set self.state old new_st) + | Terminating_or_done _ -> false + do + () + done + +let with_cancel_callback (self : _ t) cb (k : unit -> 'a) : 'a = + add_cancel_cb_ self cb; + Fun.protect k ~finally:(fun () -> remove_top_cancel_cb_ self) + +let[@inline] await self = Fut.await self.res + +module Suspend_ = Moonpool.Private.Suspend_ + +let check_if_cancelled () = + match Task_local_storage.get k_current_fiber with + | None -> + failwith "Fiber.check_if_cancelled: must be run from inside a fiber." + | Some (Any self) -> + (match peek self with + | Some (Error ebt) -> Exn_bt.raise ebt + | _ -> ()) + +let[@inline] yield () : unit = + check_if_cancelled (); + Suspend_.yield () diff --git a/src/fib/fiber.mli b/src/fib/fiber.mli new file mode 100644 index 00000000..dc60b001 --- /dev/null +++ b/src/fib/fiber.mli @@ -0,0 +1,75 @@ +(** Fibers. + + A fiber is a lightweight computation that runs cooperatively + alongside other fibers. In the context of moonpool, fibers + have additional properties: + + - they run in a moonpool runner + - they form a simple supervision tree, enabling a limited form + of structured concurrency +*) + +type 'a t +(** A fiber returning a value of type ['a]. *) + +val res : 'a t -> 'a Fut.t +(** Future result of the fiber. *) + +type 'a callback = 'a Exn_bt.result -> unit +(** Callbacks that are called when a fiber is done. *) + +type cancel_callback = Exn_bt.t -> unit + +val peek : 'a t -> 'a Fut.or_error option +(** Peek inside the future result *) + +val is_done : _ t -> bool +(** Has the fiber completed? *) + +val is_cancelled : _ t -> bool +(** Has the fiber completed with a failure? *) + +val is_success : _ t -> bool +(** Has the fiber completed with a value? *) + +val await : 'a t -> 'a +(** [await fib] is like [Fut.await (res fib)] *) + +val check_if_cancelled : unit -> unit +(** Check if the current fiber is cancelled, in which case this raises. + Must be run from inside a fiber. + @raise Failure if not. *) + +val yield : unit -> unit +(** Yield control to the scheduler from the current fiber. + @raise Failure if not run from inside a fiber. *) + +val with_cancel_callback : _ t -> cancel_callback -> (unit -> 'a) -> 'a +(** [with_cancel_callback fib cb (fun () -> )] evaluates [e] + in a scope in which, if the fiber [fib] is cancelled, + [cb()] is called. If [e] returns without the fiber being cancelled, + this callback is removed. *) + +val on_result : 'a t -> 'a callback -> unit +(** Wait for fiber to be done and call the callback + with the result. If the fiber is done already then the + callback is invoked immediately with its result. *) + +val spawn_top : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a t +(** [spawn_top ~on f] spawns a new (toplevel) fiber onto the given runner. + This fiber is not the child of any other fiber: its lifetime + is only determined by the lifetime of [f()]. *) + +val spawn_link : ?name:string -> protect:bool -> (unit -> 'a) -> 'a t +(** [spawn_link ~protect f] spawns a sub-fiber [f_child] + from a running fiber [parent]. + The sub-fiber [f_child] is attached to the current fiber and fails + if the current fiber [parent] fails. + + @param protect if true, when [f_child] fails, it does not + affect [parent]. If false, [f_child] failing also + causes [parent] to fail (and therefore all other children + of [parent]). + + Must be run from inside a fiber. + @raise Failure if not run from inside a fiber. *) diff --git a/src/fib/fls.ml b/src/fib/fls.ml new file mode 100644 index 00000000..ed2162c4 --- /dev/null +++ b/src/fib/fls.ml @@ -0,0 +1 @@ +include Task_local_storage diff --git a/src/fib/fls.mli b/src/fib/fls.mli new file mode 100644 index 00000000..ccd0d2ee --- /dev/null +++ b/src/fib/fls.mli @@ -0,0 +1,10 @@ +(** Fiber-local storage. + + This storage is associated to the current fiber, + just like thread-local storage is associated with + the current thread. +*) + +include module type of struct + include Task_local_storage +end diff --git a/src/fib/handle.ml b/src/fib/handle.ml new file mode 100644 index 00000000..f73ed58d --- /dev/null +++ b/src/fib/handle.ml @@ -0,0 +1,14 @@ +module A = Atomic + +type t = int + +let counter_ = A.make 0 +let equal : t -> t -> bool = ( = ) +let compare : t -> t -> int = Stdlib.compare +let[@inline] generate_fresh () = A.fetch_and_add counter_ 1 + +(* TODO: better hash *) +let[@inline] hash x = x land max_int + +module Set = Set.Make (Int) +module Map = Map.Make (Int) diff --git a/src/fib/handle.mli b/src/fib/handle.mli new file mode 100644 index 00000000..1fc5b106 --- /dev/null +++ b/src/fib/handle.mli @@ -0,0 +1,14 @@ +(** The unique name of a fiber *) + +type t = private int +(** Unique, opaque identifier for a fiber. *) + +val equal : t -> t -> bool +val compare : t -> t -> int +val hash : t -> int + +val generate_fresh : unit -> t +(** Generate a fresh, unique identifier *) + +module Set : Set.S with type elt = t +module Map : Map.S with type key = t diff --git a/src/forkjoin/dune b/src/forkjoin/dune new file mode 100644 index 00000000..334a6d8b --- /dev/null +++ b/src/forkjoin/dune @@ -0,0 +1,12 @@ + + +(library + (name moonpool_forkjoin) + (public_name moonpool.forkjoin) + (synopsis "Fork-join parallelism for moonpool") + (flags :standard -open Moonpool) + (optional) + (enabled_if + (>= %{ocaml_version} 5.0)) + (libraries + moonpool moonpool.private)) diff --git a/src/fork_join.ml b/src/forkjoin/moonpool_forkjoin.ml similarity index 90% rename from src/fork_join.ml rename to src/forkjoin/moonpool_forkjoin.ml index 8a4b1fc3..27aa1984 100644 --- a/src/fork_join.ml +++ b/src/forkjoin/moonpool_forkjoin.ml @@ -1,6 +1,6 @@ -[@@@ifge 5.0] - -module A = Atomic_ +module A = Moonpool.Atomic +module Suspend_ = Moonpool.Private.Suspend_ +module Domain_ = Moonpool_private.Domain_ module State_ = struct type error = exn * Printexc.raw_backtrace @@ -48,7 +48,7 @@ module State_ = struct Suspend_.suspend { Suspend_.handle = - (fun ~name:_ ~run:_ suspension -> + (fun ~run:_ ~resume suspension -> while let old_st = A.get self in match old_st with @@ -59,7 +59,7 @@ module State_ = struct | Left_solved left -> (* other thread is done, no risk of race condition *) A.set self (Both_solved (left, right)); - suspension (Ok ()); + resume suspension (Ok ()); false | Right_solved _ | Both_solved _ -> assert false do @@ -110,22 +110,22 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = | Some cs -> max 1 (min n cs) | None -> (* guess: try to have roughly one task per core *) - max 1 (1 + (n / D_pool_.n_domains ())) + max 1 (1 + (n / Moonpool.Private.num_domains ())) in - let start_tasks ~name ~run (suspension : Suspend_.suspension) = + let start_tasks ~run ~resume (suspension : Suspend_.suspension) = let task_for ~offset ~len_range = match f offset (offset + len_range - 1) with | () -> if A.fetch_and_add missing (-len_range) = len_range then (* all tasks done successfully *) - run ~name (fun () -> suspension (Ok ())) + resume suspension (Ok ()) | exception exn -> let bt = Printexc.get_raw_backtrace () in if not (A.exchange has_failed true) then (* first one to fail, and [missing] must be >= 2 because we're not decreasing it. *) - run ~name (fun () -> suspension (Error (exn, bt))) + resume suspension (Error (exn, bt)) in let i = ref 0 in @@ -135,7 +135,7 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = let len_range = min chunk_size (n - offset) in assert (offset + len_range <= n); - run ~name (fun () -> task_for ~offset ~len_range); + run ~name:"" (fun () -> task_for ~offset ~len_range); i := !i + len_range done in @@ -143,9 +143,9 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = Suspend_.suspend { Suspend_.handle = - (fun ~name ~run suspension -> + (fun ~run ~resume suspension -> (* run tasks, then we'll resume [suspension] *) - start_tasks ~run ~name suspension); + start_tasks ~run ~resume suspension); } ) @@ -216,5 +216,3 @@ let map_list ?chunk_size f (l : _ list) : _ list = match res.(i) with | None -> assert false | Some x -> x) - -[@@@endif] diff --git a/src/fork_join.mli b/src/forkjoin/moonpool_forkjoin.mli similarity index 99% rename from src/fork_join.mli rename to src/forkjoin/moonpool_forkjoin.mli index 3ffa537d..ba3b80f0 100644 --- a/src/fork_join.mli +++ b/src/forkjoin/moonpool_forkjoin.mli @@ -4,8 +4,6 @@ @since 0.3 *) -[@@@ifge 5.0] - val both : (unit -> 'a) -> (unit -> 'b) -> 'a * 'b (** [both f g] runs [f()] and [g()], potentially in parallel, and returns their result when both are done. @@ -105,5 +103,3 @@ val map_list : ?chunk_size:int -> ('a -> 'b) -> 'a list -> 'b list (** [map_list f l] is like [List.map f l], but runs in parallel. @since 0.3 {b NOTE} this is only available on OCaml 5. *) - -[@@@endif] diff --git a/src/atomic_.ml b/src/private/atomic_.ml similarity index 100% rename from src/atomic_.ml rename to src/private/atomic_.ml diff --git a/src/dla_.dummy.ml b/src/private/dla_.dummy.ml similarity index 100% rename from src/dla_.dummy.ml rename to src/private/dla_.dummy.ml diff --git a/src/dla_.real.ml b/src/private/dla_.real.ml similarity index 99% rename from src/dla_.real.ml rename to src/private/dla_.real.ml index 5f99d714..16901ba2 100644 --- a/src/dla_.real.ml +++ b/src/private/dla_.real.ml @@ -7,3 +7,4 @@ let using : prepare_for_await:(unit -> t) -> while_running:(unit -> 'a) -> 'a = Domain_local_await.using let setup_domain () = Domain_local_await.per_thread (module Thread) + diff --git a/src/domain_.ml b/src/private/domain_.ml similarity index 100% rename from src/domain_.ml rename to src/private/domain_.ml diff --git a/src/private/dune b/src/private/dune new file mode 100644 index 00000000..2d52b3ef --- /dev/null +++ b/src/private/dune @@ -0,0 +1,25 @@ +(library + (name moonpool_private) + (public_name moonpool.private) + (synopsis "Private internal utils for Moonpool") + (preprocess + (action + (run %{project_root}/src/cpp/cpp.exe %{input-file}))) + (libraries + threads + either + (select + thread_local_storage_.ml + from + (thread-local-storage -> thread_local_storage_.stub.ml) + (-> thread_local_storage_.real.ml)) + (select + dla_.ml + from + (domain-local-await -> dla_.real.ml) + (-> dla_.dummy.ml)) + (select + tracing_.ml + from + (trace.core -> tracing_.real.ml) + (-> tracing_.dummy.ml)))) diff --git a/src/thread_local_storage_.mli b/src/private/thread_local_storage_.mli similarity index 100% rename from src/thread_local_storage_.mli rename to src/private/thread_local_storage_.mli diff --git a/src/thread_local_storage_.real.ml b/src/private/thread_local_storage_.real.ml similarity index 100% rename from src/thread_local_storage_.real.ml rename to src/private/thread_local_storage_.real.ml diff --git a/src/thread_local_storage_.stub.ml b/src/private/thread_local_storage_.stub.ml similarity index 98% rename from src/thread_local_storage_.stub.ml rename to src/private/thread_local_storage_.stub.ml index 88712b6d..82d3ff6d 100644 --- a/src/thread_local_storage_.stub.ml +++ b/src/private/thread_local_storage_.stub.ml @@ -1,3 +1,2 @@ - (* just defer to library *) include Thread_local_storage diff --git a/src/tracing_.dummy.ml b/src/private/tracing_.dummy.ml similarity index 100% rename from src/tracing_.dummy.ml rename to src/private/tracing_.dummy.ml diff --git a/src/tracing_.mli b/src/private/tracing_.mli similarity index 100% rename from src/tracing_.mli rename to src/private/tracing_.mli diff --git a/src/tracing_.real.ml b/src/private/tracing_.real.ml similarity index 100% rename from src/tracing_.real.ml rename to src/private/tracing_.real.ml diff --git a/src/ws_deque_.ml b/src/private/ws_deque_.ml similarity index 100% rename from src/ws_deque_.ml rename to src/private/ws_deque_.ml diff --git a/src/ws_deque_.mli b/src/private/ws_deque_.mli similarity index 100% rename from src/ws_deque_.mli rename to src/private/ws_deque_.mli diff --git a/src/suspend_.ml b/src/suspend_.ml deleted file mode 100644 index 7824d917..00000000 --- a/src/suspend_.ml +++ /dev/null @@ -1,62 +0,0 @@ -type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit -type task = unit -> unit - -type suspension_handler = { - handle: name:string -> run:(name:string -> task -> unit) -> suspension -> unit; -} -[@@unboxed] - -[@@@ifge 5.0] -[@@@ocaml.alert "-unstable"] - -module A = Atomic_ - -type _ Effect.t += Suspend : suspension_handler -> unit Effect.t - -let[@inline] suspend h = Effect.perform (Suspend h) - -let with_suspend ~name ~on_suspend ~(run : name:string -> task -> unit) - (f : unit -> unit) : unit = - let module E = Effect.Deep in - (* effect handler *) - let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = - function - | Suspend h -> - Some - (fun k -> - on_suspend (); - let k' : suspension = function - | Ok () -> E.continue k () - | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt - in - h.handle ~name ~run k') - | _ -> None - in - - E.try_with f () { E.effc } - -(* DLA interop *) -let prepare_for_await () : Dla_.t = - (* current state *) - let st : (string * (name:string -> task -> unit) * suspension) option A.t = - A.make None - in - - let release () : unit = - match A.exchange st None with - | None -> () - | Some (name, run, k) -> run ~name (fun () -> k (Ok ())) - and await () : unit = - suspend { handle = (fun ~name ~run k -> A.set st (Some (name, run, k))) } - in - - let t = { Dla_.release; await } in - t - -[@@@ocaml.alert "+unstable"] -[@@@else_] - -let[@inline] with_suspend ~name:_ ~on_suspend:_ ~run:_ f = f () -let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore } - -[@@@endif] diff --git a/test/effect-based/dune b/test/effect-based/dune index 9989823f..4b654519 100644 --- a/test/effect-based/dune +++ b/test/effect-based/dune @@ -1,11 +1,24 @@ - (tests - (names t_fib1 t_futs1 t_many t_fib_fork_join - t_fib_fork_join_all t_sort t_fork_join t_fork_join_heavy) - (preprocess (action + (names + t_fib1 + t_futs1 + t_many + t_fib_fork_join + t_fib_fork_join_all + t_sort + t_fork_join + t_fork_join_heavy) + (preprocess + (action (run %{project_root}/src/cpp/cpp.exe %{input-file}))) - (enabled_if (>= %{ocaml_version} 5.0)) - (libraries moonpool trace trace-tef - qcheck-core qcheck-core.runner - ;tracy-client.trace - )) + (enabled_if + (>= %{ocaml_version} 5.0)) + (libraries + moonpool + moonpool.forkjoin + trace + trace-tef + qcheck-core + qcheck-core.runner + ;tracy-client.trace + )) diff --git a/test/effect-based/t_fib1.ml b/test/effect-based/t_fib1.ml index a7c8ebee..5a9d66e6 100644 --- a/test/effect-based/t_fib1.ml +++ b/test/effect-based/t_fib1.ml @@ -2,6 +2,8 @@ open Moonpool +let ( let@ ) = ( @@ ) + let rec fib_direct x = if x <= 1 then 1 @@ -18,7 +20,7 @@ let fib ~on x : int Fut.t = Fut.await t1 + Fut.await t2 ) in - Fut.spawn ~on (fun () -> fib_rec x) + Fut.spawn ~name:"fib" ~on (fun () -> fib_rec x) (* NOTE: for tracy support let () = Tracy_client_trace.setup () @@ -46,9 +48,13 @@ let run_test () = assert (res = Ok (Array.make 3 fib_40)) -let () = +let main () = (* now make sure we can do this with multiple pools in parallel *) let jobs = Array.init 2 (fun _ -> Thread.create run_test ()) in Array.iter Thread.join jobs +let () = + let@ () = Trace_tef.with_setup () in + main () + [@@@endif] diff --git a/test/effect-based/t_fib_fork_join.ml b/test/effect-based/t_fib_fork_join.ml index 4e6639b2..25e7d49d 100644 --- a/test/effect-based/t_fib_fork_join.ml +++ b/test/effect-based/t_fib_fork_join.ml @@ -1,6 +1,7 @@ [@@@ifge 5.0] open Moonpool +module FJ = Moonpool_forkjoin let rec fib_direct x = if x <= 1 then @@ -14,7 +15,7 @@ let fib ~on x : int Fut.t = fib_direct x else ( let n1, n2 = - Fork_join.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) + FJ.both (fun () -> fib_rec (x - 1)) (fun () -> fib_rec (x - 2)) in n1 + n2 ) diff --git a/test/effect-based/t_fib_fork_join_all.ml b/test/effect-based/t_fib_fork_join_all.ml index 3caee9b9..f80670ca 100644 --- a/test/effect-based/t_fib_fork_join_all.ml +++ b/test/effect-based/t_fib_fork_join_all.ml @@ -3,6 +3,7 @@ let ( let@ ) = ( @@ ) open Moonpool +module FJ = Moonpool_forkjoin let rec fib_direct x = if x <= 1 then @@ -15,9 +16,7 @@ let rec fib x : int = if x <= 18 then fib_direct x else ( - let n1, n2 = - Fork_join.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2)) - in + let n1, n2 = FJ.both (fun () -> fib (x - 1)) (fun () -> fib (x - 2)) in n1 + n2 ) @@ -32,7 +31,7 @@ let run_test () = let fut = Fut.spawn ~on:pool (fun () -> - let fibs = Fork_join.all_init 3 (fun _ -> fib 40) in + let fibs = FJ.all_init 3 (fun _ -> fib 40) in fibs) in diff --git a/test/effect-based/t_fork_join.ml b/test/effect-based/t_fork_join.ml index 5c7134ca..83c291ab 100644 --- a/test/effect-based/t_fork_join.ml +++ b/test/effect-based/t_fork_join.ml @@ -4,6 +4,7 @@ let spf = Printf.sprintf let ( let@ ) = ( @@ ) open! Moonpool +module FJ = Moonpool_forkjoin let pool = Ws_pool.create ~num_threads:4 () @@ -11,7 +12,7 @@ let () = let x = Ws_pool.run_wait_block pool (fun () -> let x, y = - Fork_join.both + FJ.both (fun () -> Thread.delay 0.005; 1) @@ -26,7 +27,7 @@ let () = let () = try Ws_pool.run_wait_block pool (fun () -> - Fork_join.both_ignore + FJ.both_ignore (fun () -> Thread.delay 0.005) (fun () -> Thread.delay 0.02; @@ -37,21 +38,20 @@ let () = let () = let par_sum = Ws_pool.run_wait_block pool (fun () -> - Fork_join.all_init 42 (fun i -> i * i) |> List.fold_left ( + ) 0) + FJ.all_init 42 (fun i -> i * i) |> List.fold_left ( + ) 0) in let exp_sum = List.init 42 (fun x -> x * x) |> List.fold_left ( + ) 0 in assert (par_sum = exp_sum) let () = - Ws_pool.run_wait_block pool (fun () -> - Fork_join.for_ 0 (fun _ _ -> assert false)); + Ws_pool.run_wait_block pool (fun () -> FJ.for_ 0 (fun _ _ -> assert false)); () let () = let total_sum = Atomic.make 0 in Ws_pool.run_wait_block pool (fun () -> - Fork_join.for_ ~chunk_size:5 100 (fun low high -> + FJ.for_ ~chunk_size:5 100 (fun low high -> (* iterate on the range sequentially. The range should have 5 items or less. *) let local_sum = ref 0 in for i = low to high do @@ -64,7 +64,7 @@ let () = let total_sum = Atomic.make 0 in Ws_pool.run_wait_block pool (fun () -> - Fork_join.for_ ~chunk_size:1 100 (fun low high -> + FJ.for_ ~chunk_size:1 100 (fun low high -> assert (low = high); ignore (Atomic.fetch_and_add total_sum low : int))); assert (Atomic.get total_sum = 4950) @@ -82,7 +82,7 @@ let rec fib_fork_join n = fib_direct n else ( let a, b = - Fork_join.both + FJ.both (fun () -> fib_fork_join (n - 1)) (fun () -> fib_fork_join (n - 2)) in @@ -254,13 +254,13 @@ module Evaluator = struct | Ret x -> x | Comp_fib n -> fib_fork_join n | Add (a, b) -> - let a, b = Fork_join.both (fun () -> eval a) (fun () -> eval b) in + let a, b = FJ.both (fun () -> eval a) (fun () -> eval b) in a + b | Pipe (a, f) -> eval a |> apply_fun_seq f | Map_arr (chunk_size, f, a, r) -> let tasks = List.map (fun x () -> eval x) a in - Fork_join.all_list ~chunk_size tasks - |> Fork_join.map_list ~chunk_size (apply_fun_seq f) + FJ.all_list ~chunk_size tasks + |> FJ.map_list ~chunk_size (apply_fun_seq f) |> eval_reducer r in @@ -290,12 +290,8 @@ let t_for_nested ~min ~chunk_size () = let l1, l2 = let@ pool = Ws_pool.with_ ~num_threads:min () in let@ () = Ws_pool.run_wait_block pool in - let l1 = - Fork_join.map_list ~chunk_size (Fork_join.map_list ~chunk_size neg) l - in - let l2 = - Fork_join.map_list ~chunk_size (Fork_join.map_list ~chunk_size neg) l1 - in + let l1 = FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l in + let l2 = FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l1 in l1, l2 in @@ -313,12 +309,8 @@ let t_map ~chunk_size () = let@ pool = Ws_pool.with_ ~num_threads:4 () in let@ () = Ws_pool.run_wait_block pool in - let a1 = - Fork_join.map_list ~chunk_size string_of_int l |> Array.of_list - in - let a2 = - Fork_join.map_array ~chunk_size string_of_int @@ Array.of_list l - in + let a1 = FJ.map_list ~chunk_size string_of_int l |> Array.of_list in + let a2 = FJ.map_array ~chunk_size string_of_int @@ Array.of_list l in if a1 <> a2 then Q.Test.fail_reportf "a1=%s, a2=%s" (ppa a1) (ppa a2); true) diff --git a/test/effect-based/t_fork_join_heavy.ml b/test/effect-based/t_fork_join_heavy.ml index bacb1d18..7fac119c 100644 --- a/test/effect-based/t_fork_join_heavy.ml +++ b/test/effect-based/t_fork_join_heavy.ml @@ -7,6 +7,7 @@ let ( let@ ) = ( @@ ) let ppl = Q.Print.(list @@ list int) open! Moonpool +module FJ = Moonpool_forkjoin let run ~min () = let@ _sp = @@ -31,17 +32,13 @@ let run ~min () = let@ () = Ws_pool.run_wait_block pool in let l1, l2 = - Fork_join.both + FJ.both (fun () -> let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.left" in - Fork_join.map_list ~chunk_size - (Fork_join.map_list ~chunk_size neg) - l) + FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) l) (fun () -> let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "fj.right" in - Fork_join.map_list ~chunk_size - (Fork_join.map_list ~chunk_size neg) - ref_l1) + FJ.map_list ~chunk_size (FJ.map_list ~chunk_size neg) ref_l1) in l1, l2 in diff --git a/test/effect-based/t_sort.ml b/test/effect-based/t_sort.ml index 8ccc372f..f0da71b8 100644 --- a/test/effect-based/t_sort.ml +++ b/test/effect-based/t_sort.ml @@ -1,6 +1,7 @@ [@@@ifge 5.0] open Moonpool +module FJ = Moonpool_forkjoin let rec select_sort arr i len = if len >= 2 then ( @@ -54,7 +55,7 @@ let rec quicksort arr i len : unit = ) done; - Fork_join.both_ignore + FJ.both_ignore (fun () -> quicksort arr i (!low - i)) (fun () -> quicksort arr !low (len - (!low - i))) )