wip: port to picos

This commit is contained in:
Simon Cruanes 2024-08-28 16:09:45 -04:00
parent a0068b09b3
commit 07a7fc3a1c
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
17 changed files with 314 additions and 316 deletions

View file

@ -7,3 +7,7 @@ let show self = Printexc.to_string (exn self)
let pp out self = Format.pp_print_string out (show self) let pp out self = Format.pp_print_string out (show self)
type nonrec 'a result = ('a, t) result type nonrec 'a result = ('a, t) result
let[@inline] unwrap = function
| Ok x -> x
| Error ebt -> raise ebt

View file

@ -21,3 +21,7 @@ val show : t -> string
val pp : Format.formatter -> t -> unit val pp : Format.formatter -> t -> unit
type nonrec 'a result = ('a, t) result type nonrec 'a result = ('a, t) result
val unwrap : 'a result -> 'a
(** [unwrap (Ok x)] is [x], [unwrap (Error ebt)] re-raises [ebt].
@since NEXT_RELEASE *)

View file

@ -11,7 +11,7 @@ type state = {
threads: Thread.t array; threads: Thread.t array;
q: task_full Bb_queue.t; (** Queue for tasks. *) q: task_full Bb_queue.t; (** Queue for tasks. *)
around_task: WL.around_task; around_task: WL.around_task;
as_runner: t lazy_t; mutable as_runner: t;
(* init options *) (* init options *)
name: string option; name: string option;
on_init_thread: dom_id:int -> t_id:int -> unit -> unit; on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
@ -24,7 +24,6 @@ type worker_state = {
idx: int; idx: int;
dom_idx: int; dom_idx: int;
st: state; st: state;
mutable current: fiber;
} }
let[@inline] size_ (self : state) = Array.length self.threads let[@inline] size_ (self : state) = Array.length self.threads
@ -95,7 +94,7 @@ let cleanup (self : worker_state) : unit =
self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id () self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id ()
let worker_ops : worker_state WL.ops = let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = Lazy.force st.st.as_runner in let runner (st : worker_state) = st.st.as_runner in
let around_task st = st.st.around_task in let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) = let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn ebt.exn ebt.bt st.st.on_exn ebt.exn ebt.bt
@ -111,9 +110,9 @@ let worker_ops : worker_state WL.ops =
cleanup; cleanup;
} }
let create ?(on_init_thread = default_thread_init_exit_) let create_ ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ?num_threads ?name () : t = ?around_task ~threads ?name () : state =
(* wrapper *) (* wrapper *)
let around_task = let around_task =
match around_task with match around_task with
@ -121,6 +120,23 @@ let create ?(on_init_thread = default_thread_init_exit_)
| None -> default_around_task_ | None -> default_around_task_
in in
let self =
{
threads;
q = Bb_queue.create ();
around_task;
as_runner = Runner.dummy;
name;
on_init_thread;
on_exit_thread;
on_exn;
}
in
self.as_runner <- runner_of_state self;
self
let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name () : t =
let num_domains = Domain_pool_.max_number_of_domains () in let num_domains = Domain_pool_.max_number_of_domains () in
(* number of threads to run *) (* number of threads to run *)
@ -129,20 +145,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in let offset = Random.int num_domains in
let rec pool = let pool =
let dummy_thread = Thread.self () in let dummy_thread = Thread.self () in
{ let threads = Array.make num_threads dummy_thread in
threads = Array.make num_threads dummy_thread; create_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ~threads ?name
q = Bb_queue.create (); ()
around_task;
as_runner = lazy (runner_of_state pool);
name;
on_init_thread;
on_exit_thread;
on_exn;
}
in in
let runner = runner_of_state pool in let runner = runner_of_state pool in
(* temporary queue used to obtain thread handles from domains (* temporary queue used to obtain thread handles from domains
@ -156,7 +164,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* function called in domain with index [i], to (* function called in domain with index [i], to
create the thread and push it into [receive_threads] *) create the thread and push it into [receive_threads] *)
let create_thread_in_domain () = let create_thread_in_domain () =
let st = { idx = i; dom_idx; st = pool; current = _dummy_fiber } in let st = { idx = i; dom_idx; st = pool } in
let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in
(* send the thread from the domain back to us *) (* send the thread from the domain back to us *)
Bb_queue.push receive_threads (i, thread) Bb_queue.push receive_threads (i, thread)
@ -187,3 +195,14 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
in in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool f pool
module Private_ = struct
type nonrec worker_state = worker_state
let worker_ops = worker_ops
let runner_of_state (self : worker_state) = worker_ops.runner self
let create_single_threaded_state ~thread ?on_exn () : worker_state =
let st : state = create_ ?on_exn ~threads:[| thread |] () in
{ idx = 0; dom_idx = 0; st }
end

View file

@ -44,3 +44,21 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args
When [f pool] returns or fails, [pool] is shutdown and its resources When [f pool] returns or fails, [pool] is shutdown and its resources
are released. are released.
Most parameters are the same as in {!create}. *) Most parameters are the same as in {!create}. *)
(**/**)
module Private_ : sig
type worker_state
val worker_ops : worker_state Worker_loop_.ops
val create_single_threaded_state :
thread:Thread.t ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
unit ->
worker_state
val runner_of_state : worker_state -> Runner.t
end
(**/**)

