diff --git a/src/core/exn_bt.ml b/src/core/exn_bt.ml index e3d4c520..170db278 100644 --- a/src/core/exn_bt.ml +++ b/src/core/exn_bt.ml @@ -7,3 +7,7 @@ let show self = Printexc.to_string (exn self) let pp out self = Format.pp_print_string out (show self) type nonrec 'a result = ('a, t) result + +let[@inline] unwrap = function + | Ok x -> x + | Error ebt -> raise ebt diff --git a/src/core/exn_bt.mli b/src/core/exn_bt.mli index 37d075a8..55ff8177 100644 --- a/src/core/exn_bt.mli +++ b/src/core/exn_bt.mli @@ -21,3 +21,7 @@ val show : t -> string val pp : Format.formatter -> t -> unit type nonrec 'a result = ('a, t) result + +val unwrap : 'a result -> 'a +(** [unwrap (Ok x)] is [x], [unwrap (Error ebt)] re-raises [ebt]. + @since NEXT_RELEASE *) diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index 1fe4b708..ab2ed5c1 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -11,7 +11,7 @@ type state = { threads: Thread.t array; q: task_full Bb_queue.t; (** Queue for tasks. *) around_task: WL.around_task; - as_runner: t lazy_t; + mutable as_runner: t; (* init options *) name: string option; on_init_thread: dom_id:int -> t_id:int -> unit -> unit; @@ -24,7 +24,6 @@ type worker_state = { idx: int; dom_idx: int; st: state; - mutable current: fiber; } let[@inline] size_ (self : state) = Array.length self.threads @@ -95,7 +94,7 @@ let cleanup (self : worker_state) : unit = self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id () let worker_ops : worker_state WL.ops = - let runner (st : worker_state) = Lazy.force st.st.as_runner in + let runner (st : worker_state) = st.st.as_runner in let around_task st = st.st.around_task in let on_exn (st : worker_state) (ebt : Exn_bt.t) = st.st.on_exn ebt.exn ebt.bt @@ -111,9 +110,9 @@ let worker_ops : worker_state WL.ops = cleanup; } -let create ?(on_init_thread = default_thread_init_exit_) +let create_ ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) - ?around_task ?num_threads ?name () : t = + ?around_task ~threads ?name () : state = (* wrapper *) let around_task = match around_task with @@ -121,6 +120,23 @@ let create ?(on_init_thread = default_thread_init_exit_) | None -> default_around_task_ in + let self = + { + threads; + q = Bb_queue.create (); + around_task; + as_runner = Runner.dummy; + name; + on_init_thread; + on_exit_thread; + on_exn; + } + in + self.as_runner <- runner_of_state self; + self + +let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads + ?name () : t = let num_domains = Domain_pool_.max_number_of_domains () in (* number of threads to run *) @@ -129,20 +145,12 @@ let create ?(on_init_thread = default_thread_init_exit_) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) let offset = Random.int num_domains in - let rec pool = + let pool = let dummy_thread = Thread.self () in - { - threads = Array.make num_threads dummy_thread; - q = Bb_queue.create (); - around_task; - as_runner = lazy (runner_of_state pool); - name; - on_init_thread; - on_exit_thread; - on_exn; - } + let threads = Array.make num_threads dummy_thread in + create_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ~threads ?name + () in - let runner = runner_of_state pool in (* temporary queue used to obtain thread handles from domains @@ -156,7 +164,7 @@ let create ?(on_init_thread = default_thread_init_exit_) (* function called in domain with index [i], to create the thread and push it into [receive_threads] *) let create_thread_in_domain () = - let st = { idx = i; dom_idx; st = pool; current = _dummy_fiber } in + let st = { idx = i; dom_idx; st = pool } in let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in (* send the thread from the domain back to us *) Bb_queue.push receive_threads (i, thread) @@ -187,3 +195,14 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads in let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in f pool + +module Private_ = struct + type nonrec worker_state = worker_state + + let worker_ops = worker_ops + let runner_of_state (self : worker_state) = worker_ops.runner self + + let create_single_threaded_state ~thread ?on_exn () : worker_state = + let st : state = create_ ?on_exn ~threads:[| thread |] () in + { idx = 0; dom_idx = 0; st } +end diff --git a/src/core/fifo_pool.mli b/src/core/fifo_pool.mli index 637586a9..11ba4ed5 100644 --- a/src/core/fifo_pool.mli +++ b/src/core/fifo_pool.mli @@ -44,3 +44,21 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args When [f pool] returns or fails, [pool] is shutdown and its resources are released. Most parameters are the same as in {!create}. *) + +(**/**) + +module Private_ : sig + type worker_state + + val worker_ops : worker_state Worker_loop_.ops + + val create_single_threaded_state : + thread:Thread.t -> + ?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> + unit -> + worker_state + + val runner_of_state : worker_state -> Runner.t +end + +(**/**) diff --git a/src/core/task_local_storage.ml b/src/core/task_local_storage.ml index f9dd98e6..b66448af 100644 --- a/src/core/task_local_storage.ml +++ b/src/core/task_local_storage.ml @@ -1,83 +1,41 @@ open Types_ -(* -module A = Atomic_ +module PF = Picos.Fiber -type 'a key = 'a ls_key +type 'a t = 'a PF.FLS.t -let key_count_ = A.make 0 +exception Not_set = PF.FLS.Not_set -type t = local_storage -type ls_value += Dummy +let create = PF.FLS.create -let dummy : t = _dummy_ls +let[@inline] get_exn k = + let fiber = get_current_fiber_exn () in + PF.FLS.get_exn fiber k -(** Resize array of TLS values *) -let[@inline never] resize_ (cur : ls_value array ref) n = - if n > Sys.max_array_length then failwith "too many task local storage keys"; - let len = Array.length !cur in - let new_ls = - Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy - in - Array.blit !cur 0 new_ls 0 len; - cur := new_ls - -module Direct = struct - type nonrec t = t - - let create = create_local_storage - let[@inline] copy (self : t) = ref (Array.copy !self) - - let get (type a) (self : t) ((module K) : a key) : a = - if K.offset >= Array.length !self then resize_ self (K.offset + 1); - match !self.(K.offset) with - | K.V x -> (* common case first *) x - | Dummy -> - (* first time we access this *) - let v = K.init () in - !self.(K.offset) <- K.V v; - v - | _ -> assert false - - let set (type a) (self : t) ((module K) : a key) (v : a) : unit = - assert (self != dummy); - if K.offset >= Array.length !self then resize_ self (K.offset + 1); - !self.(K.offset) <- K.V v; - () -end - -let new_key (type t) ~init () : t key = - let offset = A.fetch_and_add key_count_ 1 in - (module struct - type nonrec t = t - type ls_value += V of t - - let offset = offset - let init = init - end : LS_KEY - with type t = t) - -let[@inline] get_cur_ () : ls_value array ref = - match get_current_storage () with - | Some r when r != dummy -> r - | _ -> failwith "Task local storage must be accessed from within a runner." - -let[@inline] get (key : 'a key) : 'a = - let cur = get_cur_ () in - Direct.get cur key - -let[@inline] get_opt key = - match get_current_storage () with +let get_opt k = + match get_current_fiber () with | None -> None - | Some cur -> Some (Direct.get cur key) + | Some fiber -> + (match PF.FLS.get_exn fiber k with + | x -> Some x + | exception Not_set -> None) -let[@inline] set key v : unit = - let cur = get_cur_ () in - Direct.set cur key v +let[@inline] get k ~default = + let fiber = get_current_fiber_exn () in + PF.FLS.get fiber ~default k -let with_value key x f = - let old = get key in - set key x; - Fun.protect ~finally:(fun () -> set key old) f +let[@inline] set k v : unit = + let fiber = get_current_fiber_exn () in + PF.FLS.set fiber k v -let get_current = get_current_storage -*) +let with_value k v (f : _ -> 'b) : 'b = + let fiber = get_current_fiber_exn () in + + match PF.FLS.get_exn fiber k with + | exception Not_set -> + PF.FLS.set fiber k v; + (* nothing to restore back to, just call [f] *) + f () + | old_v -> + PF.FLS.set fiber k v; + let finally () = PF.FLS.set fiber k old_v in + Fun.protect f ~finally diff --git a/src/core/task_local_storage.mli b/src/core/task_local_storage.mli index 4fad8e0e..69c07039 100644 --- a/src/core/task_local_storage.mli +++ b/src/core/task_local_storage.mli @@ -8,62 +8,31 @@ @since 0.6 *) -(* -type t = Types_.local_storage -(** Underlying storage for a task. This is mutable and - not thread-safe. *) +type 'a t = 'a Picos.Fiber.FLS.t -val dummy : t +val create : unit -> 'a t +(** [create ()] makes a new key. Keys are expensive and + should never be allocated dynamically or in a loop. *) -type 'a key -(** A key used to access a particular (typed) storage slot on every task. *) +exception Not_set -val new_key : init:(unit -> 'a) -> unit -> 'a key -(** [new_key ~init ()] makes a new key. Keys are expensive and - should never be allocated dynamically or in a loop. - The correct pattern is, at toplevel: - - {[ - let k_foo : foo Task_ocal_storage.key = - Task_local_storage.new_key ~init:(fun () -> make_foo ()) () - - (* … *) - - (* use it: *) - let … = Task_local_storage.get k_foo - ]} -*) - -val get : 'a key -> 'a +val get_exn : 'a t -> 'a (** [get k] gets the value for the current task for key [k]. Must be run from inside a task running on a runner. - @raise Failure otherwise *) + @raise Not_set otherwise *) -val get_opt : 'a key -> 'a option +val get_opt : 'a t -> 'a option (** [get_opt k] gets the current task's value for key [k], or [None] if not run from inside the task. *) -val set : 'a key -> 'a -> unit +val get : 'a t -> default:'a -> 'a + +val set : 'a t -> 'a -> unit (** [set k v] sets the storage for [k] to [v]. Must be run from inside a task running on a runner. @raise Failure otherwise *) -val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b +val with_value : 'a t -> 'a -> (unit -> 'b) -> 'b (** [with_value k v f] sets [k] to [v] for the duration of the call to [f()]. When [f()] returns (or fails), [k] is restored to its old value. *) - -val get_current : unit -> t option -(** Access the current storage, or [None] if not run from - within a task. *) - -(** Direct access to values from a storage handle *) -module Direct : sig - val get : t -> 'a key -> 'a - (** Access a key *) - - val set : t -> 'a key -> 'a -> unit - val create : unit -> t - val copy : t -> t -end -*) diff --git a/src/core/trigger.ml b/src/core/trigger.ml index baad75ce..f7fda452 100644 --- a/src/core/trigger.ml +++ b/src/core/trigger.ml @@ -2,3 +2,5 @@ @since NEXT_RELEASE *) include Picos.Trigger + +let[@inline] await_exn (self : t) = await self |> Option.iter Exn_bt.raise diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index ca31ef8a..70e4bb80 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -25,7 +25,7 @@ type state = { mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *) mutex: Mutex.t; cond: Condition.t; - as_runner: t lazy_t; + mutable as_runner: t; (* init options *) around_task: WL.around_task; name: string option; @@ -167,7 +167,9 @@ and wait_on_worker (self : worker_state) : WL.task_full = | task -> Mutex.unlock self.st.mutex; task - | exception Queue.Empty -> try_steal_from_other_workers_ self + | exception Queue.Empty -> + Mutex.unlock self.st.mutex; + try_steal_from_other_workers_ self ) else ( (* do nothing more: no task in main queue, and we are shutting down so no new task should arrive. @@ -183,8 +185,7 @@ let before_start (self : worker_state) : unit = let t_id = Thread.id @@ Thread.self () in self.st.on_init_thread ~dom_id:self.dom_id ~t_id (); TLS.set k_cur_fiber _dummy_fiber; - TLS.set Runner.For_runner_implementors.k_cur_runner - (Lazy.force self.st.as_runner); + TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner; TLS.set k_worker_state self; (* set thread name *) @@ -200,7 +201,7 @@ let cleanup (self : worker_state) : unit = self.st.on_exit_thread ~dom_id:self.dom_id ~t_id () let worker_ops : worker_state WL.ops = - let runner (st : worker_state) = Lazy.force st.st.as_runner in + let runner (st : worker_state) = st.st.as_runner in let around_task st = st.st.around_task in let on_exn (st : worker_state) (ebt : Exn_bt.t) = st.st.on_exn ebt.exn ebt.bt @@ -261,7 +262,7 @@ let create ?(on_init_thread = default_thread_init_exit_) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *) let offset = Random.int num_domains in - let rec pool = + let pool = { id_ = pool_id_; active = A.make true; @@ -276,28 +277,32 @@ let create ?(on_init_thread = default_thread_init_exit_) on_init_thread; on_exit_thread; name; - as_runner = lazy (as_runner_ pool); + as_runner = Runner.dummy; } in + pool.as_runner <- as_runner_ pool; (* temporary queue used to obtain thread handles from domains on which the thread are started. *) let receive_threads = Bb_queue.create () in (* start the thread with index [i] *) - let start_thread_with_idx idx = + let create_worker_state idx = let dom_id = (offset + idx) mod num_domains in - let st = - { - st = pool; - thread = (* dummy *) Thread.self (); - q = WSQ.create ~dummy:WL._dummy_task (); - rng = Random.State.make [| idx |]; - dom_id; - idx; - } - in + { + st = pool; + thread = (* dummy *) Thread.self (); + q = WSQ.create ~dummy:WL._dummy_task (); + rng = Random.State.make [| idx |]; + dom_id; + idx; + } + in + pool.workers <- Array.init num_threads create_worker_state; + + (* start the thread with index [i] *) + let start_thread_with_idx idx (st : worker_state) = (* function called in domain with index [i], to create the thread and push it into [receive_threads] *) let create_thread_in_domain () = @@ -305,15 +310,12 @@ let create ?(on_init_thread = default_thread_init_exit_) (* send the thread from the domain back to us *) Bb_queue.push receive_threads (idx, thread) in - - Domain_pool_.run_on dom_id create_thread_in_domain; - - st + Domain_pool_.run_on st.dom_id create_thread_in_domain in (* start all worker threads, placing them on the domains according to their index and [offset] in a round-robin fashion. *) - pool.workers <- Array.init num_threads start_thread_with_idx; + Array.iteri start_thread_with_idx pool.workers; (* receive the newly created threads back from domains *) for _j = 1 to num_threads do @@ -322,7 +324,7 @@ let create ?(on_init_thread = default_thread_init_exit_) worker_state.thread <- th done; - Lazy.force pool.as_runner + pool.as_runner let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads ?name () f = diff --git a/src/fib/dune b/src/fib/dune index 17ed239a..4c6594f8 100644 --- a/src/fib/dune +++ b/src/fib/dune @@ -2,7 +2,7 @@ (name moonpool_fib) (public_name moonpool.fib) (synopsis "Fibers and structured concurrency for Moonpool") - (libraries moonpool) + (libraries moonpool picos) (enabled_if (>= %{ocaml_version} 5.0)) (flags :standard -open Moonpool_private -open Moonpool) diff --git a/src/fib/fiber.ml b/src/fib/fiber.ml index c50325b1..6fe034c0 100644 --- a/src/fib/fiber.ml +++ b/src/fib/fiber.ml @@ -1,6 +1,9 @@ +open Moonpool.Private.Types_ module A = Atomic module FM = Handle.Map module Int_map = Map.Make (Int) +module PF = Picos.Fiber +module FLS = Picos.Fiber.FLS type 'a callback = 'a Exn_bt.result -> unit (** Callbacks that are called when a fiber is done. *) @@ -10,13 +13,16 @@ type cancel_callback = Exn_bt.t -> unit let prom_of_fut : 'a Fut.t -> 'a Fut.promise = Fut.Private_.unsafe_promise_of_fut +(* TODO: replace with picos structured at some point? *) module Private_ = struct + type pfiber = PF.t + type 'a t = { id: Handle.t; (** unique identifier for this fiber *) state: 'a state A.t; (** Current state in the lifetime of the fiber *) res: 'a Fut.t; runner: Runner.t; - ls: Task_local_storage.t; + pfiber: pfiber; (** Associated picos fiber *) } and 'a state = @@ -30,11 +36,18 @@ module Private_ = struct and children = any FM.t and any = Any : _ t -> any [@@unboxed] - (** Key to access the current fiber. *) - let k_current_fiber : any option Task_local_storage.key = - Task_local_storage.new_key ~init:(fun () -> None) () + (** Key to access the current moonpool.fiber. *) + let k_current_fiber : any FLS.t = FLS.create () - let[@inline] get_cur () : any option = Task_local_storage.get k_current_fiber + exception Not_set = FLS.Not_set + + let[@inline] get_cur_from_exn (pfiber : pfiber) : any = + FLS.get_exn pfiber k_current_fiber + + let[@inline] get_cur_exn () : any = + get_cur_from_exn @@ get_current_fiber_exn () + + let[@inline] get_cur_opt () = try Some (get_cur_exn ()) with _ -> None let[@inline] is_closed (self : _ t) = match A.get self.state with @@ -44,9 +57,9 @@ end include Private_ -let create_ ~ls ~runner () : 'a t = +let create_ ~pfiber ~runner () : 'a t = let id = Handle.generate_fresh () in - let res, _promise = Fut.make () in + let res, _ = Fut.make () in { state = A.make @@ -54,7 +67,7 @@ let create_ ~ls ~runner () : 'a t = id; res; runner; - ls; + pfiber; } let create_done_ ~res () : _ t = @@ -66,7 +79,7 @@ let create_done_ ~res () : _ t = id; res; runner = Runner.dummy; - ls = Task_local_storage.dummy; + pfiber = Moonpool.Private.Types_._dummy_fiber; } let[@inline] return x = create_done_ ~res:(Fut.return x) () @@ -175,7 +188,8 @@ let with_on_cancel (self : _ t) cb (k : unit -> 'a) : 'a = let h = add_on_cancel self cb in Fun.protect k ~finally:(fun () -> remove_on_cancel self h) -(** Successfully resolve the fiber *) +(** Successfully resolve the fiber. This might still fail if + some children failed. *) let resolve_ok_ (self : 'a t) (r : 'a) : unit = let r = A.make @@ Ok r in let promise = prom_of_fut self.res in @@ -239,15 +253,21 @@ let add_child_ ~protect (self : _ t) (child : _ t) = () done -let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t = +let spawn_ ~parent ~runner (f : unit -> 'a) : 'a t = + let comp = Picos.Computation.create () in + let pfiber = PF.create ~forbid:false comp in + + (* inherit FLS from parent, if present *) + Option.iter (fun (p : _ t) -> PF.copy_fls p.pfiber pfiber) parent; + (match parent with | Some p when is_closed p -> failwith "spawn: nursery is closed" | _ -> ()); - let fib = create_ ~ls ~runner () in + let fib = create_ ~pfiber ~runner () in let run () = (* make sure the fiber is accessible from inside itself *) - Task_local_storage.set k_current_fiber (Some (Any fib)); + FLS.set pfiber k_current_fiber (Any fib); try let res = f () in resolve_ok_ fib res @@ -257,63 +277,54 @@ let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t = resolve_as_failed_ fib ebt in - Runner.run_async ~ls runner run; + Runner.run_async ~fiber:pfiber runner run; fib -let spawn_top ~on f : _ t = - let ls = Task_local_storage.Direct.create () in - spawn_ ~ls ~runner:on ~parent:None f +let spawn_top ~on f : _ t = spawn_ ~runner:on ~parent:None f let spawn ?on ?(protect = true) f : _ t = (* spawn [f()] with a copy of our local storage *) let (Any p) = - match get_cur () with - | None -> failwith "Fiber.spawn: must be run from within another fiber." - | Some p -> p + try get_cur_exn () + with Not_set -> + failwith "Fiber.spawn: must be run from within another fiber." in - let ls = Task_local_storage.Direct.copy p.ls in + let runner = match on with | Some r -> r | None -> p.runner in - let child = spawn_ ~ls ~parent:(Some p) ~runner f in + let child = spawn_ ~parent:(Some p) ~runner f in add_child_ ~protect p child; child let[@inline] spawn_ignore ?protect f : unit = ignore (spawn ?protect f : _ t) let[@inline] self () : any = - match Task_local_storage.get k_current_fiber with - | None -> failwith "Fiber.self: must be run from inside a fiber." - | Some f -> f + match get_cur_exn () with + | exception Not_set -> failwith "Fiber.self: must be run from inside a fiber." + | f -> f let with_on_self_cancel cb (k : unit -> 'a) : 'a = let (Any self) = self () in let h = add_on_cancel self cb in Fun.protect k ~finally:(fun () -> remove_on_cancel self h) -module Suspend_ = Moonpool.Private.Suspend_ - -let check_if_cancelled_ (self : _ t) = - match A.get self.state with - | Terminating_or_done r -> - (match A.get r with - | Error ebt -> Exn_bt.raise ebt - | _ -> ()) - | _ -> () +let[@inline] check_if_cancelled_ (self : _ t) = PF.check self.pfiber let check_if_cancelled () = - match Task_local_storage.get k_current_fiber with - | None -> + match get_cur_exn () with + | exception Not_set -> failwith "Fiber.check_if_cancelled: must be run from inside a fiber." - | Some (Any self) -> check_if_cancelled_ self + | Any self -> check_if_cancelled_ self let yield () : unit = - match Task_local_storage.get k_current_fiber with - | None -> failwith "Fiber.yield: must be run from inside a fiber." - | Some (Any self) -> + match get_cur_exn () with + | exception Not_set -> + failwith "Fiber.yield: must be run from inside a fiber." + | Any self -> check_if_cancelled_ self; - Suspend_.yield (); + PF.yield (); check_if_cancelled_ self diff --git a/src/fib/fiber.mli b/src/fib/fiber.mli index d02c4e56..0da300e7 100644 --- a/src/fib/fiber.mli +++ b/src/fib/fiber.mli @@ -17,20 +17,27 @@ type cancel_callback = Exn_bt.t -> unit (** Do not rely on this, it is internal implementation details. *) module Private_ : sig type 'a state + type pfiber type 'a t = private { id: Handle.t; (** unique identifier for this fiber *) state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *) res: 'a Fut.t; runner: Runner.t; - ls: Task_local_storage.t; + pfiber: pfiber; } (** Type definition, exposed so that {!any} can be unboxed. Please do not rely on that. *) type any = Any : _ t -> any [@@unboxed] - val get_cur : unit -> any option + exception Not_set + + val get_cur_exn : unit -> any + (** [get_cur_exn ()] either returns the current fiber, or + @raise Not_set if run outside a fiber. *) + + val get_cur_opt : unit -> any option end (**/**) diff --git a/src/fib/main.ml b/src/fib/main.ml index 26112015..0ec22be8 100644 --- a/src/fib/main.ml +++ b/src/fib/main.ml @@ -1,14 +1,20 @@ exception Oh_no of Exn_bt.t let main (f : Runner.t -> 'a) : 'a = - let st = Fifo_pool.Private_.create_state ~threads:[| Thread.self () |] () in - let runner = Fifo_pool.Private_.runner_of_state st in + let worker_st = + Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ()) + ~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt))) + () + in + let runner = Fifo_pool.Private_.runner_of_state worker_st in try let fiber = Fiber.spawn_top ~on:runner (fun () -> f runner) in Fiber.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner); + (* run the main thread *) - Fifo_pool.Private_.run_thread st runner ~on_exn:(fun e bt -> - raise (Oh_no (Exn_bt.make e bt))); + Moonpool.Private.Worker_loop_.worker_loop worker_st + ~ops:Fifo_pool.Private_.worker_ops; + match Fiber.peek fiber with | Some (Ok x) -> x | Some (Error ebt) -> Exn_bt.raise ebt diff --git a/src/forkjoin/dune b/src/forkjoin/dune index b17a163d..84849c9b 100644 --- a/src/forkjoin/dune +++ b/src/forkjoin/dune @@ -6,4 +6,4 @@ (optional) (enabled_if (>= %{ocaml_version} 5.0)) - (libraries moonpool moonpool.private)) + (libraries moonpool moonpool.private picos)) diff --git a/src/forkjoin/moonpool_forkjoin.ml b/src/forkjoin/moonpool_forkjoin.ml index d5961d95..80cead1d 100644 --- a/src/forkjoin/moonpool_forkjoin.ml +++ b/src/forkjoin/moonpool_forkjoin.ml @@ -1,5 +1,4 @@ module A = Moonpool.Atomic -module Suspend_ = Moonpool.Private.Suspend_ module Domain_ = Moonpool_private.Domain_ module State_ = struct @@ -9,7 +8,7 @@ module State_ = struct type ('a, 'b) t = | Init | Left_solved of 'a or_error - | Right_solved of 'b or_error * Suspend_.suspension + | Right_solved of 'b or_error * Trigger.t | Both_solved of 'a or_error * 'b or_error let get_exn_ (self : _ t A.t) = @@ -28,13 +27,13 @@ module State_ = struct Domain_.relax (); set_left_ self left ) - | Right_solved (right, cont) -> + | Right_solved (right, tr) -> let new_st = Both_solved (left, right) in if not (A.compare_and_set self old_st new_st) then ( Domain_.relax (); set_left_ self left ) else - cont (Ok ()) + Trigger.signal tr | Left_solved _ | Both_solved _ -> assert false let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit = @@ -45,27 +44,27 @@ module State_ = struct if not (A.compare_and_set self old_st new_st) then set_right_ self right | Init -> (* we are first arrived, we suspend until the left computation is done *) - Suspend_.suspend - { - Suspend_.handle = - (fun ~run:_ ~resume suspension -> - while - let old_st = A.get self in - match old_st with - | Init -> - not - (A.compare_and_set self old_st - (Right_solved (right, suspension))) - | Left_solved left -> - (* other thread is done, no risk of race condition *) - A.set self (Both_solved (left, right)); - resume suspension (Ok ()); - false - | Right_solved _ | Both_solved _ -> assert false - do - () - done); - } + let trigger = Trigger.create () in + let must_await = ref true in + + while + let old_st = A.get self in + match old_st with + | Init -> + (* setup trigger so that left computation will wake us up *) + not (A.compare_and_set self old_st (Right_solved (right, trigger))) + | Left_solved left -> + (* other thread is done, no risk of race condition *) + A.set self (Both_solved (left, right)); + must_await := false; + false + | Right_solved _ | Both_solved _ -> assert false + do + () + done; + + (* wait for the other computation to be done *) + if !must_await then Trigger.await trigger |> Option.iter Exn_bt.raise | Right_solved _ | Both_solved _ -> assert false end @@ -102,7 +101,12 @@ let both_ignore f g = ignore (both f g : _ * _) let for_ ?chunk_size n (f : int -> int -> unit) : unit = if n > 0 then ( - let has_failed = A.make false in + let runner = + match Runner.get_current_runner () with + | None -> failwith "forkjoin.for_: must be run inside a moonpool runner." + | Some r -> r + in + let failure = A.make None in let missing = A.make n in let chunk_size = @@ -113,40 +117,36 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = max 1 (1 + (n / Moonpool.Private.num_domains ())) in - let start_tasks ~run ~resume (suspension : Suspend_.suspension) = - let task_for ~offset ~len_range = - match f offset (offset + len_range - 1) with - | () -> - if A.fetch_and_add missing (-len_range) = len_range then - (* all tasks done successfully *) - resume suspension (Ok ()) - | exception exn -> - let bt = Printexc.get_raw_backtrace () in - if not (A.exchange has_failed true) then - (* first one to fail, and [missing] must be >= 2 - because we're not decreasing it. *) - resume suspension (Error { Exn_bt.exn; bt }) - in + let trigger = Trigger.create () in - let i = ref 0 in - while !i < n do - let offset = !i in - - let len_range = min chunk_size (n - offset) in - assert (offset + len_range <= n); - - run (fun () -> task_for ~offset ~len_range); - i := !i + len_range - done + let task_for ~offset ~len_range = + match f offset (offset + len_range - 1) with + | () -> + if A.fetch_and_add missing (-len_range) = len_range then + (* all tasks done successfully *) + Trigger.signal trigger + | exception exn -> + let bt = Printexc.get_raw_backtrace () in + if Option.is_none (A.exchange failure (Some { Exn_bt.exn; bt })) then + (* first one to fail, and [missing] must be >= 2 + because we're not decreasing it. *) + Trigger.signal trigger in - Suspend_.suspend - { - Suspend_.handle = - (fun ~run ~resume suspension -> - (* run tasks, then we'll resume [suspension] *) - start_tasks ~run ~resume suspension); - } + let i = ref 0 in + while !i < n do + let offset = !i in + + let len_range = min chunk_size (n - offset) in + assert (offset + len_range <= n); + + Runner.run_async runner (fun () -> task_for ~offset ~len_range); + i := !i + len_range + done; + + Trigger.await trigger |> Option.iter Exn_bt.raise; + Option.iter Exn_bt.raise @@ A.get failure; + () ) let all_array ?chunk_size (fs : _ array) : _ array = diff --git a/src/lwt/IO.ml b/src/lwt/IO.ml index 4a8acc69..6ae09506 100644 --- a/src/lwt/IO.ml +++ b/src/lwt/IO.ml @@ -1,17 +1,14 @@ open Base let await_readable fd : unit = - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Wait_readable - ( fd, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - } + let trigger = Trigger.create () in + Perform_action_in_lwt.schedule + @@ Action.Wait_readable + ( fd, + fun cancel -> + Trigger.signal trigger; + Lwt_engine.stop_event cancel ); + Trigger.await_exn trigger let rec read fd buf i len : int = if len = 0 then @@ -25,17 +22,14 @@ let rec read fd buf i len : int = ) let await_writable fd = - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Wait_writable - ( fd, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - } + let trigger = Trigger.create () in + Perform_action_in_lwt.schedule + @@ Action.Wait_writable + ( fd, + fun cancel -> + Trigger.signal trigger; + Lwt_engine.stop_event cancel ); + Trigger.await_exn trigger let rec write_once fd buf i len : int = if len = 0 then @@ -59,16 +53,14 @@ let write fd buf i len : unit = (** Sleep for the given amount of seconds *) let sleep_s (f : float) : unit = - if f > 0. then - Moonpool.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - Perform_action_in_lwt.schedule - @@ Action.Sleep - ( f, - false, - fun cancel -> - resume sus @@ Ok (); - Lwt_engine.stop_event cancel )); - } + if f > 0. then ( + let trigger = Trigger.create () in + Perform_action_in_lwt.schedule + @@ Action.Sleep + ( f, + false, + fun cancel -> + Trigger.signal trigger; + Lwt_engine.stop_event cancel ); + Trigger.await_exn trigger + ) diff --git a/src/lwt/base.ml b/src/lwt/base.ml index 73ccc049..e61b71aa 100644 --- a/src/lwt/base.ml +++ b/src/lwt/base.ml @@ -1,4 +1,5 @@ open Common_ +module Trigger = M.Trigger module Fiber = Moonpool_fib.Fiber module FLS = Moonpool_fib.Fls @@ -14,7 +15,7 @@ module Action = struct | Sleep of float * bool * cb (* TODO: provide actions with cancellation, alongside a "select" operation *) (* | Cancel of event *) - | On_termination : 'a Lwt.t * ('a Exn_bt.result -> unit) -> t + | On_termination : 'a Lwt.t * 'a Exn_bt.result ref * Trigger.t -> t | Wakeup : 'a Lwt.u * 'a -> t | Wakeup_exn : _ Lwt.u * exn -> t | Other of (unit -> unit) @@ -26,10 +27,14 @@ module Action = struct | Wait_writable (fd, cb) -> ignore (Lwt_engine.on_writable fd cb : event) | Sleep (f, repeat, cb) -> ignore (Lwt_engine.on_timer f repeat cb : event) (* | Cancel ev -> Lwt_engine.stop_event ev *) - | On_termination (fut, f) -> + | On_termination (fut, res, trigger) -> Lwt.on_any fut - (fun x -> f @@ Ok x) - (fun exn -> f @@ Error (Exn_bt.get_callstack 10 exn)) + (fun x -> + res := Ok x; + Trigger.signal trigger) + (fun exn -> + res := Error (Exn_bt.get_callstack 10 exn); + Trigger.signal trigger) | Wakeup (prom, x) -> Lwt.wakeup prom x | Wakeup_exn (prom, e) -> Lwt.wakeup_exn prom e | Other f -> f () @@ -106,23 +111,19 @@ let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t = M.Fut.fulfill prom (Error { Exn_bt.exn; bt })); fut +let _dummy_exn_bt : Exn_bt.t = + Exn_bt.get_callstack 0 (Failure "dummy Exn_bt from moonpool-lwt") + let await_lwt (fut : _ Lwt.t) = match Lwt.poll fut with | Some x -> x | None -> (* suspend fiber, wake it up when [fut] resolves *) - M.Private.Suspend_.suspend - { - handle = - (fun ~run:_ ~resume sus -> - let on_lwt_done _ = resume sus @@ Ok () in - Perform_action_in_lwt.( - schedule Action.(On_termination (fut, on_lwt_done)))); - }; - - (match Lwt.poll fut with - | Some x -> x - | None -> assert false) + let trigger = M.Trigger.create () in + let res = ref (Error _dummy_exn_bt) in + Perform_action_in_lwt.(schedule Action.(On_termination (fut, res, trigger))); + Trigger.await trigger |> Option.iter Exn_bt.raise; + Exn_bt.unwrap !res let run_in_lwt f : _ M.Fut.t = let fut, prom = M.Fut.make () in diff --git a/src/lwt/dune b/src/lwt/dune index 9038747a..b03d03d6 100644 --- a/src/lwt/dune +++ b/src/lwt/dune @@ -4,4 +4,9 @@ (private_modules common_) (enabled_if (>= %{ocaml_version} 5.0)) - (libraries moonpool moonpool.fib lwt lwt.unix)) + (libraries + (re_export moonpool) + (re_export moonpool.fib) + picos + (re_export lwt) + lwt.unix))