diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 20fd66df..20d74435 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -19,11 +19,12 @@ jobs: dune-cache: true allow-prerelease-opam: true - - name: Deps - run: opam install odig moonpool moonpool-lwt + # temporary until it's in a release + - run: opam pin https://github.com/ocaml-multicore/picos.git -y -n - - name: Build - run: opam exec -- odig odoc --cache-dir=_doc/ moonpool moonpool-lwt + - run: opam install odig moonpool moonpool-lwt moonpool-io + + - run: opam exec -- odig odoc --cache-dir=_doc/ moonpool moonpool-lwt - name: Deploy uses: peaceiris/actions-gh-pages@v3 diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5863e343..a16fe420 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,7 +16,6 @@ jobs: os: - ubuntu-latest ocaml-compiler: - - '4.08' - '4.14' - '5.2' @@ -30,14 +29,17 @@ jobs: dune-cache: true allow-prerelease-opam: true - - run: opam install -t moonpool moonpool-lwt --deps-only + # temporary until it's in a release + - run: opam pin https://github.com/ocaml-multicore/picos.git -y -n + + - run: opam install -t moonpool moonpool-lwt moonpool-io --deps-only if: matrix.ocaml-compiler == '5.2' - run: opam install -t moonpool --deps-only if: matrix.ocaml-compiler != '5.2' - run: opam exec -- dune build @install # install some depopts - - run: opam install thread-local-storage trace domain-local-await + - run: opam install thread-local-storage trace hmap if: matrix.ocaml-compiler == '5.2' - run: opam exec -- dune build --profile=release --force @install @runtest @@ -63,7 +65,10 @@ jobs: dune-cache: true allow-prerelease-opam: true - - run: opam install -t moonpool moonpool-lwt --deps-only + # temporary until it's in a release + - run: opam pin https://github.com/ocaml-multicore/picos.git -y -n + + - run: opam install -t moonpool moonpool-lwt moonpool-io --deps-only - run: opam exec -- dune build @install # install some depopts - run: opam install thread-local-storage trace domain-local-await diff --git a/.gitignore b/.gitignore index 76301f0c..ea2c377b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ _build _opam +*.tmp diff --git a/Makefile b/Makefile index 7b4e63d1..8ca1e86a 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,9 @@ clean: test: @dune runtest $(DUNE_OPTS) +test-autopromote: + @dune runtest $(DUNE_OPTS) --auto-promote + doc: @dune build $(DUNE_OPTS) @doc diff --git a/dune-project b/dune-project index 2fbaf0d8..36a3d94e 100644 --- a/dune-project +++ b/dune-project @@ -16,21 +16,24 @@ (name moonpool) (synopsis "Pools of threads supported by a pool of domains") (depends - (ocaml (>= 4.08)) + (ocaml (>= 4.14)) dune (either (>= 1.0)) (trace :with-test) (trace-tef :with-test) (qcheck-core (and :with-test (>= 0.19))) + (thread-local-storage (and (>= 0.2) (< 0.3))) (odoc :with-doc) + (hmap :with-test) + (picos (and (>= 0.5) (< 0.6))) + (picos_std (and (>= 0.5) (< 0.6))) (mdx (and (>= 1.9.0) :with-test))) (depopts - (trace (>= 0.6)) - thread-local-storage) - (conflicts (thread-local-storage (< 0.2))) + hmap + (trace (>= 0.6))) (tags (thread pool domain futures fork-join))) @@ -47,4 +50,17 @@ (trace-tef :with-test) (odoc :with-doc))) +(package + (name moonpool-io) + (synopsis "Async IO for moonpool, relying on picos") + (allow_empty) ; on < 5.0 + (depends + (moonpool (= :version)) + (picos_io (and (>= 0.5) (< 0.6))) + (ocaml (>= 5.0)) + (trace :with-test) + (trace-tef :with-test) + (odoc :with-doc))) + + ; See the complete stanza docs at https://dune.readthedocs.io/en/stable/dune-files.html#dune-project diff --git a/moonpool-io.opam b/moonpool-io.opam new file mode 100644 index 00000000..76a6ab82 --- /dev/null +++ b/moonpool-io.opam @@ -0,0 +1,33 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +version: "0.6" +synopsis: "Async IO for moonpool, relying on picos" +maintainer: ["Simon Cruanes"] +authors: ["Simon Cruanes"] +license: "MIT" +homepage: "https://github.com/c-cube/moonpool" +bug-reports: "https://github.com/c-cube/moonpool/issues" +depends: [ + "dune" {>= "3.0"} + "moonpool" {= version} + "picos_io" {>= "0.5" & < "0.6"} + "ocaml" {>= "5.0"} + "trace" {with-test} + "trace-tef" {with-test} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/c-cube/moonpool.git" diff --git a/moonpool.opam b/moonpool.opam index 338cddc1..f3c884ac 100644 --- a/moonpool.opam +++ b/moonpool.opam @@ -9,21 +9,22 @@ tags: ["thread" "pool" "domain" "futures" "fork-join"] homepage: "https://github.com/c-cube/moonpool" bug-reports: "https://github.com/c-cube/moonpool/issues" depends: [ - "ocaml" {>= "4.08"} + "ocaml" {>= "4.14"} "dune" {>= "3.0"} "either" {>= "1.0"} "trace" {with-test} "trace-tef" {with-test} "qcheck-core" {with-test & >= "0.19"} + "thread-local-storage" {>= "0.2" & < "0.3"} "odoc" {with-doc} + "hmap" {with-test} + "picos" {>= "0.5" & < "0.6"} + "picos_std" {>= "0.5" & < "0.6"} "mdx" {>= "1.9.0" & with-test} ] depopts: [ + "hmap" "trace" {>= "0.6"} - "thread-local-storage" -] -conflicts: [ - "thread-local-storage" {< "0.2"} ] build: [ ["dune" "subst"] {dev} diff --git a/src/core/dune b/src/core/dune index 1c39c97b..015e9ce6 100644 --- a/src/core/dune +++ b/src/core/dune @@ -1,9 +1,18 @@ (library (public_name moonpool) (name moonpool) - (libraries moonpool.private moonpool.dpool) + (libraries + moonpool.private + (re_export thread-local-storage) + (select + hmap_ls_.ml + from + (hmap -> hmap_ls_.real.ml) + (-> hmap_ls_.dummy.ml)) + moonpool.dpool + (re_export picos)) (flags :standard -open Moonpool_private) - (private_modules types_ util_pool_) + (private_modules util_pool_) (preprocess (action (run %{project_root}/src/cpp/cpp.exe %{input-file})))) diff --git a/src/core/exn_bt.ml b/src/core/exn_bt.ml index dc1fab0f..cfed6421 100644 --- a/src/core/exn_bt.ml +++ b/src/core/exn_bt.ml @@ -3,6 +3,9 @@ type t = exn * Printexc.raw_backtrace let[@inline] make exn bt : t = exn, bt let[@inline] exn (e, _) = e let[@inline] bt (_, bt) = bt +let show self = Printexc.to_string (exn self) +let pp out self = Format.pp_print_string out (show self) +let[@inline] raise (e, bt) = Printexc.raise_with_backtrace e bt let[@inline] get exn = let bt = Printexc.get_raw_backtrace () in @@ -12,8 +15,8 @@ let[@inline] get_callstack n exn = let bt = Printexc.get_callstack n in make exn bt -let show self = Printexc.to_string (fst self) -let pp out self = Format.pp_print_string out (show self) -let[@inline] raise self = Printexc.raise_with_backtrace (exn self) (bt self) - type nonrec 'a result = ('a, t) result + +let[@inline] unwrap = function + | Ok x -> x + | Error ebt -> raise ebt diff --git a/src/core/exn_bt.mli b/src/core/exn_bt.mli index eb8f1b02..665acfdb 100644 --- a/src/core/exn_bt.mli +++ b/src/core/exn_bt.mli @@ -1,27 +1,29 @@ (** Exception with backtrace. + Type changed @since NEXT_RELEASE + @since 0.6 *) -type t = exn * Printexc.raw_backtrace (** An exception bundled with a backtrace *) +type t = exn * Printexc.raw_backtrace + val exn : t -> exn val bt : t -> Printexc.raw_backtrace +val raise : t -> 'a +val get : exn -> t +val get_callstack : int -> exn -> t val make : exn -> Printexc.raw_backtrace -> t (** Trivial builder *) -val get : exn -> t -(** [get exn] is [make exn (get_raw_backtrace ())] *) - -val get_callstack : int -> exn -> t - -val raise : t -> 'a -(** Raise the exception with its save backtrace *) - val show : t -> string (** Simple printing *) val pp : Format.formatter -> t -> unit type nonrec 'a result = ('a, t) result + +val unwrap : 'a result -> 'a +(** [unwrap (Ok x)] is [x], [unwrap (Error ebt)] re-raises [ebt]. + @since NEXT_RELEASE *) diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index a16d5b08..5c6c13e3 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -1,95 +1,38 @@ open Types_ include Runner +module WL = Worker_loop_ + +type fiber = Picos.Fiber.t +type task_full = WL.task_full let ( let@ ) = ( @@ ) -type task_full = - | T_start of { - ls: Task_local_storage.t; - f: task; - } - | T_resume : { - ls: Task_local_storage.t; - k: 'a -> unit; - x: 'a; - } - -> task_full - type state = { threads: Thread.t array; q: task_full Bb_queue.t; (** Queue for tasks. *) + around_task: WL.around_task; + mutable as_runner: t; + (* init options *) + name: string option; + on_init_thread: dom_id:int -> t_id:int -> unit -> unit; + on_exit_thread: dom_id:int -> t_id:int -> unit -> unit; + on_exn: exn -> Printexc.raw_backtrace -> unit; } (** internal state *) +type worker_state = { + idx: int; + dom_idx: int; + st: state; +} + let[@inline] size_ (self : state) = Array.length self.threads let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q - -(** Run [task] as is, on the pool. *) -let schedule_ (self : state) (task : task_full) : unit = - try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown - -type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task -type worker_state = { mutable cur_ls: Task_local_storage.t option } - let k_worker_state : worker_state TLS.t = TLS.create () -let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = - let w = { cur_ls = None } in - TLS.set k_worker_state w; - TLS.set Runner.For_runner_implementors.k_cur_runner runner; - - let (AT_pair (before_task, after_task)) = around_task in - - let on_suspend () = - match TLS.get_opt k_worker_state with - | Some { cur_ls = Some ls; _ } -> ls - | _ -> assert false - in - let run_another_task ls task' = schedule_ self @@ T_start { f = task'; ls } in - let resume ls k res = schedule_ self @@ T_resume { ls; k; x = res } in - - let run_task (task : task_full) : unit = - let ls = - match task with - | T_start { ls; _ } | T_resume { ls; _ } -> ls - in - w.cur_ls <- Some ls; - TLS.set k_cur_storage ls; - let _ctx = before_task runner in - - (* run the task now, catching errors, handling effects *) - (try - match task with - | T_start { f = task; _ } -> - (* run [task()] and handle [suspend] in it *) - Suspend_.with_suspend - (WSH { on_suspend; run = run_another_task; resume }) - task - | T_resume { k; x; _ } -> - (* this is already in an effect handler *) - k x - with e -> - let bt = Printexc.get_raw_backtrace () in - on_exn e bt); - after_task runner _ctx; - w.cur_ls <- None; - TLS.set k_cur_storage _dummy_ls - in - - let main_loop () = - let continue = ref true in - while !continue do - match Bb_queue.pop self.q with - | task -> run_task task - | exception Bb_queue.Closed -> continue := false - done - in - - try - (* handle domain-local await *) - Dla_.using ~prepare_for_await:Suspend_.prepare_for_await - ~while_running:main_loop - with Bb_queue.Closed -> () +(* +get_thread_state = TLS.get_opt k_worker_state + *) let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () @@ -106,10 +49,14 @@ type ('a, 'b) create_args = ?name:string -> 'a -let default_around_task_ : around_task = AT_pair (ignore, fun _ _ -> ()) +let default_around_task_ : WL.around_task = AT_pair (ignore, fun _ _ -> ()) + +(** Run [task] as is, on the pool. *) +let schedule_ (self : state) (task : task_full) : unit = + try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown let runner_of_state (pool : state) : t = - let run_async ~ls f = schedule_ pool @@ T_start { f; ls } in + let run_async ~fiber f = schedule_ pool @@ T_start { f; fiber } in Runner.For_runner_implementors.create ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) ~run_async @@ -117,16 +64,79 @@ let runner_of_state (pool : state) : t = ~num_tasks:(fun () -> num_tasks_ pool) () -let create ?(on_init_thread = default_thread_init_exit_) +(** Run [task] as is, on the pool. *) +let schedule_w (self : worker_state) (task : task_full) : unit = + try Bb_queue.push self.st.q task with Bb_queue.Closed -> raise Shutdown + +let get_next_task (self : worker_state) = + try Bb_queue.pop self.st.q with Bb_queue.Closed -> raise WL.No_more_tasks + +let get_thread_state () = + match TLS.get_exn k_worker_state with + | st -> st + | exception TLS.Not_set -> + failwith "Moonpool: get_thread_state called from outside a runner." + +let before_start (self : worker_state) = + let t_id = Thread.id @@ Thread.self () in + self.st.on_init_thread ~dom_id:self.dom_idx ~t_id (); + + (* set thread name *) + Option.iter + (fun name -> + Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx)) + self.st.name + +let cleanup (self : worker_state) : unit = + (* on termination, decrease refcount of underlying domain *) + Domain_pool_.decr_on self.dom_idx; + let t_id = Thread.id @@ Thread.self () in + self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id () + +let worker_ops : worker_state WL.ops = + let runner (st : worker_state) = st.st.as_runner in + let around_task st = st.st.around_task in + let on_exn (st : worker_state) (ebt : Exn_bt.t) = + st.st.on_exn (Exn_bt.exn ebt) (Exn_bt.bt ebt) + in + { + WL.schedule = schedule_w; + runner; + get_next_task; + get_thread_state; + around_task; + on_exn; + before_start; + cleanup; + } + +let create_ ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) - ?around_task ?num_threads ?name () : t = + ?around_task ~threads ?name () : state = (* wrapper *) let around_task = match around_task with - | Some (f, g) -> AT_pair (f, g) + | Some (f, g) -> WL.AT_pair (f, g) | None -> default_around_task_ in + let self = + { + threads; + q = Bb_queue.create (); + around_task; + as_runner = Runner.dummy; + name; + on_init_thread; + on_exit_thread; + on_exn; + } + in + self.as_runner <- runner_of_state self; + self + +let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads + ?name () : t = let num_domains = Domain_pool_.max_number_of_domains () in (* number of threads to run *) @@ -137,9 +147,10 @@ let create ?(on_init_thread = default_thread_init_exit_) let pool = let dummy_thread = Thread.self () in - { threads = Array.make num_threads dummy_thread; q = Bb_queue.create () } + let threads = Array.make num_threads dummy_thread in + create_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ~threads ?name + () in - let runner = runner_of_state pool in (* temporary queue used to obtain thread handles from domains @@ -150,31 +161,11 @@ let create ?(on_init_thread = default_thread_init_exit_) let start_thread_with_idx i = let dom_idx = (offset + i) mod num_domains in - (* function run in the thread itself *) - let main_thread_fun () : unit = - let thread = Thread.self () in - let t_id = Thread.id thread in - on_init_thread ~dom_id:dom_idx ~t_id (); - - (* set thread name *) - Option.iter - (fun name -> - Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i)) - name; - - let run () = worker_thread_ pool runner ~on_exn ~around_task in - - (* now run the main loop *) - Fun.protect run ~finally:(fun () -> - (* on termination, decrease refcount of underlying domain *) - Domain_pool_.decr_on dom_idx); - on_exit_thread ~dom_id:dom_idx ~t_id () - in - (* function called in domain with index [i], to create the thread and push it into [receive_threads] *) let create_thread_in_domain () = - let thread = Thread.create main_thread_fun () in + let st = { idx = i; dom_idx; st = pool } in + let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in (* send the thread from the domain back to us *) Bb_queue.push receive_threads (i, thread) in @@ -206,11 +197,12 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads f pool module Private_ = struct - type nonrec state = state + type nonrec worker_state = worker_state - let create_state ~threads () : state = { threads; q = Bb_queue.create () } - let runner_of_state = runner_of_state + let worker_ops = worker_ops + let runner_of_state (self : worker_state) = worker_ops.runner self - let run_thread (st : state) (self : t) ~on_exn : unit = - worker_thread_ st self ~on_exn ~around_task:default_around_task_ + let create_single_threaded_state ~thread ?on_exn () : worker_state = + let st : state = create_ ?on_exn ~threads:[| thread |] () in + { idx = 0; dom_idx = 0; st } end diff --git a/src/core/fifo_pool.mli b/src/core/fifo_pool.mli index d7d103cf..11ba4ed5 100644 --- a/src/core/fifo_pool.mli +++ b/src/core/fifo_pool.mli @@ -48,13 +48,17 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args (**/**) module Private_ : sig - type state + type worker_state - val create_state : threads:Thread.t array -> unit -> state - val runner_of_state : state -> Runner.t + val worker_ops : worker_state Worker_loop_.ops - val run_thread : - state -> t -> on_exn:(exn -> Printexc.raw_backtrace -> unit) -> unit + val create_single_threaded_state : + thread:Thread.t -> + ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> + unit -> + worker_state + + val runner_of_state : worker_state -> Runner.t end (**/**) diff --git a/src/core/fut.ml b/src/core/fut.ml index 17afb908..58b6e443 100644 --- a/src/core/fut.ml +++ b/src/core/fut.ml @@ -1,118 +1,113 @@ module A = Atomic_ +module C = Picos.Computation type 'a or_error = ('a, Exn_bt.t) result type 'a waiter = 'a or_error -> unit - -type 'a state = - | Done of 'a or_error - | Waiting of { waiters: 'a waiter list } - -type 'a t = { st: 'a state A.t } [@@unboxed] +type 'a t = { st: 'a C.t } [@@unboxed] type 'a promise = 'a t +let[@inline] make_ () : _ t = + let fut = { st = C.create ~mode:`LIFO () } in + fut + let make () = - let fut = { st = A.make (Waiting { waiters = [] }) } in + let fut = make_ () in fut, fut -let[@inline] of_result x : _ t = { st = A.make (Done x) } -let[@inline] return x : _ t = of_result (Ok x) -let[@inline] fail e bt : _ t = of_result (Error (e, bt)) -let[@inline] fail_exn_bt ebt = of_result (Error ebt) +let[@inline] return x : _ t = { st = C.returned x } -let[@inline] is_resolved self : bool = - match A.get self.st with - | Done _ -> true - | Waiting _ -> false +let[@inline] fail exn bt : _ t = + let st = C.create () in + C.cancel st exn bt; + { st } -let[@inline] peek self : _ option = - match A.get self.st with - | Done x -> Some x - | Waiting _ -> None +let[@inline] fail_exn_bt ebt = fail (Exn_bt.exn ebt) (Exn_bt.bt ebt) -let[@inline] raise_if_failed self : unit = - match A.get self.st with - | Done (Error ebt) -> Exn_bt.raise ebt - | _ -> () +let[@inline] of_result = function + | Ok x -> return x + | Error ebt -> fail_exn_bt ebt -let[@inline] is_done self : bool = - match A.get self.st with - | Done _ -> true - | Waiting _ -> false +let[@inline] is_resolved self : bool = not (C.is_running self.st) +let is_done = is_resolved +let[@inline] peek self : _ option = C.peek self.st +let[@inline] raise_if_failed self : unit = C.check self.st let[@inline] is_success self = - match A.get self.st with - | Done (Ok _) -> true - | _ -> false + match C.peek_exn self.st with + | _ -> true + | exception _ -> false -let[@inline] is_failed self = - match A.get self.st with - | Done (Error _) -> true - | _ -> false +let[@inline] is_failed self = C.is_canceled self.st exception Not_ready let[@inline] get_or_fail self = - match A.get self.st with - | Done x -> x - | Waiting _ -> raise Not_ready + match C.peek self.st with + | Some x -> x + | None -> raise Not_ready let[@inline] get_or_fail_exn self = - match A.get self.st with - | Done (Ok x) -> x - | Done (Error (exn, bt)) -> Printexc.raise_with_backtrace exn bt - | Waiting _ -> raise Not_ready + match C.peek_exn self.st with + | x -> x + | exception C.Running -> raise Not_ready + +let[@inline] peek_or_assert_ (self : 'a t) : 'a = + match C.peek_exn self.st with + | x -> x + | exception C.Running -> assert false + +let on_result_cb_ _tr f self : unit = + match peek_or_assert_ self with + | x -> f (Ok x) + | exception exn -> + let ebt = Exn_bt.get exn in + f (Error ebt) let on_result (self : _ t) (f : _ waiter) : unit = - while - let st = A.get self.st in - match st with - | Done x -> - f x; - false - | Waiting { waiters = l } -> - not (A.compare_and_set self.st st (Waiting { waiters = f :: l })) - do - Domain_.relax () - done + let trigger = + (Trigger.from_action f self on_result_cb_ [@alert "-handler"]) + in + if not (C.try_attach self.st trigger) then on_result_cb_ () f self + +let on_result_ignore_cb_ _tr f (self : _ t) = + f (Picos.Computation.canceled self.st) + +let on_result_ignore (self : _ t) f : unit = + if Picos.Computation.is_running self.st then ( + let trigger = + (Trigger.from_action f self on_result_ignore_cb_ [@alert "-handler"]) + in + if not (C.try_attach self.st trigger) then on_result_ignore_cb_ () f self + ) else + on_result_ignore_cb_ () f self + +let[@inline] fulfill_idempotent self r = + match r with + | Ok x -> C.return self.st x + | Error ebt -> C.cancel self.st (Exn_bt.exn ebt) (Exn_bt.bt ebt) exception Already_fulfilled let fulfill (self : _ t) (r : _ result) : unit = - let fs = ref [] in - while - let st = A.get self.st in - match st with - | Done _ -> raise Already_fulfilled - | Waiting { waiters = l } -> - let did_swap = A.compare_and_set self.st st (Done r) in - if did_swap then ( - (* success, now call all the waiters *) - fs := l; - false - ) else - true - do - Domain_.relax () - done; - List.iter (fun f -> try f r with _ -> ()) !fs; - () - -let[@inline] fulfill_idempotent self r = - try fulfill self r with Already_fulfilled -> () + let ok = + match r with + | Ok x -> C.try_return self.st x + | Error ebt -> C.try_cancel self.st (Exn_bt.exn ebt) (Exn_bt.bt ebt) + in + if not ok then raise Already_fulfilled (* ### combinators ### *) let spawn ~on f : _ t = - let fut, promise = make () in + let fut = make_ () in let task () = - let res = - try Ok (f ()) - with e -> - let bt = Printexc.get_raw_backtrace () in - Error (e, bt) - in - fulfill promise res + try + let res = f () in + C.return fut.st res + with exn -> + let bt = Printexc.get_raw_backtrace () in + C.cancel fut.st exn bt in Runner.run_async on task; @@ -127,8 +122,8 @@ let reify_error (f : 'a t) : 'a or_error t = match peek f with | Some res -> return res | None -> - let fut, promise = make () in - on_result f (fun r -> fulfill promise (Ok r)); + let fut = make_ () in + on_result f (fun r -> fulfill fut (Ok r)); fut let[@inline] get_runner_ ?on () : Runner.t option = @@ -141,9 +136,9 @@ let map ?on ~f fut : _ t = match r with | Ok x -> (try Ok (f x) - with e -> + with exn -> let bt = Printexc.get_raw_backtrace () in - Error (e, bt)) + Error (Exn_bt.make exn bt)) | Error e_bt -> Error e_bt in @@ -167,7 +162,7 @@ let map ?on ~f fut : _ t = let join (fut : 'a t t) : 'a t = match peek fut with | Some (Ok f) -> f - | Some (Error (e, bt)) -> fail e bt + | Some (Error ebt) -> fail_exn_bt ebt | None -> let fut2, promise = make () in on_result fut (function @@ -183,7 +178,7 @@ let bind ?on ~f fut : _ t = with e -> let bt = Printexc.get_raw_backtrace () in fail e bt) - | Error (e, bt) -> fail e bt + | Error ebt -> fail_exn_bt ebt in let bind_and_fulfill (r : _ result) promise () : unit = @@ -226,7 +221,7 @@ let update_atomic_ (st : 'a A.t) f : 'a = let both a b : _ t = match peek a, peek b with | Some (Ok x), Some (Ok y) -> return (x, y) - | Some (Error (e, bt)), _ | _, Some (Error (e, bt)) -> fail e bt + | Some (Error ebt), _ | _, Some (Error ebt) -> fail_exn_bt ebt | _ -> let fut, promise = make () in @@ -259,7 +254,7 @@ let choose a b : _ t = match peek a, peek b with | Some (Ok x), _ -> return (Either.Left x) | _, Some (Ok y) -> return (Either.Right y) - | Some (Error (e, bt)), Some (Error _) -> fail e bt + | Some (Error ebt), Some (Error _) -> fail_exn_bt ebt | _ -> let fut, promise = make () in @@ -282,7 +277,7 @@ let choose_same a b : _ t = match peek a, peek b with | Some (Ok x), _ -> return x | _, Some (Ok y) -> return y - | Some (Error (e, bt)), Some (Error _) -> fail e bt + | Some (Error ebt), Some (Error _) -> fail_exn_bt ebt | _ -> let fut, promise = make () in @@ -299,11 +294,6 @@ let choose_same a b : _ t = | Ok y -> fulfill_idempotent promise (Ok y)); fut -let peek_ok_assert_ (self : 'a t) : 'a = - match A.get self.st with - | Done (Ok x) -> x - | _ -> assert false - let barrier_on_abstract_container_of_futures ~iter ~len ~aggregate_results cont : _ t = let n_items = len cont in @@ -317,14 +307,14 @@ let barrier_on_abstract_container_of_futures ~iter ~len ~aggregate_results cont (* callback called when a future in [a] is resolved *) let on_res = function - | Ok _ -> + | None -> let n = A.fetch_and_add missing (-1) in if n = 1 then ( (* last future, we know they all succeeded, so resolve [fut] *) - let res = aggregate_results peek_ok_assert_ cont in + let res = aggregate_results peek_or_assert_ cont in fulfill promise (Ok res) ) - | Error e_bt -> + | Some e_bt -> (* immediately cancel all other [on_res] *) let n = A.exchange missing 0 in if n > 0 then @@ -333,7 +323,7 @@ let barrier_on_abstract_container_of_futures ~iter ~len ~aggregate_results cont fulfill promise (Error e_bt) in - iter (fun fut -> on_result fut on_res) cont; + iter (fun fut -> on_result_ignore fut on_res) cont; fut ) @@ -387,61 +377,65 @@ let for_list ~on l f : unit t = (* ### blocking ### *) -let wait_block (self : 'a t) : 'a or_error = - match A.get self.st with - | Done x -> x (* fast path *) - | Waiting _ -> +let push_queue_ _tr q () = Bb_queue.push q () + +let wait_block_exn (self : 'a t) : 'a = + match C.peek_exn self.st with + | x -> x (* fast path *) + | exception C.Running -> let real_block () = (* use queue only once *) let q = Bb_queue.create () in - on_result self (fun r -> Bb_queue.push q r); - Bb_queue.pop q + + let trigger = Trigger.create () in + let attached = + (Trigger.on_signal trigger q () push_queue_ [@alert "-handler"]) + in + assert attached; + + (* blockingly wait for trigger if computation didn't complete in the mean time *) + if C.try_attach self.st trigger then Bb_queue.pop q; + + (* trigger was signaled! computation must be done*) + peek_or_assert_ self in + (* TODO: use backoff? *) (* a bit of spinning before we block *) let rec loop i = if i = 0 then real_block () else ( - match A.get self.st with - | Done x -> x - | Waiting _ -> + match C.peek_exn self.st with + | x -> x + | exception C.Running -> Domain_.relax (); (loop [@tailcall]) (i - 1) ) in loop 50 -let wait_block_exn self = - match wait_block self with - | Ok x -> x - | Error (e, bt) -> Printexc.raise_with_backtrace e bt +let wait_block self = + match wait_block_exn self with + | x -> Ok x + | exception exn -> + let bt = Printexc.get_raw_backtrace () in + Error (Exn_bt.make exn bt) [@@@ifge 5.0] -let await (fut : 'a t) : 'a = - match peek fut with - | Some res -> - (* fast path: peek *) - (match res with - | Ok x -> x - | Error (exn, bt) -> Printexc.raise_with_backtrace exn bt) - | None -> +let await (self : 'a t) : 'a = + (* fast path: peek *) + match C.peek_exn self.st with + | res -> res + | exception C.Running -> + let trigger = Trigger.create () in (* suspend until the future is resolved *) - Suspend_.suspend - { - Suspend_.handle = - (fun ~run:_ ~resume k -> - on_result fut (function - | Ok _ -> - (* schedule continuation with the same name *) - resume k (Ok ()) - | Error (exn, bt) -> - (* fail continuation immediately *) - resume k (Error (exn, bt)))); - }; + if C.try_attach self.st trigger then + Option.iter Exn_bt.raise @@ Trigger.await trigger; + (* un-suspended: we should have a result! *) - get_or_fail_exn fut + get_or_fail_exn self [@@@endif] @@ -459,4 +453,5 @@ module Infix_local = Infix [@@deprecated "use Infix"] module Private_ = struct let[@inline] unsafe_promise_of_fut x = x + let[@inline] as_computation self = self.st end diff --git a/src/core/fut.mli b/src/core/fut.mli index 243afad0..cc8f85ee 100644 --- a/src/core/fut.mli +++ b/src/core/fut.mli @@ -34,6 +34,13 @@ val on_result : 'a t -> ('a or_error -> unit) -> unit when [fut] is set ; or calls [f] immediately if [fut] is already set. *) +val on_result_ignore : _ t -> (Exn_bt.t option -> unit) -> unit +(** [on_result_ignore fut f] registers [f] to be called in the future + when [fut] is set; + or calls [f] immediately if [fut] is already set. + It does not pass the result, only a success/error signal. + @since NEXT_RELEASE *) + exception Already_fulfilled val fulfill : 'a promise -> 'a or_error -> unit @@ -285,6 +292,8 @@ module Infix_local = Infix module Private_ : sig val unsafe_promise_of_fut : 'a t -> 'a promise (** please do not use *) + + val as_computation : 'a t -> 'a Picos.Computation.t end (**/**) diff --git a/src/core/hmap_ls_.dummy.ml b/src/core/hmap_ls_.dummy.ml new file mode 100644 index 00000000..e4f0692c --- /dev/null +++ b/src/core/hmap_ls_.dummy.ml @@ -0,0 +1,7 @@ +(**/**) + +module Private_hmap_ls_ = struct + let copy_fls _ _ = () +end + +(**/**) diff --git a/src/core/hmap_ls_.real.ml b/src/core/hmap_ls_.real.ml new file mode 100644 index 00000000..7d79316b --- /dev/null +++ b/src/core/hmap_ls_.real.ml @@ -0,0 +1,65 @@ +open Types_ + +open struct + module FLS = Picos.Fiber.FLS +end + +(** A local hmap, inherited in children fibers *) +let k_local_hmap : Hmap.t FLS.t = FLS.create () + +(** Access the local [hmap], or an empty one if not set *) +let[@inline] get_local_hmap () : Hmap.t = + let fiber = get_current_fiber_exn () in + FLS.get fiber ~default:Hmap.empty k_local_hmap + +let[@inline] set_local_hmap (h : Hmap.t) : unit = + let fiber = get_current_fiber_exn () in + FLS.set fiber k_local_hmap h + +let[@inline] update_local_hmap (f : Hmap.t -> Hmap.t) : unit = + let fiber = get_current_fiber_exn () in + let h = FLS.get fiber ~default:Hmap.empty k_local_hmap in + let h = f h in + FLS.set fiber k_local_hmap h + +(** @raise Invalid_argument if not present *) +let get_in_local_hmap_exn (k : 'a Hmap.key) : 'a = + let h = get_local_hmap () in + Hmap.get k h + +let get_in_local_hmap_opt (k : 'a Hmap.key) : 'a option = + let h = get_local_hmap () in + Hmap.find k h + +(** Remove given key from the local hmap *) +let[@inline] remove_in_local_hmap (k : _ Hmap.key) : unit = + update_local_hmap (Hmap.rem k) + +let[@inline] set_in_local_hmap (k : 'a Hmap.key) (v : 'a) : unit = + update_local_hmap (Hmap.add k v) + +(** [with_in_local_hmap k v f] calls [f()] in a context + where [k] is bound to [v] in the local hmap. Then it restores the + previous binding for [k]. *) +let with_in_local_hmap (k : 'a Hmap.key) (v : 'a) f : unit = + let h = get_local_hmap () in + match Hmap.find k h with + | None -> + set_in_local_hmap k v; + Fun.protect ~finally:(fun () -> remove_in_local_hmap k) f + | Some old_v -> + set_in_local_hmap k v; + Fun.protect ~finally:(fun () -> set_in_local_hmap k old_v) f + +(**/**) + +(* private functions, to be used by the rest of moonpool *) +module Private_hmap_ls_ = struct + (** Copy the hmap from f1.fls to f2.fls *) + let copy_fls (f1 : Picos.Fiber.t) (f2 : Picos.Fiber.t) : unit = + match FLS.get_exn f1 k_local_hmap with + | exception FLS.Not_set -> () + | hmap -> FLS.set f2 k_local_hmap hmap +end + +(**/**) diff --git a/src/core/moonpool.ml b/src/core/moonpool.ml index cafac26c..47e5e5e3 100644 --- a/src/core/moonpool.ml +++ b/src/core/moonpool.ml @@ -30,14 +30,16 @@ module Lock = Lock module Immediate_runner = struct end module Runner = Runner module Task_local_storage = Task_local_storage -module Thread_local_storage = Thread_local_storage_ +module Thread_local_storage = Thread_local_storage +module Trigger = Trigger module Ws_pool = Ws_pool module Private = struct module Ws_deque_ = Ws_deque_ - module Suspend_ = Suspend_ + module Worker_loop_ = Worker_loop_ module Domain_ = Domain_ module Tracing_ = Tracing_ + module Types_ = Types_ let num_domains = Domain_pool_.max_number_of_domains end diff --git a/src/core/moonpool.mli b/src/core/moonpool.mli index c0d495c9..a992e8b8 100644 --- a/src/core/moonpool.mli +++ b/src/core/moonpool.mli @@ -13,6 +13,7 @@ module Ws_pool = Ws_pool module Fifo_pool = Fifo_pool module Background_thread = Background_thread module Runner = Runner +module Trigger = Trigger module Immediate_runner : sig end [@@deprecated "use Moonpool_fib.Main"] @@ -32,19 +33,22 @@ val start_thread_on_some_domain : ('a -> unit) -> 'a -> Thread.t to run the thread. This ensures that we don't always pick the same domain to run all the various threads needed in an application (timers, event loops, etc.) *) -val run_async : ?ls:Task_local_storage.t -> Runner.t -> (unit -> unit) -> unit +val run_async : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> unit) -> unit (** [run_async runner task] schedules the task to run on the given runner. This means [task()] will be executed at some point in the future, possibly in another thread. + @param fiber optional initial (picos) fiber state @since 0.5 *) -val run_wait_block : ?ls:Task_local_storage.t -> Runner.t -> (unit -> 'a) -> 'a +val run_wait_block : ?fiber:Picos.Fiber.t -> Runner.t -> (unit -> 'a) -> 'a (** [run_wait_block runner f] schedules [f] for later execution on the runner, like {!run_async}. It then blocks the current thread until [f()] is done executing, and returns its result. If [f()] raises an exception, then [run_wait_block pool f] will raise it as well. + See {!run_async} for more details. + {b NOTE} be careful with deadlocks (see notes in {!Fut.wait_block} about the required discipline to avoid deadlocks). @raise Shutdown if the runner was already shut down @@ -78,7 +82,7 @@ module Lock = Lock module Fut = Fut module Chan = Chan module Task_local_storage = Task_local_storage -module Thread_local_storage = Thread_local_storage_ +module Thread_local_storage = Thread_local_storage (** A simple blocking queue. @@ -211,21 +215,16 @@ module Private : sig module Ws_deque_ = Ws_deque_ (** A deque for work stealing, fixed size. *) - (** {2 Suspensions} *) - - module Suspend_ = Suspend_ - [@@alert - unstable "this module is an implementation detail of moonpool for now"] - (** Suspensions. - - This is only going to work on OCaml 5.x. - - {b NOTE}: this is not stable for now. *) + module Worker_loop_ = Worker_loop_ + (** Worker loop. This is useful to implement custom runners, it + should run on each thread of the runner. + @since NEXT_RELEASE *) module Domain_ = Domain_ (** Utils for domains *) module Tracing_ = Tracing_ + module Types_ = Types_ val num_domains : unit -> int (** Number of domains in the backing domain pool *) diff --git a/src/core/runner.ml b/src/core/runner.ml index 0bf7895c..a95de289 100644 --- a/src/core/runner.ml +++ b/src/core/runner.ml @@ -1,10 +1,10 @@ open Types_ -module TLS = Thread_local_storage_ +type fiber = Picos.Fiber.t type task = unit -> unit type t = runner = { - run_async: ls:local_storage -> task -> unit; + run_async: fiber:fiber -> task -> unit; shutdown: wait:bool -> unit -> unit; size: unit -> int; num_tasks: unit -> int; @@ -12,8 +12,15 @@ type t = runner = { exception Shutdown -let[@inline] run_async ?(ls = create_local_storage ()) (self : t) f : unit = - self.run_async ~ls f +let[@inline] run_async ?fiber (self : t) f : unit = + let fiber = + match fiber with + | Some f -> f + | None -> + let comp = Picos.Computation.create () in + Picos.Fiber.create ~forbid:false comp + in + self.run_async ~fiber f let[@inline] shutdown (self : t) : unit = self.shutdown ~wait:true () @@ -23,9 +30,9 @@ let[@inline] shutdown_without_waiting (self : t) : unit = let[@inline] num_tasks (self : t) : int = self.num_tasks () let[@inline] size (self : t) : int = self.size () -let run_wait_block ?ls self (f : unit -> 'a) : 'a = +let run_wait_block ?fiber self (f : unit -> 'a) : 'a = let q = Bb_queue.create () in - run_async ?ls self (fun () -> + run_async ?fiber self (fun () -> try let x = f () in Bb_queue.push q (Ok x) @@ -48,9 +55,9 @@ let dummy : t = ~size:(fun () -> 0) ~num_tasks:(fun () -> 0) ~shutdown:(fun ~wait:_ () -> ()) - ~run_async:(fun ~ls:_ _ -> + ~run_async:(fun ~fiber:_ _ -> failwith "Runner.dummy: cannot actually run tasks") () let get_current_runner = get_current_runner -let get_current_storage = get_current_storage +let get_current_fiber = get_current_fiber diff --git a/src/core/runner.mli b/src/core/runner.mli index f0b0d099..958c8598 100644 --- a/src/core/runner.mli +++ b/src/core/runner.mli @@ -5,6 +5,7 @@ @since 0.3 *) +type fiber = Picos.Fiber.t type task = unit -> unit type t @@ -33,14 +34,14 @@ val shutdown_without_waiting : t -> unit exception Shutdown -val run_async : ?ls:Task_local_storage.t -> t -> task -> unit +val run_async : ?fiber:fiber -> t -> task -> unit (** [run_async pool f] schedules [f] for later execution on the runner in one of the threads. [f()] will run on one of the runner's worker threads/domains. - @param ls if provided, run the task with this initial local storage + @param fiber if provided, run the task with this initial fiber data @raise Shutdown if the runner was shut down before [run_async] was called. *) -val run_wait_block : ?ls:Task_local_storage.t -> t -> (unit -> 'a) -> 'a +val run_wait_block : ?fiber:fiber -> t -> (unit -> 'a) -> 'a (** [run_wait_block pool f] schedules [f] for later execution on the pool, like {!run_async}. It then blocks the current thread until [f()] is done executing, @@ -65,7 +66,7 @@ module For_runner_implementors : sig size:(unit -> int) -> num_tasks:(unit -> int) -> shutdown:(wait:bool -> unit -> unit) -> - run_async:(ls:Task_local_storage.t -> task -> unit) -> + run_async:(fiber:fiber -> task -> unit) -> unit -> t (** Create a new runner. @@ -73,7 +74,7 @@ module For_runner_implementors : sig {b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, so that {!Fork_join} and other 5.x features work properly. *) - val k_cur_runner : t Thread_local_storage_.t + val k_cur_runner : t Thread_local_storage.t (** Key that should be used by each runner to store itself in TLS on every thread it controls, so that tasks running on these threads can access the runner. This is necessary for {!get_current_runner} @@ -85,6 +86,6 @@ val get_current_runner : unit -> t option happens on a thread that belongs in a runner. @since 0.5 *) -val get_current_storage : unit -> Task_local_storage.t option +val get_current_fiber : unit -> fiber option (** [get_current_storage runner] gets the local storage for the currently running task. *) diff --git a/src/core/suspend_.ml b/src/core/suspend_.ml deleted file mode 100644 index fefbaff3..00000000 --- a/src/core/suspend_.ml +++ /dev/null @@ -1,87 +0,0 @@ -type suspension = unit Exn_bt.result -> unit -type task = unit -> unit - -type suspension_handler = { - handle: - run:(task -> unit) -> - resume:(suspension -> unit Exn_bt.result -> unit) -> - suspension -> - unit; -} -[@@unboxed] - -type with_suspend_handler = - | WSH : { - on_suspend: unit -> 'state; - (** on_suspend called when [f()] suspends itself. *) - run: 'state -> task -> unit; (** run used to schedule new tasks *) - resume: 'state -> suspension -> unit Exn_bt.result -> unit; - (** resume run the suspension. Must be called exactly once. *) - } - -> with_suspend_handler - -[@@@ifge 5.0] -[@@@ocaml.alert "-unstable"] - -module A = Atomic_ - -type _ Effect.t += - | Suspend : suspension_handler -> unit Effect.t - | Yield : unit Effect.t - -let[@inline] yield () = Effect.perform Yield -let[@inline] suspend h = Effect.perform (Suspend h) - -let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit = - let module E = Effect.Deep in - (* effect handler *) - let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = - function - | Suspend h -> - (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) - Some - (fun k -> - let state = on_suspend () in - let k' : suspension = function - | Ok () -> E.continue k () - | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt - in - h.handle ~run:(run state) ~resume:(resume state) k') - | Yield -> - (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) - Some - (fun k -> - let state = on_suspend () in - let k' : suspension = function - | Ok () -> E.continue k () - | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt - in - resume state k' @@ Ok ()) - | _ -> None - in - - E.try_with f () { E.effc } - -(* DLA interop *) -let prepare_for_await () : Dla_.t = - (* current state *) - let st : (_ * suspension) option A.t = A.make None in - - let release () : unit = - match A.exchange st None with - | None -> () - | Some (resume, k) -> resume k @@ Ok () - and await () : unit = - suspend { handle = (fun ~run:_ ~resume k -> A.set st (Some (resume, k))) } - in - - let t = { Dla_.release; await } in - t - -[@@@ocaml.alert "+unstable"] -[@@@else_] - -let[@inline] with_suspend (WSH _) f = f () -let[@inline] prepare_for_await () = { Dla_.release = ignore; await = ignore } - -[@@@endif] diff --git a/src/core/suspend_.mli b/src/core/suspend_.mli deleted file mode 100644 index de90e2d4..00000000 --- a/src/core/suspend_.mli +++ /dev/null @@ -1,89 +0,0 @@ -(** (Private) suspending tasks using Effects. - - This module is an implementation detail of Moonpool and should - not be used outside of it, except by experts to implement {!Runner}. *) - -type suspension = unit Exn_bt.result -> unit -(** A suspended computation *) - -type task = unit -> unit - -type suspension_handler = { - handle: - run:(task -> unit) -> - resume:(suspension -> unit Exn_bt.result -> unit) -> - suspension -> - unit; -} -[@@unboxed] -(** The handler that knows what to do with the suspended computation. - - The handler is given a few things: - - - the suspended computation (which can be resumed with a result - eventually); - - a [run] function that can be used to start tasks to perform some - computation. - - a [resume] function to resume the suspended computation. This - must be called exactly once, in all situations. - - This means that a fork-join primitive, for example, can use a single call - to {!suspend} to: - - suspend the caller until the fork-join is done - - use [run] to start all the tasks. Typically [run] is called multiple times, - which is where the "fork" part comes from. Each call to [run] potentially - runs in parallel with the other calls. The calls must coordinate so - that, once they are all done, the suspended caller is resumed with the - aggregated result of the computation. - - use [resume] exactly -*) - -[@@@ifge 5.0] -[@@@ocaml.alert "-unstable"] - -type _ Effect.t += - | Suspend : suspension_handler -> unit Effect.t - (** The effect used to suspend the current thread and pass it, suspended, - to the handler. The handler will ensure that the suspension is resumed later - once some computation has been done. *) - | Yield : unit Effect.t - (** The effect used to interrupt the current computation and immediately re-schedule - it on the same runner. *) - -[@@@ocaml.alert "+unstable"] - -val yield : unit -> unit -(** Interrupt current computation, and re-schedule it at the end of the - runner's job queue. *) - -val suspend : suspension_handler -> unit -(** [suspend h] jumps back to the nearest {!with_suspend} - and calls [h.handle] with the current continuation [k] - and a task runner function. -*) - -[@@@endif] - -type with_suspend_handler = - | WSH : { - on_suspend: unit -> 'state; - (** on_suspend called when [f()] suspends itself. *) - run: 'state -> task -> unit; (** run used to schedule new tasks *) - resume: 'state -> suspension -> unit Exn_bt.result -> unit; - (** resume run the suspension. Must be called exactly once. *) - } - -> with_suspend_handler - -val with_suspend : with_suspend_handler -> (unit -> unit) -> unit -(** [with_suspend wsh f] - runs [f()] in an environment where [suspend] will work. - - If [f()] suspends with suspension handler [h], - this calls [wsh.on_suspend()] to capture the current state [st]. - Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension. - The suspension should always be passed exactly once to - [resume]. [run] should be used to start other tasks. -*) - -val prepare_for_await : unit -> Dla_.t -(** Our stub for DLA. Unstable. *) diff --git a/src/core/task_local_storage.ml b/src/core/task_local_storage.ml index a1266304..4f5e5004 100644 --- a/src/core/task_local_storage.ml +++ b/src/core/task_local_storage.ml @@ -1,81 +1,44 @@ open Types_ -module A = Atomic_ +module PF = Picos.Fiber -type 'a key = 'a ls_key +type 'a t = 'a PF.FLS.t -let key_count_ = A.make 0 +exception Not_set = PF.FLS.Not_set -type t = local_storage -type ls_value += Dummy +let create = PF.FLS.create -let dummy : t = _dummy_ls +let[@inline] get_exn k = + let fiber = get_current_fiber_exn () in + PF.FLS.get_exn fiber k -(** Resize array of TLS values *) -let[@inline never] resize_ (cur : ls_value array ref) n = - if n > Sys.max_array_length then failwith "too many task local storage keys"; - let len = Array.length !cur in - let new_ls = - Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy - in - Array.blit !cur 0 new_ls 0 len; - cur := new_ls - -module Direct = struct - type nonrec t = t - - let create = create_local_storage - let[@inline] copy (self : t) = ref (Array.copy !self) - - let get (type a) (self : t) ((module K) : a key) : a = - if K.offset >= Array.length !self then resize_ self (K.offset + 1); - match !self.(K.offset) with - | K.V x -> (* common case first *) x - | Dummy -> - (* first time we access this *) - let v = K.init () in - !self.(K.offset) <- K.V v; - v - | _ -> assert false - - let set (type a) (self : t) ((module K) : a key) (v : a) : unit = - assert (self != dummy); - if K.offset >= Array.length !self then resize_ self (K.offset + 1); - !self.(K.offset) <- K.V v; - () -end - -let new_key (type t) ~init () : t key = - let offset = A.fetch_and_add key_count_ 1 in - (module struct - type nonrec t = t - type ls_value += V of t - - let offset = offset - let init = init - end : LS_KEY - with type t = t) - -let[@inline] get_cur_ () : ls_value array ref = - match get_current_storage () with - | Some r when r != dummy -> r - | _ -> failwith "Task local storage must be accessed from within a runner." - -let[@inline] get (key : 'a key) : 'a = - let cur = get_cur_ () in - Direct.get cur key - -let[@inline] get_opt key = - match get_current_storage () with +let get_opt k = + match get_current_fiber () with | None -> None - | Some cur -> Some (Direct.get cur key) + | Some fiber -> + (match PF.FLS.get_exn fiber k with + | x -> Some x + | exception Not_set -> None) -let[@inline] set key v : unit = - let cur = get_cur_ () in - Direct.set cur key v +let[@inline] get k ~default = + match get_current_fiber () with + | None -> default + | Some fiber -> PF.FLS.get fiber ~default k -let with_value key x f = - let old = get key in - set key x; - Fun.protect ~finally:(fun () -> set key old) f +let[@inline] set k v : unit = + let fiber = get_current_fiber_exn () in + PF.FLS.set fiber k v -let get_current = get_current_storage +let with_value k v (f : _ -> 'b) : 'b = + let fiber = get_current_fiber_exn () in + + match PF.FLS.get_exn fiber k with + | exception Not_set -> + PF.FLS.set fiber k v; + (* nothing to restore back to, just call [f] *) + f () + | old_v -> + PF.FLS.set fiber k v; + let finally () = PF.FLS.set fiber k old_v in + Fun.protect f ~finally + +include Hmap_ls_ diff --git a/src/core/task_local_storage.mli b/src/core/task_local_storage.mli index a1da0b0f..71c7ffe6 100644 --- a/src/core/task_local_storage.mli +++ b/src/core/task_local_storage.mli @@ -8,60 +8,39 @@ @since 0.6 *) -type t = Types_.local_storage -(** Underlying storage for a task. This is mutable and - not thread-safe. *) +type 'a t = 'a Picos.Fiber.FLS.t -val dummy : t +val create : unit -> 'a t +(** [create ()] makes a new key. Keys are expensive and + should never be allocated dynamically or in a loop. *) -type 'a key -(** A key used to access a particular (typed) storage slot on every task. *) +exception Not_set -val new_key : init:(unit -> 'a) -> unit -> 'a key -(** [new_key ~init ()] makes a new key. Keys are expensive and - should never be allocated dynamically or in a loop. - The correct pattern is, at toplevel: - - {[ - let k_foo : foo Task_ocal_storage.key = - Task_local_storage.new_key ~init:(fun () -> make_foo ()) () - - (* … *) - - (* use it: *) - let … = Task_local_storage.get k_foo - ]} -*) - -val get : 'a key -> 'a +val get_exn : 'a t -> 'a (** [get k] gets the value for the current task for key [k]. Must be run from inside a task running on a runner. - @raise Failure otherwise *) + @raise Not_set otherwise *) -val get_opt : 'a key -> 'a option +val get_opt : 'a t -> 'a option (** [get_opt k] gets the current task's value for key [k], or [None] if not run from inside the task. *) -val set : 'a key -> 'a -> unit +val get : 'a t -> default:'a -> 'a + +val set : 'a t -> 'a -> unit (** [set k v] sets the storage for [k] to [v]. Must be run from inside a task running on a runner. @raise Failure otherwise *) -val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b +val with_value : 'a t -> 'a -> (unit -> 'b) -> 'b (** [with_value k v f] sets [k] to [v] for the duration of the call to [f()]. When [f()] returns (or fails), [k] is restored to its old value. *) -val get_current : unit -> t option -(** Access the current storage, or [None] if not run from - within a task. *) +(** {2 Local [Hmap.t]} -(** Direct access to values from a storage handle *) -module Direct : sig - val get : t -> 'a key -> 'a - (** Access a key *) + This requires [hmap] to be installed. *) - val set : t -> 'a key -> 'a -> unit - val create : unit -> t - val copy : t -> t +include module type of struct + include Hmap_ls_ end diff --git a/src/core/trigger.ml b/src/core/trigger.ml new file mode 100644 index 00000000..f7fda452 --- /dev/null +++ b/src/core/trigger.ml @@ -0,0 +1,6 @@ +(** Triggers from picos + @since NEXT_RELEASE *) + +include Picos.Trigger + +let[@inline] await_exn (self : t) = await self |> Option.iter Exn_bt.raise diff --git a/src/core/types_.ml b/src/core/types_.ml index 08d2f09c..97209942 100644 --- a/src/core/types_.ml +++ b/src/core/types_.ml @@ -1,35 +1,38 @@ -module TLS = Thread_local_storage_ +module TLS = Thread_local_storage module Domain_pool_ = Moonpool_dpool -type ls_value = .. - -(** Key for task local storage *) -module type LS_KEY = sig - type t - type ls_value += V of t - - val offset : int - (** Unique offset *) - - val init : unit -> t -end - -type 'a ls_key = (module LS_KEY with type t = 'a) -(** A LS key (task local storage) *) - type task = unit -> unit -type local_storage = ls_value array ref +type fiber = Picos.Fiber.t type runner = { - run_async: ls:local_storage -> task -> unit; + run_async: fiber:fiber -> task -> unit; shutdown: wait:bool -> unit -> unit; size: unit -> int; num_tasks: unit -> int; } let k_cur_runner : runner TLS.t = TLS.create () -let k_cur_storage : local_storage TLS.t = TLS.create () -let _dummy_ls : local_storage = ref [||] +let k_cur_fiber : fiber TLS.t = TLS.create () + +let _dummy_computation : Picos.Computation.packed = + let c = Picos.Computation.create () in + Picos.Computation.cancel c (Failure "dummy fiber") (Printexc.get_callstack 0); + Picos.Computation.Packed c + +let _dummy_fiber = Picos.Fiber.create_packed ~forbid:true _dummy_computation let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner -let[@inline] get_current_storage () : _ option = TLS.get_opt k_cur_storage -let[@inline] create_local_storage () = ref [||] + +let[@inline] get_current_fiber () : fiber option = + match TLS.get_exn k_cur_fiber with + | f when f != _dummy_fiber -> Some f + | _ -> None + | exception TLS.Not_set -> None + +let error_get_current_fiber_ = + "Moonpool: get_current_fiber was called outside of a fiber." + +let[@inline] get_current_fiber_exn () : fiber = + match TLS.get_exn k_cur_fiber with + | f when f != _dummy_fiber -> f + | _ -> failwith error_get_current_fiber_ + | exception TLS.Not_set -> failwith error_get_current_fiber_ diff --git a/src/core/worker_loop_.ml b/src/core/worker_loop_.ml new file mode 100644 index 00000000..b28ef42b --- /dev/null +++ b/src/core/worker_loop_.ml @@ -0,0 +1,153 @@ +open Types_ + +type fiber = Picos.Fiber.t + +type task_full = + | T_start of { + fiber: fiber; + f: unit -> unit; + } + | T_resume : { + fiber: fiber; + k: unit -> unit; + } + -> task_full + +type around_task = + | AT_pair : (Runner.t -> 'a) * (Runner.t -> 'a -> unit) -> around_task + +exception No_more_tasks + +type 'st ops = { + schedule: 'st -> task_full -> unit; + get_next_task: 'st -> task_full; (** @raise No_more_tasks *) + get_thread_state: unit -> 'st; + (** Access current thread's worker state from any worker *) + around_task: 'st -> around_task; + on_exn: 'st -> Exn_bt.t -> unit; + runner: 'st -> Runner.t; + before_start: 'st -> unit; + cleanup: 'st -> unit; +} + +(** A dummy task. *) +let _dummy_task : task_full = T_start { f = ignore; fiber = _dummy_fiber } + +[@@@ifge 5.0] + +let[@inline] discontinue k exn = + let bt = Printexc.get_raw_backtrace () in + Effect.Deep.discontinue_with_backtrace k exn bt + +let with_handler (type st arg) ~(ops : st ops) (self : st) : + (unit -> unit) -> unit = + let current = + Some + (fun k -> + match get_current_fiber_exn () with + | fiber -> Effect.Deep.continue k fiber + | exception exn -> discontinue k exn) + and yield = + Some + (fun k -> + let fiber = get_current_fiber_exn () in + match + let k () = Effect.Deep.continue k () in + ops.schedule self @@ T_resume { fiber; k } + with + | () -> () + | exception exn -> discontinue k exn) + and reschedule trigger fiber k : unit = + ignore (Picos.Fiber.unsuspend fiber trigger : bool); + let k () = Picos.Fiber.resume fiber k in + let task = T_resume { fiber; k } in + ops.schedule self task + in + let effc (type a) : + a Effect.t -> ((a, _) Effect.Deep.continuation -> _) option = function + | Picos.Fiber.Current -> current + | Picos.Fiber.Yield -> yield + | Picos.Fiber.Spawn r -> + Some + (fun k -> + match + let f () = r.main r.fiber in + let task = T_start { fiber = r.fiber; f } in + ops.schedule self task + with + | unit -> Effect.Deep.continue k unit + | exception exn -> discontinue k exn) + | Picos.Trigger.Await trigger -> + Some + (fun k -> + let fiber = get_current_fiber_exn () in + (* when triggers is signaled, reschedule task *) + if not (Picos.Fiber.try_suspend fiber trigger fiber k reschedule) then + (* trigger was already signaled, run task now *) + Picos.Fiber.resume fiber k) + | Picos.Computation.Cancel_after _r -> + Some + (fun k -> + (* not implemented *) + let exn = Failure "Moonpool: cancel_after is not supported." in + discontinue k exn) + | _ -> None + in + let handler = Effect.Deep.{ retc = Fun.id; exnc = raise; effc } in + fun f -> Effect.Deep.match_with f () handler + +[@@@else_] + +let with_handler ~ops:_ self f = f () + +[@@@endif] + +let worker_loop (type st) ~(ops : st ops) (self : st) : unit = + let cur_fiber : fiber ref = ref _dummy_fiber in + let runner = ops.runner self in + TLS.set Runner.For_runner_implementors.k_cur_runner runner; + + let (AT_pair (before_task, after_task)) = ops.around_task self in + + let run_task (task : task_full) : unit = + let fiber = + match task with + | T_start { fiber; _ } | T_resume { fiber; _ } -> fiber + in + + cur_fiber := fiber; + TLS.set k_cur_fiber fiber; + let _ctx = before_task runner in + + (* run the task now, catching errors, handling effects *) + assert (task != _dummy_task); + (try + match task with + | T_start { fiber = _; f } -> with_handler ~ops self f + | T_resume { fiber = _; k } -> + (* this is already in an effect handler *) + k () + with e -> + let ebt = Exn_bt.get e in + ops.on_exn self ebt); + + after_task runner _ctx; + + cur_fiber := _dummy_fiber; + TLS.set k_cur_fiber _dummy_fiber + in + + ops.before_start self; + + let continue = ref true in + try + while !continue do + match ops.get_next_task self with + | task -> run_task task + | exception No_more_tasks -> continue := false + done; + ops.cleanup self + with exn -> + let bt = Printexc.get_raw_backtrace () in + ops.cleanup self; + Printexc.raise_with_backtrace exn bt diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index de4b44cc..1e421fe7 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -1,7 +1,7 @@ open Types_ -module WSQ = Ws_deque_ module A = Atomic_ -module TLS = Thread_local_storage_ +module WSQ = Ws_deque_ +module WL = Worker_loop_ include Runner let ( let@ ) = ( @@ ) @@ -14,46 +14,39 @@ module Id = struct let equal : t -> t -> bool = ( == ) end -type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task - -type task_full = - | T_start of { - ls: Task_local_storage.t; - f: task; - } - | T_resume : { - ls: Task_local_storage.t; - k: 'a -> unit; - x: 'a; - } - -> task_full - -type worker_state = { - pool_id_: Id.t; (** Unique per pool *) - mutable thread: Thread.t; - q: task_full WSQ.t; (** Work stealing queue *) - mutable cur_ls: Task_local_storage.t option; (** Task storage *) - rng: Random.State.t; -} -(** State for a given worker. Only this worker is - allowed to push into the queue, but other workers - can come and steal from it if they're idle. *) - type state = { id_: Id.t; + (** Unique to this pool. Used to make sure tasks stay within the same pool. *) active: bool A.t; (** Becomes [false] when the pool is shutdown. *) - workers: worker_state array; (** Fixed set of workers. *) - main_q: task_full Queue.t; + mutable workers: worker_state array; (** Fixed set of workers. *) + main_q: WL.task_full Queue.t; (** Main queue for tasks coming from the outside *) mutable n_waiting: int; (* protected by mutex *) mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *) mutex: Mutex.t; cond: Condition.t; + mutable as_runner: t; + (* init options *) + around_task: WL.around_task; + name: string option; + on_init_thread: dom_id:int -> t_id:int -> unit -> unit; + on_exit_thread: dom_id:int -> t_id:int -> unit -> unit; on_exn: exn -> Printexc.raw_backtrace -> unit; - around_task: around_task; } (** internal state *) +and worker_state = { + mutable thread: Thread.t; + idx: int; + dom_id: int; + st: state; + q: WL.task_full WSQ.t; (** Work stealing queue *) + rng: Random.State.t; +} +(** State for a given worker. Only this worker is + allowed to push into the queue, but other workers + can come and steal from it if they're idle. *) + let[@inline] size_ (self : state) = Array.length self.workers let num_tasks_ (self : state) : int = @@ -67,9 +60,15 @@ let num_tasks_ (self : state) : int = sub-tasks. *) let k_worker_state : worker_state TLS.t = TLS.create () -let[@inline] find_current_worker_ () : worker_state option = +let[@inline] get_current_worker_ () : worker_state option = TLS.get_opt k_worker_state +let[@inline] get_current_worker_exn () : worker_state = + match TLS.get_exn k_worker_state with + | w -> w + | exception TLS.Not_set -> + failwith "Moonpool: get_current_runner was called from outside a pool." + (** Try to wake up a waiter, if there's any. *) let[@inline] try_wake_someone_ (self : state) : unit = if self.n_waiting_nonzero then ( @@ -78,194 +77,148 @@ let[@inline] try_wake_someone_ (self : state) : unit = Mutex.unlock self.mutex ) -(** Run [task] as is, on the pool. *) -let schedule_task_ (self : state) ~w (task : task_full) : unit = - (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) - match w with - | Some w when Id.equal self.id_ w.pool_id_ -> - (* we're on this same pool, schedule in the worker's state. Otherwise - we might also be on pool A but asking to schedule on pool B, - so we have to check that identifiers match. *) - let pushed = WSQ.push w.q task in - if pushed then - try_wake_someone_ self - else ( - (* overflow into main queue *) - Mutex.lock self.mutex; - Queue.push task self.main_q; - if self.n_waiting_nonzero then Condition.signal self.cond; - Mutex.unlock self.mutex +(** Push into worker's local queue, open to work stealing. + precondition: this runs on the worker thread whose state is [self] *) +let schedule_on_current_worker (self : worker_state) task : unit = + (* we're on this same pool, schedule in the worker's state. Otherwise + we might also be on pool A but asking to schedule on pool B, + so we have to check that identifiers match. *) + let pushed = WSQ.push self.q task in + if pushed then + try_wake_someone_ self.st + else ( + (* overflow into main queue *) + Mutex.lock self.st.mutex; + Queue.push task self.st.main_q; + if self.st.n_waiting_nonzero then Condition.signal self.st.cond; + Mutex.unlock self.st.mutex + ) + +(** Push into the shared queue of this pool *) +let schedule_in_main_queue (self : state) task : unit = + if A.get self.active then ( + (* push into the main queue *) + Mutex.lock self.mutex; + Queue.push task self.main_q; + if self.n_waiting_nonzero then Condition.signal self.cond; + Mutex.unlock self.mutex + ) else + (* notify the caller that scheduling tasks is no + longer permitted *) + raise Shutdown + +let schedule_from_w (self : worker_state) (task : WL.task_full) : unit = + match get_current_worker_ () with + | Some w when Id.equal self.st.id_ w.st.id_ -> + (* use worker from the same pool *) + schedule_on_current_worker w task + | _ -> schedule_in_main_queue self.st task + +exception Got_task of WL.task_full + +(** Try to steal a task. + @raise Got_task if it finds one. *) +let try_to_steal_work_once_ (self : worker_state) : unit = + let init = Random.State.int self.rng (Array.length self.st.workers) in + for i = 0 to Array.length self.st.workers - 1 do + let w' = + Array.unsafe_get self.st.workers + ((i + init) mod Array.length self.st.workers) + in + + if self != w' then ( + match WSQ.steal w'.q with + | Some t -> raise_notrace (Got_task t) + | None -> () ) - | _ -> - if A.get self.active then ( - (* push into the main queue *) - Mutex.lock self.mutex; - Queue.push task self.main_q; - if self.n_waiting_nonzero then Condition.signal self.cond; - Mutex.unlock self.mutex - ) else - (* notify the caller that scheduling tasks is no - longer permitted *) - raise Shutdown - -(** Run this task, now. Must be called from a worker. *) -let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) - : unit = - (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) - let (AT_pair (before_task, after_task)) = self.around_task in - - let ls = - match task with - | T_start { ls; _ } | T_resume { ls; _ } -> ls - in - - w.cur_ls <- Some ls; - TLS.set k_cur_storage ls; - let _ctx = before_task runner in - - let[@inline] on_suspend () : _ ref = - match find_current_worker_ () with - | Some { cur_ls = Some w; _ } -> w - | _ -> assert false - in - - let run_another_task ls (task' : task) = - let w = - match find_current_worker_ () with - | Some w when Id.equal w.pool_id_ self.id_ -> Some w - | _ -> None - in - let ls' = Task_local_storage.Direct.copy ls in - schedule_task_ self ~w @@ T_start { ls = ls'; f = task' } - in - - let resume ls k x = - let w = - match find_current_worker_ () with - | Some w when Id.equal w.pool_id_ self.id_ -> Some w - | _ -> None - in - schedule_task_ self ~w @@ T_resume { ls; k; x } - in - - (* run the task now, catching errors *) - (try - match task with - | T_start { f = task; _ } -> - (* run [task()] and handle [suspend] in it *) - Suspend_.with_suspend - (WSH { on_suspend; run = run_another_task; resume }) - task - | T_resume { k; x; _ } -> - (* this is already in an effect handler *) - k x - with e -> - let bt = Printexc.get_raw_backtrace () in - self.on_exn e bt); - - after_task runner _ctx; - w.cur_ls <- None; - TLS.set k_cur_storage _dummy_ls - -let run_async_ (self : state) ~ls (f : task) : unit = - let w = find_current_worker_ () in - schedule_task_ self ~w @@ T_start { f; ls } - -(* TODO: function to schedule many tasks from the outside. - - build a queue - - lock - - queue transfer - - wakeup all (broadcast) - - unlock *) + done (** Wait on condition. Precondition: we hold the mutex. *) -let[@inline] wait_ (self : state) : unit = +let[@inline] wait_for_condition_ (self : state) : unit = self.n_waiting <- self.n_waiting + 1; if self.n_waiting = 1 then self.n_waiting_nonzero <- true; Condition.wait self.cond self.mutex; self.n_waiting <- self.n_waiting - 1; if self.n_waiting = 0 then self.n_waiting_nonzero <- false -exception Got_task of task_full +let rec get_next_task (self : worker_state) : WL.task_full = + (* see if we can empty the local queue *) + match WSQ.pop_exn self.q with + | task -> + try_wake_someone_ self.st; + task + | exception WSQ.Empty -> try_to_steal_from_other_workers_ self -(** Try to steal a task *) -let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option - = - let init = Random.State.int w.rng (Array.length self.workers) in +and try_to_steal_from_other_workers_ (self : worker_state) = + match try_to_steal_work_once_ self with + | exception Got_task task -> task + | () -> wait_on_main_queue self - try - for i = 0 to Array.length self.workers - 1 do - let w' = - Array.unsafe_get self.workers ((i + init) mod Array.length self.workers) - in +and wait_on_main_queue (self : worker_state) : WL.task_full = + Mutex.lock self.st.mutex; + match Queue.pop self.st.main_q with + | task -> + Mutex.unlock self.st.mutex; + task + | exception Queue.Empty -> + (* wait here *) + if A.get self.st.active then ( + wait_for_condition_ self.st; - if w != w' then ( - match WSQ.steal w'.q with - | Some t -> raise_notrace (Got_task t) - | None -> () - ) - done; - None - with Got_task t -> Some t + (* see if a task became available *) + match Queue.pop self.st.main_q with + | task -> + Mutex.unlock self.st.mutex; + task + | exception Queue.Empty -> + Mutex.unlock self.st.mutex; + try_to_steal_from_other_workers_ self + ) else ( + (* do nothing more: no task in main queue, and we are shutting + down so no new task should arrive. + The exception is if another task is creating subtasks + that overflow into the main queue, but we can ignore that at + the price of slightly decreased performance for the last few + tasks *) + Mutex.unlock self.st.mutex; + raise WL.No_more_tasks + ) -(** Worker runs tasks from its queue until none remains *) -let worker_run_self_tasks_ (self : state) ~runner w : unit = - let continue = ref true in - while !continue && A.get self.active do - match WSQ.pop w.q with - | Some task -> - try_wake_someone_ self; - run_task_now_ self ~runner ~w task - | None -> continue := false - done +let before_start (self : worker_state) : unit = + let t_id = Thread.id @@ Thread.self () in + self.st.on_init_thread ~dom_id:self.dom_id ~t_id (); + TLS.set k_cur_fiber _dummy_fiber; + TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner; + TLS.set k_worker_state self; -(** Main loop for a worker thread. *) -let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = - TLS.set Runner.For_runner_implementors.k_cur_runner runner; - TLS.set k_worker_state w; + (* set thread name *) + Option.iter + (fun name -> + Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name self.idx)) + self.st.name - let rec main () : unit = - worker_run_self_tasks_ self ~runner w; - try_steal () - and run_task task : unit = - run_task_now_ self ~runner ~w task; - main () - and try_steal () = - match try_to_steal_work_once_ self w with - | Some task -> run_task task - | None -> wait () - and wait () = - Mutex.lock self.mutex; - match Queue.pop self.main_q with - | task -> - Mutex.unlock self.mutex; - run_task task - | exception Queue.Empty -> - (* wait here *) - if A.get self.active then ( - wait_ self; +let cleanup (self : worker_state) : unit = + (* on termination, decrease refcount of underlying domain *) + Domain_pool_.decr_on self.dom_id; + let t_id = Thread.id @@ Thread.self () in + self.st.on_exit_thread ~dom_id:self.dom_id ~t_id () - (* see if a task became available *) - let task = - try Some (Queue.pop self.main_q) with Queue.Empty -> None - in - Mutex.unlock self.mutex; - - match task with - | Some t -> run_task t - | None -> try_steal () - ) else - (* do nothing more: no task in main queue, and we are shutting - down so no new task should arrive. - The exception is if another task is creating subtasks - that overflow into the main queue, but we can ignore that at - the price of slightly decreased performance for the last few - tasks *) - Mutex.unlock self.mutex +let worker_ops : worker_state WL.ops = + let runner (st : worker_state) = st.st.as_runner in + let around_task st = st.st.around_task in + let on_exn (st : worker_state) (ebt : Exn_bt.t) = + st.st.on_exn (Exn_bt.exn ebt) (Exn_bt.bt ebt) in - - (* handle domain-local await *) - Dla_.using ~prepare_for_await:Suspend_.prepare_for_await ~while_running:main + { + WL.schedule = schedule_from_w; + runner; + get_next_task; + get_thread_state = get_current_worker_exn; + around_task; + on_exn; + before_start; + cleanup; + } let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () @@ -277,6 +230,15 @@ let shutdown_ ~wait (self : state) : unit = if wait then Array.iter (fun w -> Thread.join w.thread) self.workers ) +let as_runner_ (self : state) : t = + Runner.For_runner_implementors.create + ~shutdown:(fun ~wait () -> shutdown_ self ~wait) + ~run_async:(fun ~fiber f -> + schedule_in_main_queue self @@ T_start { fiber; f }) + ~size:(fun () -> size_ self) + ~num_tasks:(fun () -> num_tasks_ self) + () + type ('a, 'b) create_args = ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> @@ -287,9 +249,6 @@ type ('a, 'b) create_args = 'a (** Arguments used in {!create}. See {!create} for explanations. *) -let dummy_task_ : task_full = - T_start { f = ignore; ls = Task_local_storage.dummy } - let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?around_task ?num_threads ?name () : t = @@ -297,8 +256,8 @@ let create ?(on_init_thread = default_thread_init_exit_) (* wrapper *) let around_task = match around_task with - | Some (f, g) -> AT_pair (f, g) - | None -> AT_pair (ignore, fun _ _ -> ()) + | Some (f, g) -> WL.AT_pair (f, g) + | None -> WL.AT_pair (ignore, fun _ _ -> ()) in let num_domains = Domain_pool_.max_number_of_domains () in @@ -307,23 +266,11 @@ let create ?(on_init_thread = default_thread_init_exit_) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) let offset = Random.int num_domains in - let workers : worker_state array = - let dummy = Thread.self () in - Array.init num_threads (fun i -> - { - pool_id_; - thread = dummy; - q = WSQ.create ~dummy:dummy_task_ (); - rng = Random.State.make [| i |]; - cur_ls = None; - }) - in - let pool = { id_ = pool_id_; active = A.make true; - workers; + workers = [||]; main_q = Queue.create (); n_waiting = 0; n_waiting_nonzero = true; @@ -331,65 +278,48 @@ let create ?(on_init_thread = default_thread_init_exit_) cond = Condition.create (); around_task; on_exn; + on_init_thread; + on_exit_thread; + name; + as_runner = Runner.dummy; } in - - let runner = - Runner.For_runner_implementors.create - ~shutdown:(fun ~wait () -> shutdown_ pool ~wait) - ~run_async:(fun ~ls f -> run_async_ pool ~ls f) - ~size:(fun () -> size_ pool) - ~num_tasks:(fun () -> num_tasks_ pool) - () - in + pool.as_runner <- as_runner_ pool; (* temporary queue used to obtain thread handles from domains on which the thread are started. *) let receive_threads = Bb_queue.create () in (* start the thread with index [i] *) - let start_thread_with_idx i = - let w = pool.workers.(i) in - let dom_idx = (offset + i) mod num_domains in + let create_worker_state idx = + let dom_id = (offset + idx) mod num_domains in + { + st = pool; + thread = (* dummy *) Thread.self (); + q = WSQ.create ~dummy:WL._dummy_task (); + rng = Random.State.make [| idx |]; + dom_id; + idx; + } + in - (* function run in the thread itself *) - let main_thread_fun () : unit = - let thread = Thread.self () in - let t_id = Thread.id thread in - on_init_thread ~dom_id:dom_idx ~t_id (); - TLS.set k_cur_storage _dummy_ls; - - (* set thread name *) - Option.iter - (fun name -> - Tracing_.set_thread_name (Printf.sprintf "%s.worker.%d" name i)) - name; - - let run () = worker_thread_ pool ~runner w in - - (* now run the main loop *) - Fun.protect run ~finally:(fun () -> - (* on termination, decrease refcount of underlying domain *) - Domain_pool_.decr_on dom_idx); - on_exit_thread ~dom_id:dom_idx ~t_id () - in + pool.workers <- Array.init num_threads create_worker_state; + (* start the thread with index [i] *) + let start_thread_with_idx idx (st : worker_state) = (* function called in domain with index [i], to create the thread and push it into [receive_threads] *) let create_thread_in_domain () = - let thread = Thread.create main_thread_fun () in + let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in (* send the thread from the domain back to us *) - Bb_queue.push receive_threads (i, thread) + Bb_queue.push receive_threads (idx, thread) in - - Domain_pool_.run_on dom_idx create_thread_in_domain + Domain_pool_.run_on st.dom_id create_thread_in_domain in - (* start all threads, placing them on the domains + (* start all worker threads, placing them on the domains according to their index and [offset] in a round-robin fashion. *) - for i = 0 to num_threads - 1 do - start_thread_with_idx i - done; + Array.iteri start_thread_with_idx pool.workers; (* receive the newly created threads back from domains *) for _j = 1 to num_threads do @@ -398,7 +328,7 @@ let create ?(on_init_thread = default_thread_init_exit_) worker_state.thread <- th done; - runner + pool.as_runner let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads ?name () f = diff --git a/src/fib/dune b/src/fib/dune index 17ed239a..4c6594f8 100644 --- a/src/fib/dune +++ b/src/fib/dune @@ -2,7 +2,7 @@ (name moonpool_fib) (public_name moonpool.fib) (synopsis "Fibers and structured concurrency for Moonpool") - (libraries moonpool) + (libraries moonpool picos) (enabled_if (>= %{ocaml_version} 5.0)) (flags :standard -open Moonpool_private -open Moonpool) diff --git a/src/fib/fiber.ml b/src/fib/fiber.ml index c50325b1..bedfa0e7 100644 --- a/src/fib/fiber.ml +++ b/src/fib/fiber.ml @@ -1,6 +1,9 @@ +open Moonpool.Private.Types_ module A = Atomic module FM = Handle.Map module Int_map = Map.Make (Int) +module PF = Picos.Fiber +module FLS = Picos.Fiber.FLS type 'a callback = 'a Exn_bt.result -> unit (** Callbacks that are called when a fiber is done. *) @@ -10,13 +13,16 @@ type cancel_callback = Exn_bt.t -> unit let prom_of_fut : 'a Fut.t -> 'a Fut.promise = Fut.Private_.unsafe_promise_of_fut +(* TODO: replace with picos structured at some point? *) module Private_ = struct + type pfiber = PF.t + type 'a t = { id: Handle.t; (** unique identifier for this fiber *) state: 'a state A.t; (** Current state in the lifetime of the fiber *) res: 'a Fut.t; runner: Runner.t; - ls: Task_local_storage.t; + pfiber: pfiber; (** Associated picos fiber *) } and 'a state = @@ -30,11 +36,18 @@ module Private_ = struct and children = any FM.t and any = Any : _ t -> any [@@unboxed] - (** Key to access the current fiber. *) - let k_current_fiber : any option Task_local_storage.key = - Task_local_storage.new_key ~init:(fun () -> None) () + (** Key to access the current moonpool.fiber. *) + let k_current_fiber : any FLS.t = FLS.create () - let[@inline] get_cur () : any option = Task_local_storage.get k_current_fiber + exception Not_set = FLS.Not_set + + let[@inline] get_cur_from_exn (pfiber : pfiber) : any = + FLS.get_exn pfiber k_current_fiber + + let[@inline] get_cur_exn () : any = + get_cur_from_exn @@ get_current_fiber_exn () + + let[@inline] get_cur_opt () = try Some (get_cur_exn ()) with _ -> None let[@inline] is_closed (self : _ t) = match A.get self.state with @@ -44,9 +57,9 @@ end include Private_ -let create_ ~ls ~runner () : 'a t = +let create_ ~pfiber ~runner () : 'a t = let id = Handle.generate_fresh () in - let res, _promise = Fut.make () in + let res, _ = Fut.make () in { state = A.make @@ -54,7 +67,7 @@ let create_ ~ls ~runner () : 'a t = id; res; runner; - ls; + pfiber; } let create_done_ ~res () : _ t = @@ -66,7 +79,7 @@ let create_done_ ~res () : _ t = id; res; runner = Runner.dummy; - ls = Task_local_storage.dummy; + pfiber = Moonpool.Private.Types_._dummy_fiber; } let[@inline] return x = create_done_ ~res:(Fut.return x) () @@ -175,7 +188,8 @@ let with_on_cancel (self : _ t) cb (k : unit -> 'a) : 'a = let h = add_on_cancel self cb in Fun.protect k ~finally:(fun () -> remove_on_cancel self h) -(** Successfully resolve the fiber *) +(** Successfully resolve the fiber. This might still fail if + some children failed. *) let resolve_ok_ (self : 'a t) (r : 'a) : unit = let r = A.make @@ Ok r in let promise = prom_of_fut self.res in @@ -239,15 +253,23 @@ let add_child_ ~protect (self : _ t) (child : _ t) = () done -let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t = +let spawn_ ~parent ~runner (f : unit -> 'a) : 'a t = + let comp = Picos.Computation.create () in + let pfiber = PF.create ~forbid:false comp in + + (* copy local hmap from parent, if present *) + Option.iter + (fun (p : _ t) -> Fls.Private_hmap_ls_.copy_fls p.pfiber pfiber) + parent; + (match parent with | Some p when is_closed p -> failwith "spawn: nursery is closed" | _ -> ()); - let fib = create_ ~ls ~runner () in + let fib = create_ ~pfiber ~runner () in let run () = (* make sure the fiber is accessible from inside itself *) - Task_local_storage.set k_current_fiber (Some (Any fib)); + FLS.set pfiber k_current_fiber (Any fib); try let res = f () in resolve_ok_ fib res @@ -257,63 +279,54 @@ let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t = resolve_as_failed_ fib ebt in - Runner.run_async ~ls runner run; + Runner.run_async ~fiber:pfiber runner run; fib -let spawn_top ~on f : _ t = - let ls = Task_local_storage.Direct.create () in - spawn_ ~ls ~runner:on ~parent:None f +let spawn_top ~on f : _ t = spawn_ ~runner:on ~parent:None f let spawn ?on ?(protect = true) f : _ t = (* spawn [f()] with a copy of our local storage *) let (Any p) = - match get_cur () with - | None -> failwith "Fiber.spawn: must be run from within another fiber." - | Some p -> p + try get_cur_exn () + with Not_set -> + failwith "Fiber.spawn: must be run from within another fiber." in - let ls = Task_local_storage.Direct.copy p.ls in + let runner = match on with | Some r -> r | None -> p.runner in - let child = spawn_ ~ls ~parent:(Some p) ~runner f in + let child = spawn_ ~parent:(Some p) ~runner f in add_child_ ~protect p child; child let[@inline] spawn_ignore ?protect f : unit = ignore (spawn ?protect f : _ t) let[@inline] self () : any = - match Task_local_storage.get k_current_fiber with - | None -> failwith "Fiber.self: must be run from inside a fiber." - | Some f -> f + match get_cur_exn () with + | exception Not_set -> failwith "Fiber.self: must be run from inside a fiber." + | f -> f let with_on_self_cancel cb (k : unit -> 'a) : 'a = let (Any self) = self () in let h = add_on_cancel self cb in Fun.protect k ~finally:(fun () -> remove_on_cancel self h) -module Suspend_ = Moonpool.Private.Suspend_ - -let check_if_cancelled_ (self : _ t) = - match A.get self.state with - | Terminating_or_done r -> - (match A.get r with - | Error ebt -> Exn_bt.raise ebt - | _ -> ()) - | _ -> () +let[@inline] check_if_cancelled_ (self : _ t) = PF.check self.pfiber let check_if_cancelled () = - match Task_local_storage.get k_current_fiber with - | None -> + match get_cur_exn () with + | exception Not_set -> failwith "Fiber.check_if_cancelled: must be run from inside a fiber." - | Some (Any self) -> check_if_cancelled_ self + | Any self -> check_if_cancelled_ self let yield () : unit = - match Task_local_storage.get k_current_fiber with - | None -> failwith "Fiber.yield: must be run from inside a fiber." - | Some (Any self) -> + match get_cur_exn () with + | exception Not_set -> + failwith "Fiber.yield: must be run from inside a fiber." + | Any self -> check_if_cancelled_ self; - Suspend_.yield (); + PF.yield (); check_if_cancelled_ self diff --git a/src/fib/fiber.mli b/src/fib/fiber.mli index d02c4e56..0da300e7 100644 --- a/src/fib/fiber.mli +++ b/src/fib/fiber.mli @@ -17,20 +17,27 @@ type cancel_callback = Exn_bt.t -> unit (** Do not rely on this, it is internal implementation details. *) module Private_ : sig type 'a state + type pfiber type 'a t = private { id: Handle.t; (** unique identifier for this fiber *) state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *) res: 'a Fut.t; runner: Runner.t; - ls: Task_local_storage.t; + pfiber: pfiber; } (** Type definition, exposed so that {!any} can be unboxed. Please do not rely on that. *) type any = Any : _ t -> any [@@unboxed] - val get_cur : unit -> any option + exception Not_set + + val get_cur_exn : unit -> any + (** [get_cur_exn ()] either returns the current fiber, or + @raise Not_set if run outside a fiber. *) + + val get_cur_opt : unit -> any option end (**/**) diff --git a/src/fib/main.ml b/src/fib/main.ml index 26112015..0ec22be8 100644 --- a/src/fib/main.ml +++ b/src/fib/main.ml @@ -1,14 +1,20 @@ exception Oh_no of Exn_bt.t let main (f : Runner.t -> 'a) : 'a = - let st = Fifo_pool.Private_.create_state ~threads:[| Thread.self () |] () in - let runner = Fifo_pool.Private_.runner_of_state st in + let worker_st = + Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ()) + ~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt))) + () + in + let runner = Fifo_pool.Private_.runner_of_state worker_st in try let fiber = Fiber.spawn_top ~on:runner (fun () -> f runner) in Fiber.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner); + (* run the main thread *) - Fifo_pool.Private_.run_thread st runner ~on_exn:(fun e bt -> - raise (Oh_no (Exn_bt.make e bt))); + Moonpool.Private.Worker_loop_.worker_loop worker_st + ~ops:Fifo_pool.Private_.worker_ops; + match Fiber.peek fiber with | Some (Ok x) -> x | Some (Error ebt) -> Exn_bt.raise ebt diff --git a/src/forkjoin/dune b/src/forkjoin/dune index b17a163d..84849c9b 100644 --- a/src/forkjoin/dune +++ b/src/forkjoin/dune @@ -6,4 +6,4 @@ (optional) (enabled_if (>= %{ocaml_version} 5.0)) - (libraries moonpool moonpool.private)) + (libraries moonpool moonpool.private picos)) diff --git a/src/forkjoin/moonpool_forkjoin.ml b/src/forkjoin/moonpool_forkjoin.ml index 052ca7f2..2619c4ab 100644 --- a/src/forkjoin/moonpool_forkjoin.ml +++ b/src/forkjoin/moonpool_forkjoin.ml @@ -1,5 +1,4 @@ module A = Moonpool.Atomic -module Suspend_ = Moonpool.Private.Suspend_ module Domain_ = Moonpool_private.Domain_ module State_ = struct @@ -9,7 +8,7 @@ module State_ = struct type ('a, 'b) t = | Init | Left_solved of 'a or_error - | Right_solved of 'b or_error * Suspend_.suspension + | Right_solved of 'b or_error * Trigger.t | Both_solved of 'a or_error * 'b or_error let get_exn_ (self : _ t A.t) = @@ -28,13 +27,13 @@ module State_ = struct Domain_.relax (); set_left_ self left ) - | Right_solved (right, cont) -> + | Right_solved (right, tr) -> let new_st = Both_solved (left, right) in if not (A.compare_and_set self old_st new_st) then ( Domain_.relax (); set_left_ self left ) else - cont (Ok ()) + Trigger.signal tr | Left_solved _ | Both_solved _ -> assert false let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit = @@ -45,27 +44,27 @@ module State_ = struct if not (A.compare_and_set self old_st new_st) then set_right_ self right | Init -> (* we are first arrived, we suspend until the left computation is done *) - Suspend_.suspend - { - Suspend_.handle = - (fun ~run:_ ~resume suspension -> - while - let old_st = A.get self in - match old_st with - | Init -> - not - (A.compare_and_set self old_st - (Right_solved (right, suspension))) - | Left_solved left -> - (* other thread is done, no risk of race condition *) - A.set self (Both_solved (left, right)); - resume suspension (Ok ()); - false - | Right_solved _ | Both_solved _ -> assert false - do - () - done); - } + let trigger = Trigger.create () in + let must_await = ref true in + + while + let old_st = A.get self in + match old_st with + | Init -> + (* setup trigger so that left computation will wake us up *) + not (A.compare_and_set self old_st (Right_solved (right, trigger))) + | Left_solved left -> + (* other thread is done, no risk of race condition *) + A.set self (Both_solved (left, right)); + must_await := false; + false + | Right_solved _ | Both_solved _ -> assert false + do + () + done; + + (* wait for the other computation to be done *) + if !must_await then Trigger.await trigger |> Option.iter Exn_bt.raise | Right_solved _ | Both_solved _ -> assert false end @@ -102,7 +101,12 @@ let both_ignore f g = ignore (both f g : _ * _) let for_ ?chunk_size n (f : int -> int -> unit) : unit = if n > 0 then ( - let has_failed = A.make false in + let runner = + match Runner.get_current_runner () with + | None -> failwith "forkjoin.for_: must be run inside a moonpool runner." + | Some r -> r + in + let failure = A.make None in let missing = A.make n in let chunk_size = @@ -113,40 +117,36 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = max 1 (1 + (n / Moonpool.Private.num_domains ())) in - let start_tasks ~run ~resume (suspension : Suspend_.suspension) = - let task_for ~offset ~len_range = - match f offset (offset + len_range - 1) with - | () -> - if A.fetch_and_add missing (-len_range) = len_range then - (* all tasks done successfully *) - resume suspension (Ok ()) - | exception exn -> - let bt = Printexc.get_raw_backtrace () in - if not (A.exchange has_failed true) then - (* first one to fail, and [missing] must be >= 2 - because we're not decreasing it. *) - resume suspension (Error (exn, bt)) - in + let trigger = Trigger.create () in - let i = ref 0 in - while !i < n do - let offset = !i in - - let len_range = min chunk_size (n - offset) in - assert (offset + len_range <= n); - - run (fun () -> task_for ~offset ~len_range); - i := !i + len_range - done + let task_for ~offset ~len_range = + match f offset (offset + len_range - 1) with + | () -> + if A.fetch_and_add missing (-len_range) = len_range then + (* all tasks done successfully *) + Trigger.signal trigger + | exception exn -> + let bt = Printexc.get_raw_backtrace () in + if Option.is_none (A.exchange failure (Some (Exn_bt.make exn bt))) then + (* first one to fail, and [missing] must be >= 2 + because we're not decreasing it. *) + Trigger.signal trigger in - Suspend_.suspend - { - Suspend_.handle = - (fun ~run ~resume suspension -> - (* run tasks, then we'll resume [suspension] *) - start_tasks ~run ~resume suspension); - } + let i = ref 0 in + while !i < n do + let offset = !i in + + let len_range = min chunk_size (n - offset) in + assert (offset + len_range <= n); + + Runner.run_async runner (fun () -> task_for ~offset ~len_range); + i := !i + len_range + done; + + Trigger.await trigger |> Option.iter Exn_bt.raise; + Option.iter Exn_bt.raise @@ A.get failure; + () ) let all_array ?chunk_size (fs : _ array) : _ array = diff --git a/src/io/dune b/src/io/dune new file mode 100644 index 00000000..8ba05ca5 --- /dev/null +++ b/src/io/dune @@ -0,0 +1,7 @@ +(library + (name moonpool_io) + (public_name moonpool-io) + (synopsis "Async IO for moonpool, using Picos") + (enabled_if + (>= %{ocaml_version} 5.0)) + (libraries moonpool picos_io picos_io.select picos_io.fd)) diff --git a/src/io/moonpool_io.ml b/src/io/moonpool_io.ml new file mode 100644 index 00000000..9b249347 --- /dev/null +++ b/src/io/moonpool_io.ml @@ -0,0 +1,11 @@ +module Fd = Picos_io_fd +module Unix = Picos_io.Unix +module Select = Picos_io_select + +let fd_of_unix_fd : Unix.file_descr -> Fd.t = Fun.id +let configure = Select.configure + +(** {2 Async read/write} *) + +let read = Unix.read +let write = Unix.write diff --git a/src/lwt/IO.ml b/src/lwt/IO.ml index 4a8acc69..6ae09506 100644 --- a/src/lwt/IO.ml +++ b/src/lwt/IO.ml @@ -1,17 +1,14 @@ open Base let await_readable fd : unit = - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Wait_readable - ( fd, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - } + let trigger = Trigger.create () in + Perform_action_in_lwt.schedule + @@ Action.Wait_readable + ( fd, + fun cancel -> + Trigger.signal trigger; + Lwt_engine.stop_event cancel ); + Trigger.await_exn trigger let rec read fd buf i len : int = if len = 0 then @@ -25,17 +22,14 @@ let rec read fd buf i len : int = ) let await_writable fd = - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Wait_writable - ( fd, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - } + let trigger = Trigger.create () in + Perform_action_in_lwt.schedule + @@ Action.Wait_writable + ( fd, + fun cancel -> + Trigger.signal trigger; + Lwt_engine.stop_event cancel ); + Trigger.await_exn trigger let rec write_once fd buf i len : int = if len = 0 then @@ -59,16 +53,14 @@ let write fd buf i len : unit = (** Sleep for the given amount of seconds *) let sleep_s (f : float) : unit = - if f > 0. then - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Sleep - ( f, - false, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - } + if f > 0. then ( + let trigger = Trigger.create () in + Perform_action_in_lwt.schedule + @@ Action.Sleep + ( f, + false, + fun cancel -> + Trigger.signal trigger; + Lwt_engine.stop_event cancel ); + Trigger.await_exn trigger + ) diff --git a/src/lwt/base.ml b/src/lwt/base.ml index 88e7ed3d..e859f06e 100644 --- a/src/lwt/base.ml +++ b/src/lwt/base.ml @@ -1,4 +1,5 @@ open Common_ +module Trigger = M.Trigger module Fiber = Moonpool_fib.Fiber module FLS = Moonpool_fib.Fls @@ -14,7 +15,7 @@ module Action = struct | Sleep of float * bool * cb (* TODO: provide actions with cancellation, alongside a "select" operation *) (* | Cancel of event *) - | On_termination : 'a Lwt.t * ('a Exn_bt.result -> unit) -> t + | On_termination : 'a Lwt.t * 'a Exn_bt.result ref * Trigger.t -> t | Wakeup : 'a Lwt.u * 'a -> t | Wakeup_exn : _ Lwt.u * exn -> t | Other of (unit -> unit) @@ -26,10 +27,14 @@ module Action = struct | Wait_writable (fd, cb) -> ignore (Lwt_engine.on_writable fd cb : event) | Sleep (f, repeat, cb) -> ignore (Lwt_engine.on_timer f repeat cb : event) (* | Cancel ev -> Lwt_engine.stop_event ev *) - | On_termination (fut, f) -> + | On_termination (fut, res, trigger) -> Lwt.on_any fut - (fun x -> f @@ Ok x) - (fun exn -> f @@ Error (Exn_bt.get_callstack 10 exn)) + (fun x -> + res := Ok x; + Trigger.signal trigger) + (fun exn -> + res := Error (Exn_bt.get_callstack 10 exn); + Trigger.signal trigger) | Wakeup (prom, x) -> Lwt.wakeup prom x | Wakeup_exn (prom, e) -> Lwt.wakeup_exn prom e | Other f -> f () @@ -90,7 +95,8 @@ let lwt_of_fut (fut : 'a M.Fut.t) : 'a Lwt.t = let lwt_fut, lwt_prom = Lwt.wait () in M.Fut.on_result fut (function | Ok x -> Perform_action_in_lwt.schedule @@ Action.Wakeup (lwt_prom, x) - | Error (exn, _) -> + | Error ebt -> + let exn = Exn_bt.exn ebt in Perform_action_in_lwt.schedule @@ Action.Wakeup_exn (lwt_prom, exn)); lwt_fut @@ -101,26 +107,24 @@ let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t = let fut, prom = M.Fut.make () in Lwt.on_any lwt_fut (fun x -> M.Fut.fulfill prom (Ok x)) - (fun e -> M.Fut.fulfill prom (Error (e, Printexc.get_callstack 10))); + (fun exn -> + let bt = Printexc.get_callstack 10 in + M.Fut.fulfill prom (Error (Exn_bt.make exn bt))); fut +let _dummy_exn_bt : Exn_bt.t = + Exn_bt.get_callstack 0 (Failure "dummy Exn_bt from moonpool-lwt") + let await_lwt (fut : _ Lwt.t) = match Lwt.poll fut with | Some x -> x | None -> (* suspend fiber, wake it up when [fut] resolves *) - M.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - let on_lwt_done _ = resume sus @@ Ok () in - Perform_action_in_lwt.( - schedule Action.(On_termination (fut, on_lwt_done)))); - }; - - (match Lwt.poll fut with - | Some x -> x - | None -> assert false) + let trigger = M.Trigger.create () in + let res = ref (Error _dummy_exn_bt) in + Perform_action_in_lwt.(schedule Action.(On_termination (fut, res, trigger))); + Trigger.await trigger |> Option.iter Exn_bt.raise; + Exn_bt.unwrap !res let run_in_lwt f : _ M.Fut.t = let fut, prom = M.Fut.make () in diff --git a/src/lwt/dune b/src/lwt/dune index 9038747a..b03d03d6 100644 --- a/src/lwt/dune +++ b/src/lwt/dune @@ -4,4 +4,9 @@ (private_modules common_) (enabled_if (>= %{ocaml_version} 5.0)) - (libraries moonpool moonpool.fib lwt lwt.unix)) + (libraries + (re_export moonpool) + (re_export moonpool.fib) + picos + (re_export lwt) + lwt.unix)) diff --git a/src/private/dla_.dummy.ml b/src/private/dla_.dummy.ml deleted file mode 100644 index 3991ff1a..00000000 --- a/src/private/dla_.dummy.ml +++ /dev/null @@ -1,13 +0,0 @@ -(** Interface to Domain-local-await. - - This is used to handle the presence or absence of DLA. *) - -type t = { - release: unit -> unit; - await: unit -> unit; -} - -let using : prepare_for_await:(unit -> t) -> while_running:(unit -> 'a) -> 'a = - fun ~prepare_for_await:_ ~while_running -> while_running () - -let setup_domain () = () diff --git a/src/private/dla_.real.ml b/src/private/dla_.real.ml deleted file mode 100644 index 5f99d714..00000000 --- a/src/private/dla_.real.ml +++ /dev/null @@ -1,9 +0,0 @@ -type t = Domain_local_await.t = { - release: unit -> unit; - await: unit -> unit; -} - -let using : prepare_for_await:(unit -> t) -> while_running:(unit -> 'a) -> 'a = - Domain_local_await.using - -let setup_domain () = Domain_local_await.per_thread (module Thread) diff --git a/src/private/dune b/src/private/dune index 37b5a925..4555122c 100644 --- a/src/private/dune +++ b/src/private/dune @@ -8,16 +8,6 @@ (libraries threads either - (select - thread_local_storage_.ml - from - (thread-local-storage -> thread_local_storage_.stub.ml) - (-> thread_local_storage_.real.ml)) - (select - dla_.ml - from - (domain-local-await -> dla_.real.ml) - (-> dla_.dummy.ml)) (select tracing_.ml from diff --git a/src/private/thread_local_storage_.mli b/src/private/thread_local_storage_.mli deleted file mode 100644 index 2769f4cd..00000000 --- a/src/private/thread_local_storage_.mli +++ /dev/null @@ -1,15 +0,0 @@ -(** Thread local storage *) - -type 'a t -(** A TLS slot for values of type ['a]. This allows the storage of a - single value of type ['a] per thread. *) - -exception Not_set - -val create : unit -> 'a t - -val get_exn : 'a t -> 'a -(** @raise Not_set if not present *) - -val get_opt : 'a t -> 'a option -val set : 'a t -> 'a -> unit diff --git a/src/private/thread_local_storage_.real.ml b/src/private/thread_local_storage_.real.ml deleted file mode 100644 index 14f14ffb..00000000 --- a/src/private/thread_local_storage_.real.ml +++ /dev/null @@ -1,122 +0,0 @@ -(* vendored from https://github.com/c-cube/thread-local-storage *) - -module Atomic = Atomic_ - -(* sanity check *) -let () = assert (Obj.field (Obj.repr (Thread.self ())) 1 = Obj.repr ()) - -type 'a t = int -(** Unique index for this TLS slot. *) - -let tls_length index = - let ceil_pow_2_minus_1 (n : int) : int = - let n = n lor (n lsr 1) in - let n = n lor (n lsr 2) in - let n = n lor (n lsr 4) in - let n = n lor (n lsr 8) in - let n = n lor (n lsr 16) in - if Sys.int_size > 32 then - n lor (n lsr 32) - else - n - in - let size = ceil_pow_2_minus_1 (index + 1) in - assert (size > index); - size - -(** Counter used to allocate new keys *) -let counter = Atomic.make 0 - -(** Value used to detect a TLS slot that was not initialized yet. - Because [counter] is private and lives forever, no other - object the user can see will have the same address. *) -let sentinel_value_for_uninit_tls : Obj.t = Obj.repr counter - -external max_wosize : unit -> int = "caml_sys_const_max_wosize" - -let max_word_size = max_wosize () - -let create () : _ t = - let index = Atomic.fetch_and_add counter 1 in - if tls_length index <= max_word_size then - index - else ( - (* Some platforms have a small max word size. *) - ignore (Atomic.fetch_and_add counter (-1)); - failwith "Thread_local_storage.create: out of TLS slots" - ) - -type thread_internal_state = { - _id: int; (** Thread ID (here for padding reasons) *) - mutable tls: Obj.t; (** Our data, stowed away in this unused field *) - _other: Obj.t; - (** Here to avoid lying to ocamlopt/flambda about the size of [Thread.t] *) -} -(** A partial representation of the internal type [Thread.t], allowing - us to access the second field (unused after the thread - has started) and stash TLS data in it. *) - -let[@inline] get_raw index : Obj.t = - let thread : thread_internal_state = Obj.magic (Thread.self ()) in - let tls = thread.tls in - if Obj.is_block tls && index < Array.length (Obj.obj tls : Obj.t array) then - Array.unsafe_get (Obj.obj tls : Obj.t array) index - else - sentinel_value_for_uninit_tls - -exception Not_set - -let[@inline] get_exn slot = - let v = get_raw slot in - if v != sentinel_value_for_uninit_tls then - Obj.obj v - else - raise_notrace Not_set - -let[@inline] get_opt slot = - let v = get_raw slot in - if v != sentinel_value_for_uninit_tls then - Some (Obj.obj v) - else - None - -(** Allocating and setting *) - -(** Grow the array so that [index] is valid. *) -let grow (old : Obj.t array) (index : int) : Obj.t array = - let new_length = tls_length index in - let new_ = Array.make new_length sentinel_value_for_uninit_tls in - Array.blit old 0 new_ 0 (Array.length old); - new_ - -let get_tls_with_capacity index : Obj.t array = - let thread : thread_internal_state = Obj.magic (Thread.self ()) in - let tls = thread.tls in - if Obj.is_int tls then ( - let new_tls = grow [||] index in - thread.tls <- Obj.repr new_tls; - new_tls - ) else ( - let tls = (Obj.obj tls : Obj.t array) in - if index < Array.length tls then - tls - else ( - let new_tls = grow tls index in - thread.tls <- Obj.repr new_tls; - new_tls - ) - ) - -let[@inline] set slot value : unit = - let tls = get_tls_with_capacity slot in - Array.unsafe_set tls slot (Obj.repr (Sys.opaque_identity value)) - -let[@inline] get_default ~default slot = - let v = get_raw slot in - if v != sentinel_value_for_uninit_tls then - Obj.obj v - else ( - let v = default () in - set slot v; - v - ) diff --git a/src/private/thread_local_storage_.stub.ml b/src/private/thread_local_storage_.stub.ml deleted file mode 100644 index 82d3ff6d..00000000 --- a/src/private/thread_local_storage_.stub.ml +++ /dev/null @@ -1,2 +0,0 @@ -(* just defer to library *) -include Thread_local_storage diff --git a/src/private/ws_deque_.ml b/src/private/ws_deque_.ml index 6c5d1419..368cc8b0 100644 --- a/src/private/ws_deque_.ml +++ b/src/private/ws_deque_.ml @@ -72,7 +72,9 @@ let push (self : 'a t) (x : 'a) : bool = true with Full -> false -let pop (self : 'a t) : 'a option = +exception Empty + +let pop_exn (self : 'a t) : 'a = let b = A.get self.bottom in let b = b - 1 in A.set self.bottom b; @@ -84,11 +86,11 @@ let pop (self : 'a t) : 'a option = if size < 0 then ( (* reset to basic empty state *) A.set self.bottom t; - None + raise_notrace Empty ) else if size > 0 then ( (* can pop without modifying [top] *) let x = CA.get self.arr b in - Some x + x ) else ( assert (size = 0); (* there was exactly one slot, so we might be racing against stealers @@ -96,13 +98,18 @@ let pop (self : 'a t) : 'a option = if A.compare_and_set self.top t (t + 1) then ( let x = CA.get self.arr b in A.set self.bottom (t + 1); - Some x + x ) else ( A.set self.bottom (t + 1); - None + raise_notrace Empty ) ) +let[@inline] pop self : _ option = + match pop_exn self with + | exception Empty -> None + | t -> Some t + let steal (self : 'a t) : 'a option = (* read [top], but do not update [top_cached] as we're in another thread *) diff --git a/src/private/ws_deque_.mli b/src/private/ws_deque_.mli index b696224e..0b9fd84a 100644 --- a/src/private/ws_deque_.mli +++ b/src/private/ws_deque_.mli @@ -21,6 +21,10 @@ val pop : 'a t -> 'a option (** Pop value from the bottom of deque. This must be called only by the owner thread. *) +exception Empty + +val pop_exn : 'a t -> 'a + val steal : 'a t -> 'a option (** Try to steal from the top of deque. This is thread-safe. *) diff --git a/src/sync/dune b/src/sync/dune new file mode 100644 index 00000000..365d310b --- /dev/null +++ b/src/sync/dune @@ -0,0 +1,5 @@ +(library + (name moonpool_sync) + (public_name moonpool.sync) + (synopsis "Cooperative synchronization primitives for Moonpool") + (libraries moonpool picos picos_std.sync picos_std.event)) diff --git a/src/sync/event.ml b/src/sync/event.ml new file mode 100644 index 00000000..90446648 --- /dev/null +++ b/src/sync/event.ml @@ -0,0 +1,11 @@ +include Picos_std_event.Event + +let[@inline] of_fut (fut : _ Moonpool.Fut.t) : _ t = + from_computation (Moonpool.Fut.Private_.as_computation fut) + +module Infix = struct + let[@inline] ( let+ ) x f = map f x + let ( >|= ) = ( let+ ) +end + +include Infix diff --git a/src/sync/event.mli b/src/sync/event.mli new file mode 100644 index 00000000..309edbc7 --- /dev/null +++ b/src/sync/event.mli @@ -0,0 +1,12 @@ +include module type of struct + include Picos_std_event.Event +end + +val of_fut : 'a Moonpool.Fut.t -> 'a t + +module Infix : sig + val ( >|= ) : 'a t -> ('a -> 'b) -> 'b t + val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t +end + +include module type of Infix diff --git a/src/sync/lock.ml b/src/sync/lock.ml new file mode 100644 index 00000000..fb70e3ac --- /dev/null +++ b/src/sync/lock.ml @@ -0,0 +1,38 @@ +module Mutex = Picos_std_sync.Mutex + +type 'a t = { + mutex: Mutex.t; + mutable content: 'a; +} + +let create content : _ t = { mutex = Mutex.create (); content } + +let with_ (self : _ t) f = + Mutex.lock self.mutex; + try + let x = f self.content in + Mutex.unlock self.mutex; + x + with e -> + Mutex.unlock self.mutex; + raise e + +let[@inline] mutex self = self.mutex +let[@inline] update self f = with_ self (fun x -> self.content <- f x) + +let[@inline] update_map l f = + with_ l (fun x -> + let x', y = f x in + l.content <- x'; + y) + +let get l = + Mutex.lock l.mutex; + let x = l.content in + Mutex.unlock l.mutex; + x + +let set l x = + Mutex.lock l.mutex; + l.content <- x; + Mutex.unlock l.mutex diff --git a/src/sync/lock.mli b/src/sync/lock.mli new file mode 100644 index 00000000..51754a39 --- /dev/null +++ b/src/sync/lock.mli @@ -0,0 +1,56 @@ +(** Mutex-protected resource. + + This lock is a synchronous concurrency primitive, as a thin wrapper + around {!Mutex} that encourages proper management of the critical + section in RAII style: + + {[ + let (let@) = (@@) + + + … + let compute_foo = + (* enter critical section *) + let@ x = Lock.with_ protected_resource in + use_x; + return_foo () + (* exit critical section *) + in + … + ]} + + This lock is based on {!Picos_sync.Mutex} so it is [await]-safe. + + @since NEXT_RELEASE *) + +type 'a t +(** A value protected by a cooperative mutex *) + +val create : 'a -> 'a t +(** Create a new protected value. *) + +val with_ : 'a t -> ('a -> 'b) -> 'b +(** [with_ l f] runs [f x] where [x] is the value protected with + the lock [l], in a critical section. If [f x] fails, [with_lock l f] + fails too but the lock is released. *) + +val update : 'a t -> ('a -> 'a) -> unit +(** [update l f] replaces the content [x] of [l] with [f x], while protected + by the mutex. *) + +val update_map : 'a t -> ('a -> 'a * 'b) -> 'b +(** [update_map l f] computes [x', y = f (get l)], then puts [x'] in [l] + and returns [y], while protected by the mutex. *) + +val mutex : _ t -> Picos_std_sync.Mutex.t +(** Underlying mutex. *) + +val get : 'a t -> 'a +(** Atomically get the value in the lock. The value that is returned + isn't protected! *) + +val set : 'a t -> 'a -> unit +(** Atomically set the value. + + {b NOTE} caution: using {!get} and {!set} as if this were a {!ref} + is an anti pattern and will not protect data against some race conditions. *) diff --git a/src/sync/moonpool_sync.ml b/src/sync/moonpool_sync.ml new file mode 100644 index 00000000..99065305 --- /dev/null +++ b/src/sync/moonpool_sync.ml @@ -0,0 +1,9 @@ +module Mutex = Picos_std_sync.Mutex +module Condition = Picos_std_sync.Condition +module Lock = Lock +module Event = Event +module Semaphore = Picos_std_sync.Semaphore +module Lazy = Picos_std_sync.Lazy +module Latch = Picos_std_sync.Latch +module Ivar = Picos_std_sync.Ivar +module Stream = Picos_std_sync.Stream diff --git a/test/fiber/dune b/test/fiber/dune index 1b6521ad..42845ff5 100644 --- a/test/fiber/dune +++ b/test/fiber/dune @@ -1,5 +1,5 @@ (tests - (names t_fib1 t_fls t_main) + (names t_fls t_main t_fib1) (enabled_if (>= %{ocaml_version} 5.0)) (package moonpool) diff --git a/test/fiber/t_fib1.ml b/test/fiber/t_fib1.ml index 9235fcf5..7ceedbf4 100644 --- a/test/fiber/t_fib1.ml +++ b/test/fiber/t_fib1.ml @@ -3,7 +3,7 @@ module A = Atomic module F = Moonpool_fib.Fiber let ( let@ ) = ( @@ ) -let runner = Ws_pool.create ~num_threads:8 () +let runner = Fifo_pool.create ~num_threads:1 () module TS = struct type t = int list @@ -80,10 +80,10 @@ let () = let clock = ref (0 :: i :: clock0) in logf !clock "await fiber %d" i; logf (TS.tick_get clock) "cur fiber[%d] is some: %b" i - (Option.is_some @@ F.Private_.get_cur ()); + (Option.is_some @@ F.Private_.get_cur_opt ()); let res = F.await f in logf (TS.tick_get clock) "cur fiber[%d] is some: %b" i - (Option.is_some @@ F.Private_.get_cur ()); + (Option.is_some @@ F.Private_.get_cur_opt ()); F.yield (); logf (TS.tick_get clock) "res %d = %d" i res) subs); diff --git a/test/fiber/t_fls.ml b/test/fiber/t_fls.ml index 01ee96ef..ca397ed0 100644 --- a/test/fiber/t_fls.ml +++ b/test/fiber/t_fls.ml @@ -7,7 +7,7 @@ module FLS = Moonpool_fib.Fls type span_id = int -let k_parent : span_id option FLS.key = FLS.new_key ~init:(fun () -> None) () +let k_parent : span_id Hmap.key = Hmap.Key.create () let ( let@ ) = ( @@ ) let spf = Printf.sprintf @@ -39,10 +39,10 @@ module Tracer = struct let with_span self name f = let id = Span.new_id_ () in - let parent = FLS.get k_parent in + let parent = FLS.get_in_local_hmap_opt k_parent in let span = { Span.id; parent; msg = name } in add self span; - FLS.with_value k_parent (Some id) f + FLS.with_in_local_hmap k_parent id f end module Render = struct