View file

@ -1,83 +1,41 @@
open Types_ open Types_
(* module PF = Picos.Fiber
module A = Atomic_
type 'a key = 'a ls_key type 'a t = 'a PF.FLS.t
let key_count_ = A.make 0 exception Not_set = PF.FLS.Not_set
type t = local_storage let create = PF.FLS.create
type ls_value += Dummy
let dummy : t = _dummy_ls let[@inline] get_exn k =
let fiber = get_current_fiber_exn () in
PF.FLS.get_exn fiber k
(** Resize array of TLS values *) let get_opt k =
let[@inline never] resize_ (cur : ls_value array ref) n = match get_current_fiber () with
if n > Sys.max_array_length then failwith "too many task local storage keys";
let len = Array.length !cur in
let new_ls =
Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy
in
Array.blit !cur 0 new_ls 0 len;
cur := new_ls
module Direct = struct
type nonrec t = t
let create = create_local_storage
let[@inline] copy (self : t) = ref (Array.copy !self)
let get (type a) (self : t) ((module K) : a key) : a =
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
match !self.(K.offset) with
| K.V x -> (* common case first *) x
| Dummy ->
(* first time we access this *)
let v = K.init () in
!self.(K.offset) <- K.V v;
v
| _ -> assert false
let set (type a) (self : t) ((module K) : a key) (v : a) : unit =
assert (self != dummy);
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
!self.(K.offset) <- K.V v;
()
end
let new_key (type t) ~init () : t key =
let offset = A.fetch_and_add key_count_ 1 in
(module struct
type nonrec t = t
type ls_value += V of t
let offset = offset
let init = init
end : LS_KEY
with type t = t)
let[@inline] get_cur_ () : ls_value array ref =
match get_current_storage () with
| Some r when r != dummy -> r
| _ -> failwith "Task local storage must be accessed from within a runner."
let[@inline] get (key : 'a key) : 'a =
let cur = get_cur_ () in
Direct.get cur key
let[@inline] get_opt key =
match get_current_storage () with
| None -> None | None -> None
| Some cur -> Some (Direct.get cur key) | Some fiber ->
(match PF.FLS.get_exn fiber k with
| x -> Some x
| exception Not_set -> None)
let[@inline] set key v : unit = let[@inline] get k ~default =
let cur = get_cur_ () in let fiber = get_current_fiber_exn () in
Direct.set cur key v PF.FLS.get fiber ~default k
let with_value key x f = let[@inline] set k v : unit =
let old = get key in let fiber = get_current_fiber_exn () in
set key x; PF.FLS.set fiber k v
Fun.protect ~finally:(fun () -> set key old) f
let get_current = get_current_storage let with_value k v (f : _ -> 'b) : 'b =
*) let fiber = get_current_fiber_exn () in
match PF.FLS.get_exn fiber k with
| exception Not_set ->
PF.FLS.set fiber k v;
(* nothing to restore back to, just call [f] *)
f ()
| old_v ->
PF.FLS.set fiber k v;
let finally () = PF.FLS.set fiber k old_v in
Fun.protect f ~finally

View file

@ -8,62 +8,31 @@
@since 0.6 @since 0.6
*) *)
(* type 'a t = 'a Picos.Fiber.FLS.t
type t = Types_.local_storage
(** Underlying storage for a task. This is mutable and
not thread-safe. *)
val dummy : t val create : unit -> 'a t
(** [create ()] makes a new key. Keys are expensive and
should never be allocated dynamically or in a loop. *)
type 'a key exception Not_set
(** A key used to access a particular (typed) storage slot on every task. *)
val new_key : init:(unit -> 'a) -> unit -> 'a key val get_exn : 'a t -> 'a
(** [new_key ~init ()] makes a new key. Keys are expensive and
should never be allocated dynamically or in a loop.
The correct pattern is, at toplevel:
{[
let k_foo : foo Task_ocal_storage.key =
Task_local_storage.new_key ~init:(fun () -> make_foo ()) ()
(**)
(* use it: *)
let = Task_local_storage.get k_foo
]}
*)
val get : 'a key -> 'a
(** [get k] gets the value for the current task for key [k]. (** [get k] gets the value for the current task for key [k].
Must be run from inside a task running on a runner. Must be run from inside a task running on a runner.
@raise Failure otherwise *) @raise Not_set otherwise *)
val get_opt : 'a key -> 'a option val get_opt : 'a t -> 'a option
(** [get_opt k] gets the current task's value for key [k], (** [get_opt k] gets the current task's value for key [k],
or [None] if not run from inside the task. *) or [None] if not run from inside the task. *)
val set : 'a key -> 'a -> unit val get : 'a t -> default:'a -> 'a
val set : 'a t -> 'a -> unit
(** [set k v] sets the storage for [k] to [v]. (** [set k v] sets the storage for [k] to [v].
Must be run from inside a task running on a runner. Must be run from inside a task running on a runner.
@raise Failure otherwise *) @raise Failure otherwise *)
val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b val with_value : 'a t -> 'a -> (unit -> 'b) -> 'b
(** [with_value k v f] sets [k] to [v] for the duration of the call (** [with_value k v f] sets [k] to [v] for the duration of the call
to [f()]. When [f()] returns (or fails), [k] is restored to [f()]. When [f()] returns (or fails), [k] is restored
to its old value. *) to its old value. *)
val get_current : unit -> t option
(** Access the current storage, or [None] if not run from
within a task. *)
(** Direct access to values from a storage handle *)
module Direct : sig
val get : t -> 'a key -> 'a
(** Access a key *)
val set : t -> 'a key -> 'a -> unit
val create : unit -> t
val copy : t -> t
end
*)

View file

@ -2,3 +2,5 @@
@since NEXT_RELEASE *) @since NEXT_RELEASE *)
include Picos.Trigger include Picos.Trigger
let[@inline] await_exn (self : t) = await self |> Option.iter Exn_bt.raise

View file

@ -25,7 +25,7 @@ type state = {
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *) mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
mutex: Mutex.t; mutex: Mutex.t;
cond: Condition.t; cond: Condition.t;
as_runner: t lazy_t; mutable as_runner: t;
(* init options *) (* init options *)
around_task: WL.around_task; around_task: WL.around_task;
name: string option; name: string option;
@ -167,7 +167,9 @@ and wait_on_worker (self : worker_state) : WL.task_full =
| task -> | task ->
Mutex.unlock self.st.mutex; Mutex.unlock self.st.mutex;
task task
| exception Queue.Empty -> try_steal_from_other_workers_ self | exception Queue.Empty ->
Mutex.unlock self.st.mutex;
try_steal_from_other_workers_ self
) else ( ) else (
(* do nothing more: no task in main queue, and we are shutting (* do nothing more: no task in main queue, and we are shutting
down so no new task should arrive. down so no new task should arrive.
@ -183,8 +185,7 @@ let before_start (self : worker_state) : unit =
let t_id = Thread.id @@ Thread.self () in let t_id = Thread.id @@ Thread.self () in
self.st.on_init_thread ~dom_id:self.dom_id ~t_id (); self.st.on_init_thread ~dom_id:self.dom_id ~t_id ();
TLS.set k_cur_fiber _dummy_fiber; TLS.set k_cur_fiber _dummy_fiber;
TLS.set Runner.For_runner_implementors.k_cur_runner TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner;
(Lazy.force self.st.as_runner);
TLS.set k_worker_state self; TLS.set k_worker_state self;
(* set thread name *) (* set thread name *)
@ -200,7 +201,7 @@ let cleanup (self : worker_state) : unit =
self.st.on_exit_thread ~dom_id:self.dom_id ~t_id () self.st.on_exit_thread ~dom_id:self.dom_id ~t_id ()
let worker_ops : worker_state WL.ops = let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = Lazy.force st.st.as_runner in let runner (st : worker_state) = st.st.as_runner in
let around_task st = st.st.around_task in let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) = let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn ebt.exn ebt.bt st.st.on_exn ebt.exn ebt.bt
@ -261,7 +262,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *) (* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in let offset = Random.int num_domains in
let rec pool = let pool =
{ {
id_ = pool_id_; id_ = pool_id_;
active = A.make true; active = A.make true;
@ -276,18 +277,18 @@ let create ?(on_init_thread = default_thread_init_exit_)
on_init_thread; on_init_thread;
on_exit_thread; on_exit_thread;
name; name;
as_runner = lazy (as_runner_ pool); as_runner = Runner.dummy;
} }
in in
pool.as_runner <- as_runner_ pool;
(* temporary queue used to obtain thread handles from domains (* temporary queue used to obtain thread handles from domains
on which the thread are started. *) on which the thread are started. *)
let receive_threads = Bb_queue.create () in let receive_threads = Bb_queue.create () in
(* start the thread with index [i] *) (* start the thread with index [i] *)
let start_thread_with_idx idx = let create_worker_state idx =
let dom_id = (offset + idx) mod num_domains in let dom_id = (offset + idx) mod num_domains in
let st =
{ {
st = pool; st = pool;
thread = (* dummy *) Thread.self (); thread = (* dummy *) Thread.self ();
@ -298,6 +299,10 @@ let create ?(on_init_thread = default_thread_init_exit_)
} }
in in
pool.workers <- Array.init num_threads create_worker_state;
(* start the thread with index [i] *)
let start_thread_with_idx idx (st : worker_state) =
(* function called in domain with index [i], to (* function called in domain with index [i], to
create the thread and push it into [receive_threads] *) create the thread and push it into [receive_threads] *)
let create_thread_in_domain () = let create_thread_in_domain () =
@ -305,15 +310,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* send the thread from the domain back to us *) (* send the thread from the domain back to us *)
Bb_queue.push receive_threads (idx, thread) Bb_queue.push receive_threads (idx, thread)
in in
Domain_pool_.run_on st.dom_id create_thread_in_domain
Domain_pool_.run_on dom_id create_thread_in_domain;
st
in in
(* start all worker threads, placing them on the domains (* start all worker threads, placing them on the domains
according to their index and [offset] in a round-robin fashion. *) according to their index and [offset] in a round-robin fashion. *)
pool.workers <- Array.init num_threads start_thread_with_idx; Array.iteri start_thread_with_idx pool.workers;
(* receive the newly created threads back from domains *) (* receive the newly created threads back from domains *)
for _j = 1 to num_threads do for _j = 1 to num_threads do
@ -322,7 +324,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
worker_state.thread <- th worker_state.thread <- th
done; done;
Lazy.force pool.as_runner pool.as_runner
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name () f = ?name () f =

View file

@ -2,7 +2,7 @@
(name moonpool_fib) (name moonpool_fib)
(public_name moonpool.fib) (public_name moonpool.fib)
(synopsis "Fibers and structured concurrency for Moonpool") (synopsis "Fibers and structured concurrency for Moonpool")
(libraries moonpool) (libraries moonpool picos)
(enabled_if (enabled_if
(>= %{ocaml_version} 5.0)) (>= %{ocaml_version} 5.0))
(flags :standard -open Moonpool_private -open Moonpool) (flags :standard -open Moonpool_private -open Moonpool)

View file

@ -1,6 +1,9 @@
open Moonpool.Private.Types_
module A = Atomic module A = Atomic
module FM = Handle.Map module FM = Handle.Map
module Int_map = Map.Make (Int) module Int_map = Map.Make (Int)
module PF = Picos.Fiber
module FLS = Picos.Fiber.FLS
type 'a callback = 'a Exn_bt.result -> unit type 'a callback = 'a Exn_bt.result -> unit
(** Callbacks that are called when a fiber is done. *) (** Callbacks that are called when a fiber is done. *)
@ -10,13 +13,16 @@ type cancel_callback = Exn_bt.t -> unit
let prom_of_fut : 'a Fut.t -> 'a Fut.promise = let prom_of_fut : 'a Fut.t -> 'a Fut.promise =
Fut.Private_.unsafe_promise_of_fut Fut.Private_.unsafe_promise_of_fut
(* TODO: replace with picos structured at some point? *)
module Private_ = struct module Private_ = struct
type pfiber = PF.t
type 'a t = { type 'a t = {
id: Handle.t; (** unique identifier for this fiber *) id: Handle.t; (** unique identifier for this fiber *)
state: 'a state A.t; (** Current state in the lifetime of the fiber *) state: 'a state A.t; (** Current state in the lifetime of the fiber *)
res: 'a Fut.t; res: 'a Fut.t;
runner: Runner.t; runner: Runner.t;
ls: Task_local_storage.t; pfiber: pfiber; (** Associated picos fiber *)
} }
and 'a state = and 'a state =
@ -30,11 +36,18 @@ module Private_ = struct
and children = any FM.t and children = any FM.t
and any = Any : _ t -> any [@@unboxed] and any = Any : _ t -> any [@@unboxed]
(** Key to access the current fiber. *) (** Key to access the current moonpool.fiber. *)
let k_current_fiber : any option Task_local_storage.key = let k_current_fiber : any FLS.t = FLS.create ()
Task_local_storage.new_key ~init:(fun () -> None) ()
let[@inline] get_cur () : any option = Task_local_storage.get k_current_fiber exception Not_set = FLS.Not_set
let[@inline] get_cur_from_exn (pfiber : pfiber) : any =
FLS.get_exn pfiber k_current_fiber
let[@inline] get_cur_exn () : any =
get_cur_from_exn @@ get_current_fiber_exn ()
let[@inline] get_cur_opt () = try Some (get_cur_exn ()) with _ -> None
let[@inline] is_closed (self : _ t) = let[@inline] is_closed (self : _ t) =
match A.get self.state with match A.get self.state with
@ -44,9 +57,9 @@ end
include Private_ include Private_
let create_ ~ls ~runner () : 'a t = let create_ ~pfiber ~runner () : 'a t =
let id = Handle.generate_fresh () in let id = Handle.generate_fresh () in
let res, _promise = Fut.make () in let res, _ = Fut.make () in
{ {
state = state =
A.make A.make
@ -54,7 +67,7 @@ let create_ ~ls ~runner () : 'a t =
id; id;
res; res;
runner; runner;
ls; pfiber;
} }
let create_done_ ~res () : _ t = let create_done_ ~res () : _ t =
@ -66,7 +79,7 @@ let create_done_ ~res () : _ t =
id; id;
res; res;
runner = Runner.dummy; runner = Runner.dummy;
ls = Task_local_storage.dummy; pfiber = Moonpool.Private.Types_._dummy_fiber;
} }
let[@inline] return x = create_done_ ~res:(Fut.return x) () let[@inline] return x = create_done_ ~res:(Fut.return x) ()
@ -175,7 +188,8 @@ let with_on_cancel (self : _ t) cb (k : unit -> 'a) : 'a =
let h = add_on_cancel self cb in let h = add_on_cancel self cb in
Fun.protect k ~finally:(fun () -> remove_on_cancel self h) Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
(** Successfully resolve the fiber *) (** Successfully resolve the fiber. This might still fail if
some children failed. *)
let resolve_ok_ (self : 'a t) (r : 'a) : unit = let resolve_ok_ (self : 'a t) (r : 'a) : unit =
let r = A.make @@ Ok r in let r = A.make @@ Ok r in
let promise = prom_of_fut self.res in let promise = prom_of_fut self.res in
@ -239,15 +253,21 @@ let add_child_ ~protect (self : _ t) (child : _ t) =
() ()
done done
let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t = let spawn_ ~parent ~runner (f : unit -> 'a) : 'a t =
let comp = Picos.Computation.create () in
let pfiber = PF.create ~forbid:false comp in
(* inherit FLS from parent, if present *)
Option.iter (fun (p : _ t) -> PF.copy_fls p.pfiber pfiber) parent;
(match parent with (match parent with
| Some p when is_closed p -> failwith "spawn: nursery is closed" | Some p when is_closed p -> failwith "spawn: nursery is closed"
| _ -> ()); | _ -> ());
let fib = create_ ~ls ~runner () in let fib = create_ ~pfiber ~runner () in
let run () = let run () =
(* make sure the fiber is accessible from inside itself *) (* make sure the fiber is accessible from inside itself *)
Task_local_storage.set k_current_fiber (Some (Any fib)); FLS.set pfiber k_current_fiber (Any fib);
try try
let res = f () in let res = f () in
resolve_ok_ fib res resolve_ok_ fib res
@ -257,63 +277,54 @@ let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t =
resolve_as_failed_ fib ebt resolve_as_failed_ fib ebt
in in
Runner.run_async ~ls runner run; Runner.run_async ~fiber:pfiber runner run;
fib fib
let spawn_top ~on f : _ t = let spawn_top ~on f : _ t = spawn_ ~runner:on ~parent:None f
let ls = Task_local_storage.Direct.create () in
spawn_ ~ls ~runner:on ~parent:None f
let spawn ?on ?(protect = true) f : _ t = let spawn ?on ?(protect = true) f : _ t =
(* spawn [f()] with a copy of our local storage *) (* spawn [f()] with a copy of our local storage *)
let (Any p) = let (Any p) =
match get_cur () with try get_cur_exn ()
| None -> failwith "Fiber.spawn: must be run from within another fiber." with Not_set ->
| Some p -> p failwith "Fiber.spawn: must be run from within another fiber."
in in
let ls = Task_local_storage.Direct.copy p.ls in
let runner = let runner =
match on with match on with
| Some r -> r | Some r -> r
| None -> p.runner | None -> p.runner
in in
let child = spawn_ ~ls ~parent:(Some p) ~runner f in let child = spawn_ ~parent:(Some p) ~runner f in
add_child_ ~protect p child; add_child_ ~protect p child;
child child
let[@inline] spawn_ignore ?protect f : unit = ignore (spawn ?protect f : _ t) let[@inline] spawn_ignore ?protect f : unit = ignore (spawn ?protect f : _ t)
let[@inline] self () : any = let[@inline] self () : any =
match Task_local_storage.get k_current_fiber with match get_cur_exn () with
| None -> failwith "Fiber.self: must be run from inside a fiber." | exception Not_set -> failwith "Fiber.self: must be run from inside a fiber."
| Some f -> f | f -> f
let with_on_self_cancel cb (k : unit -> 'a) : 'a = let with_on_self_cancel cb (k : unit -> 'a) : 'a =
let (Any self) = self () in let (Any self) = self () in
let h = add_on_cancel self cb in let h = add_on_cancel self cb in
Fun.protect k ~finally:(fun () -> remove_on_cancel self h) Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
module Suspend_ = Moonpool.Private.Suspend_ let[@inline] check_if_cancelled_ (self : _ t) = PF.check self.pfiber
let check_if_cancelled_ (self : _ t) =
match A.get self.state with
| Terminating_or_done r ->
(match A.get r with
| Error ebt -> Exn_bt.raise ebt
| _ -> ())
| _ -> ()
let check_if_cancelled () = let check_if_cancelled () =
match Task_local_storage.get k_current_fiber with match get_cur_exn () with
| None -> | exception Not_set ->
failwith "Fiber.check_if_cancelled: must be run from inside a fiber." failwith "Fiber.check_if_cancelled: must be run from inside a fiber."
| Some (Any self) -> check_if_cancelled_ self | Any self -> check_if_cancelled_ self
let yield () : unit = let yield () : unit =
match Task_local_storage.get k_current_fiber with match get_cur_exn () with
| None -> failwith "Fiber.yield: must be run from inside a fiber." | exception Not_set ->
| Some (Any self) -> failwith "Fiber.yield: must be run from inside a fiber."
| Any self ->
check_if_cancelled_ self; check_if_cancelled_ self;
Suspend_.yield (); PF.yield ();
check_if_cancelled_ self check_if_cancelled_ self

View file

@ -17,20 +17,27 @@ type cancel_callback = Exn_bt.t -> unit
(** Do not rely on this, it is internal implementation details. *) (** Do not rely on this, it is internal implementation details. *)
module Private_ : sig module Private_ : sig
type 'a state type 'a state
type pfiber
type 'a t = private { type 'a t = private {
id: Handle.t; (** unique identifier for this fiber *) id: Handle.t; (** unique identifier for this fiber *)
state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *) state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *)
res: 'a Fut.t; res: 'a Fut.t;
runner: Runner.t; runner: Runner.t;
ls: Task_local_storage.t; pfiber: pfiber;
} }
(** Type definition, exposed so that {!any} can be unboxed. (** Type definition, exposed so that {!any} can be unboxed.
Please do not rely on that. *) Please do not rely on that. *)
type any = Any : _ t -> any [@@unboxed] type any = Any : _ t -> any [@@unboxed]
val get_cur : unit -> any option exception Not_set
val get_cur_exn : unit -> any
(** [get_cur_exn ()] either returns the current fiber, or
@raise Not_set if run outside a fiber. *)
val get_cur_opt : unit -> any option
end end
(**/**) (**/**)

View file

@ -1,14 +1,20 @@
exception Oh_no of Exn_bt.t exception Oh_no of Exn_bt.t
let main (f : Runner.t -> 'a) : 'a = let main (f : Runner.t -> 'a) : 'a =
let st = Fifo_pool.Private_.create_state ~threads:[| Thread.self () |] () in let worker_st =
let runner = Fifo_pool.Private_.runner_of_state st in Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ())
~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt)))
()
in
let runner = Fifo_pool.Private_.runner_of_state worker_st in
try try
let fiber = Fiber.spawn_top ~on:runner (fun () -> f runner) in let fiber = Fiber.spawn_top ~on:runner (fun () -> f runner) in
Fiber.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner); Fiber.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner);
(* run the main thread *) (* run the main thread *)
Fifo_pool.Private_.run_thread st runner ~on_exn:(fun e bt -> Moonpool.Private.Worker_loop_.worker_loop worker_st
raise (Oh_no (Exn_bt.make e bt))); ~ops:Fifo_pool.Private_.worker_ops;
match Fiber.peek fiber with match Fiber.peek fiber with
| Some (Ok x) -> x | Some (Ok x) -> x
| Some (Error ebt) -> Exn_bt.raise ebt | Some (Error ebt) -> Exn_bt.raise ebt

View file

@ -6,4 +6,4 @@
(optional) (optional)
(enabled_if (enabled_if
(>= %{ocaml_version} 5.0)) (>= %{ocaml_version} 5.0))
(libraries moonpool moonpool.private)) (libraries moonpool moonpool.private picos))

View file

@ -1,5 +1,4 @@
module A = Moonpool.Atomic module A = Moonpool.Atomic
module Suspend_ = Moonpool.Private.Suspend_
module Domain_ = Moonpool_private.Domain_ module Domain_ = Moonpool_private.Domain_
module State_ = struct module State_ = struct
@ -9,7 +8,7 @@ module State_ = struct
type ('a, 'b) t = type ('a, 'b) t =
| Init | Init
| Left_solved of 'a or_error | Left_solved of 'a or_error
| Right_solved of 'b or_error * Suspend_.suspension | Right_solved of 'b or_error * Trigger.t
| Both_solved of 'a or_error * 'b or_error | Both_solved of 'a or_error * 'b or_error
let get_exn_ (self : _ t A.t) = let get_exn_ (self : _ t A.t) =
@ -28,13 +27,13 @@ module State_ = struct
Domain_.relax (); Domain_.relax ();
set_left_ self left set_left_ self left
) )
| Right_solved (right, cont) -> | Right_solved (right, tr) ->
let new_st = Both_solved (left, right) in let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then ( if not (A.compare_and_set self old_st new_st) then (
Domain_.relax (); Domain_.relax ();
set_left_ self left set_left_ self left
) else ) else
cont (Ok ()) Trigger.signal tr
| Left_solved _ | Both_solved _ -> assert false | Left_solved _ | Both_solved _ -> assert false
let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit = let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
@ -45,27 +44,27 @@ module State_ = struct
if not (A.compare_and_set self old_st new_st) then set_right_ self right if not (A.compare_and_set self old_st new_st) then set_right_ self right
| Init -> | Init ->
(* we are first arrived, we suspend until the left computation is done *) (* we are first arrived, we suspend until the left computation is done *)
Suspend_.suspend let trigger = Trigger.create () in
{ let must_await = ref true in
Suspend_.handle =
(fun ~run:_ ~resume suspension ->
while while
let old_st = A.get self in let old_st = A.get self in
match old_st with match old_st with
| Init -> | Init ->
not (* setup trigger so that left computation will wake us up *)
(A.compare_and_set self old_st not (A.compare_and_set self old_st (Right_solved (right, trigger)))
(Right_solved (right, suspension)))
| Left_solved left -> | Left_solved left ->
(* other thread is done, no risk of race condition *) (* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right)); A.set self (Both_solved (left, right));
resume suspension (Ok ()); must_await := false;
false false
| Right_solved _ | Both_solved _ -> assert false | Right_solved _ | Both_solved _ -> assert false
do do
() ()
done); done;
}
(* wait for the other computation to be done *)
if !must_await then Trigger.await trigger |> Option.iter Exn_bt.raise
| Right_solved _ | Both_solved _ -> assert false | Right_solved _ | Both_solved _ -> assert false
end end
@ -102,7 +101,12 @@ let both_ignore f g = ignore (both f g : _ * _)
let for_ ?chunk_size n (f : int -> int -> unit) : unit = let for_ ?chunk_size n (f : int -> int -> unit) : unit =
if n > 0 then ( if n > 0 then (
let has_failed = A.make false in let runner =
match Runner.get_current_runner () with
| None -> failwith "forkjoin.for_: must be run inside a moonpool runner."
| Some r -> r
in
let failure = A.make None in
let missing = A.make n in let missing = A.make n in
let chunk_size = let chunk_size =
@ -113,19 +117,20 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
max 1 (1 + (n / Moonpool.Private.num_domains ())) max 1 (1 + (n / Moonpool.Private.num_domains ()))
in in
let start_tasks ~run ~resume (suspension : Suspend_.suspension) = let trigger = Trigger.create () in
let task_for ~offset ~len_range = let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with match f offset (offset + len_range - 1) with
| () -> | () ->
if A.fetch_and_add missing (-len_range) = len_range then if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *) (* all tasks done successfully *)
resume suspension (Ok ()) Trigger.signal trigger
| exception exn -> | exception exn ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then if Option.is_none (A.exchange failure (Some { Exn_bt.exn; bt })) then
(* first one to fail, and [missing] must be >= 2 (* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *) because we're not decreasing it. *)
resume suspension (Error { Exn_bt.exn; bt }) Trigger.signal trigger
in in
let i = ref 0 in let i = ref 0 in
@ -135,18 +140,13 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
let len_range = min chunk_size (n - offset) in let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n); assert (offset + len_range <= n);
run (fun () -> task_for ~offset ~len_range); Runner.run_async runner (fun () -> task_for ~offset ~len_range);
i := !i + len_range i := !i + len_range
done done;
in
Suspend_.suspend Trigger.await trigger |> Option.iter Exn_bt.raise;
{ Option.iter Exn_bt.raise @@ A.get failure;
Suspend_.handle = ()
(fun ~run ~resume suspension ->
(* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~resume suspension);
}
) )
let all_array ?chunk_size (fs : _ array) : _ array = let all_array ?chunk_size (fs : _ array) : _ array =

View file

@ -1,17 +1,14 @@
open Base open Base
let await_readable fd : unit = let await_readable fd : unit =
Moonpool.Private.Suspend_.suspend let trigger = Trigger.create () in
{
handle =
(fun ~run:_ ~resume sus ->
Perform_action_in_lwt.schedule Perform_action_in_lwt.schedule
@@ Action.Wait_readable @@ Action.Wait_readable
( fd, ( fd,
fun cancel -> fun cancel ->
resume sus @@ Ok (); Trigger.signal trigger;
Lwt_engine.stop_event cancel )); Lwt_engine.stop_event cancel );
} Trigger.await_exn trigger
let rec read fd buf i len : int = let rec read fd buf i len : int =
if len = 0 then if len = 0 then
@ -25,17 +22,14 @@ let rec read fd buf i len : int =
) )
let await_writable fd = let await_writable fd =
Moonpool.Private.Suspend_.suspend let trigger = Trigger.create () in
{
handle =
(fun ~run:_ ~resume sus ->
Perform_action_in_lwt.schedule Perform_action_in_lwt.schedule
@@ Action.Wait_writable @@ Action.Wait_writable
( fd, ( fd,
fun cancel -> fun cancel ->
resume sus @@ Ok (); Trigger.signal trigger;
Lwt_engine.stop_event cancel )); Lwt_engine.stop_event cancel );
} Trigger.await_exn trigger
let rec write_once fd buf i len : int = let rec write_once fd buf i len : int =
if len = 0 then if len = 0 then
@ -59,16 +53,14 @@ let write fd buf i len : unit =
(** Sleep for the given amount of seconds *) (** Sleep for the given amount of seconds *)
let sleep_s (f : float) : unit = let sleep_s (f : float) : unit =
if f > 0. then if f > 0. then (
Moonpool.Private.Suspend_.suspend let trigger = Trigger.create () in
{
handle =
(fun ~run:_ ~resume sus ->
Perform_action_in_lwt.schedule Perform_action_in_lwt.schedule
@@ Action.Sleep @@ Action.Sleep
( f, ( f,
false, false,
fun cancel -> fun cancel ->
resume sus @@ Ok (); Trigger.signal trigger;
Lwt_engine.stop_event cancel )); Lwt_engine.stop_event cancel );
} Trigger.await_exn trigger
)

View file

@ -1,4 +1,5 @@
open Common_ open Common_
module Trigger = M.Trigger
module Fiber = Moonpool_fib.Fiber module Fiber = Moonpool_fib.Fiber
module FLS = Moonpool_fib.Fls module FLS = Moonpool_fib.Fls
@ -14,7 +15,7 @@ module Action = struct
| Sleep of float * bool * cb | Sleep of float * bool * cb
(* TODO: provide actions with cancellation, alongside a "select" operation *) (* TODO: provide actions with cancellation, alongside a "select" operation *)
(* | Cancel of event *) (* | Cancel of event *)
| On_termination : 'a Lwt.t * ('a Exn_bt.result -> unit) -> t | On_termination : 'a Lwt.t * 'a Exn_bt.result ref * Trigger.t -> t
| Wakeup : 'a Lwt.u * 'a -> t | Wakeup : 'a Lwt.u * 'a -> t
| Wakeup_exn : _ Lwt.u * exn -> t | Wakeup_exn : _ Lwt.u * exn -> t
| Other of (unit -> unit) | Other of (unit -> unit)
@ -26,10 +27,14 @@ module Action = struct
| Wait_writable (fd, cb) -> ignore (Lwt_engine.on_writable fd cb : event) | Wait_writable (fd, cb) -> ignore (Lwt_engine.on_writable fd cb : event)
| Sleep (f, repeat, cb) -> ignore (Lwt_engine.on_timer f repeat cb : event) | Sleep (f, repeat, cb) -> ignore (Lwt_engine.on_timer f repeat cb : event)
(* | Cancel ev -> Lwt_engine.stop_event ev *) (* | Cancel ev -> Lwt_engine.stop_event ev *)
| On_termination (fut, f) -> | On_termination (fut, res, trigger) ->
Lwt.on_any fut Lwt.on_any fut
(fun x -> f @@ Ok x) (fun x ->
(fun exn -> f @@ Error (Exn_bt.get_callstack 10 exn)) res := Ok x;
Trigger.signal trigger)
(fun exn ->
res := Error (Exn_bt.get_callstack 10 exn);
Trigger.signal trigger)
| Wakeup (prom, x) -> Lwt.wakeup prom x | Wakeup (prom, x) -> Lwt.wakeup prom x
| Wakeup_exn (prom, e) -> Lwt.wakeup_exn prom e | Wakeup_exn (prom, e) -> Lwt.wakeup_exn prom e
| Other f -> f () | Other f -> f ()
@ -106,23 +111,19 @@ let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t =
M.Fut.fulfill prom (Error { Exn_bt.exn; bt })); M.Fut.fulfill prom (Error { Exn_bt.exn; bt }));
fut fut
let _dummy_exn_bt : Exn_bt.t =
Exn_bt.get_callstack 0 (Failure "dummy Exn_bt from moonpool-lwt")
let await_lwt (fut : _ Lwt.t) = let await_lwt (fut : _ Lwt.t) =
match Lwt.poll fut with match Lwt.poll fut with
| Some x -> x | Some x -> x
| None -> | None ->
(* suspend fiber, wake it up when [fut] resolves *) (* suspend fiber, wake it up when [fut] resolves *)
M.Private.Suspend_.suspend let trigger = M.Trigger.create () in
{ let res = ref (Error _dummy_exn_bt) in
handle = Perform_action_in_lwt.(schedule Action.(On_termination (fut, res, trigger)));
(fun ~run:_ ~resume sus -> Trigger.await trigger |> Option.iter Exn_bt.raise;
let on_lwt_done _ = resume sus @@ Ok () in Exn_bt.unwrap !res
Perform_action_in_lwt.(
schedule Action.(On_termination (fut, on_lwt_done))));
};
(match Lwt.poll fut with
| Some x -> x
| None -> assert false)
let run_in_lwt f : _ M.Fut.t = let run_in_lwt f : _ M.Fut.t =
let fut, prom = M.Fut.make () in let fut, prom = M.Fut.make () in

View file

@ -4,4 +4,9 @@
(private_modules common_) (private_modules common_)
(enabled_if (enabled_if
(>= %{ocaml_version} 5.0)) (>= %{ocaml_version} 5.0))
(libraries moonpool moonpool.fib lwt lwt.unix)) (libraries
(re_export moonpool)
(re_export moonpool.fib)
picos
(re_export lwt)
lwt.unix))