mirror of
https://github.com/c-cube/moonpool.git
synced 2025-12-06 03:05:30 -05:00
wip: port to picos
This commit is contained in:
parent
a0068b09b3
commit
07a7fc3a1c
17 changed files with 314 additions and 316 deletions
|
|
@ -7,3 +7,7 @@ let show self = Printexc.to_string (exn self)
|
|||
let pp out self = Format.pp_print_string out (show self)
|
||||
|
||||
type nonrec 'a result = ('a, t) result
|
||||
|
||||
let[@inline] unwrap = function
|
||||
| Ok x -> x
|
||||
| Error ebt -> raise ebt
|
||||
|
|
|
|||
|
|
@ -21,3 +21,7 @@ val show : t -> string
|
|||
val pp : Format.formatter -> t -> unit
|
||||
|
||||
type nonrec 'a result = ('a, t) result
|
||||
|
||||
val unwrap : 'a result -> 'a
|
||||
(** [unwrap (Ok x)] is [x], [unwrap (Error ebt)] re-raises [ebt].
|
||||
@since NEXT_RELEASE *)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ type state = {
|
|||
threads: Thread.t array;
|
||||
q: task_full Bb_queue.t; (** Queue for tasks. *)
|
||||
around_task: WL.around_task;
|
||||
as_runner: t lazy_t;
|
||||
mutable as_runner: t;
|
||||
(* init options *)
|
||||
name: string option;
|
||||
on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
|
||||
|
|
@ -24,7 +24,6 @@ type worker_state = {
|
|||
idx: int;
|
||||
dom_idx: int;
|
||||
st: state;
|
||||
mutable current: fiber;
|
||||
}
|
||||
|
||||
let[@inline] size_ (self : state) = Array.length self.threads
|
||||
|
|
@ -95,7 +94,7 @@ let cleanup (self : worker_state) : unit =
|
|||
self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id ()
|
||||
|
||||
let worker_ops : worker_state WL.ops =
|
||||
let runner (st : worker_state) = Lazy.force st.st.as_runner in
|
||||
let runner (st : worker_state) = st.st.as_runner in
|
||||
let around_task st = st.st.around_task in
|
||||
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
|
||||
st.st.on_exn ebt.exn ebt.bt
|
||||
|
|
@ -111,9 +110,9 @@ let worker_ops : worker_state WL.ops =
|
|||
cleanup;
|
||||
}
|
||||
|
||||
let create ?(on_init_thread = default_thread_init_exit_)
|
||||
let create_ ?(on_init_thread = default_thread_init_exit_)
|
||||
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
|
||||
?around_task ?num_threads ?name () : t =
|
||||
?around_task ~threads ?name () : state =
|
||||
(* wrapper *)
|
||||
let around_task =
|
||||
match around_task with
|
||||
|
|
@ -121,6 +120,23 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
| None -> default_around_task_
|
||||
in
|
||||
|
||||
let self =
|
||||
{
|
||||
threads;
|
||||
q = Bb_queue.create ();
|
||||
around_task;
|
||||
as_runner = Runner.dummy;
|
||||
name;
|
||||
on_init_thread;
|
||||
on_exit_thread;
|
||||
on_exn;
|
||||
}
|
||||
in
|
||||
self.as_runner <- runner_of_state self;
|
||||
self
|
||||
|
||||
let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
|
||||
?name () : t =
|
||||
let num_domains = Domain_pool_.max_number_of_domains () in
|
||||
|
||||
(* number of threads to run *)
|
||||
|
|
@ -129,20 +145,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
|
||||
let offset = Random.int num_domains in
|
||||
|
||||
let rec pool =
|
||||
let pool =
|
||||
let dummy_thread = Thread.self () in
|
||||
{
|
||||
threads = Array.make num_threads dummy_thread;
|
||||
q = Bb_queue.create ();
|
||||
around_task;
|
||||
as_runner = lazy (runner_of_state pool);
|
||||
name;
|
||||
on_init_thread;
|
||||
on_exit_thread;
|
||||
on_exn;
|
||||
}
|
||||
let threads = Array.make num_threads dummy_thread in
|
||||
create_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ~threads ?name
|
||||
()
|
||||
in
|
||||
|
||||
let runner = runner_of_state pool in
|
||||
|
||||
(* temporary queue used to obtain thread handles from domains
|
||||
|
|
@ -156,7 +164,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* function called in domain with index [i], to
|
||||
create the thread and push it into [receive_threads] *)
|
||||
let create_thread_in_domain () =
|
||||
let st = { idx = i; dom_idx; st = pool; current = _dummy_fiber } in
|
||||
let st = { idx = i; dom_idx; st = pool } in
|
||||
let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in
|
||||
(* send the thread from the domain back to us *)
|
||||
Bb_queue.push receive_threads (i, thread)
|
||||
|
|
@ -187,3 +195,14 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
|
|||
in
|
||||
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
|
||||
f pool
|
||||
|
||||
module Private_ = struct
|
||||
type nonrec worker_state = worker_state
|
||||
|
||||
let worker_ops = worker_ops
|
||||
let runner_of_state (self : worker_state) = worker_ops.runner self
|
||||
|
||||
let create_single_threaded_state ~thread ?on_exn () : worker_state =
|
||||
let st : state = create_ ?on_exn ~threads:[| thread |] () in
|
||||
{ idx = 0; dom_idx = 0; st }
|
||||
end
|
||||
|
|
|
|||
|
|
@ -44,3 +44,21 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args
|
|||
When [f pool] returns or fails, [pool] is shutdown and its resources
|
||||
are released.
|
||||
Most parameters are the same as in {!create}. *)
|
||||
|
||||
(**/**)
|
||||
|
||||
module Private_ : sig
|
||||
type worker_state
|
||||
|
||||
val worker_ops : worker_state Worker_loop_.ops
|
||||
|
||||
val create_single_threaded_state :
|
||||
thread:Thread.t ->
|
||||
?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
|
||||
unit ->
|
||||
worker_state
|
||||
|
||||
val runner_of_state : worker_state -> Runner.t
|
||||
end
|
||||
|
||||
(**/**)
|
||||
|
|
|
|||
|
|
@ -1,83 +1,41 @@
|
|||
open Types_
|
||||
(*
|
||||
module A = Atomic_
|
||||
module PF = Picos.Fiber
|
||||
|
||||
type 'a key = 'a ls_key
|
||||
type 'a t = 'a PF.FLS.t
|
||||
|
||||
let key_count_ = A.make 0
|
||||
exception Not_set = PF.FLS.Not_set
|
||||
|
||||
type t = local_storage
|
||||
type ls_value += Dummy
|
||||
let create = PF.FLS.create
|
||||
|
||||
let dummy : t = _dummy_ls
|
||||
let[@inline] get_exn k =
|
||||
let fiber = get_current_fiber_exn () in
|
||||
PF.FLS.get_exn fiber k
|
||||
|
||||
(** Resize array of TLS values *)
|
||||
let[@inline never] resize_ (cur : ls_value array ref) n =
|
||||
if n > Sys.max_array_length then failwith "too many task local storage keys";
|
||||
let len = Array.length !cur in
|
||||
let new_ls =
|
||||
Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy
|
||||
in
|
||||
Array.blit !cur 0 new_ls 0 len;
|
||||
cur := new_ls
|
||||
|
||||
module Direct = struct
|
||||
type nonrec t = t
|
||||
|
||||
let create = create_local_storage
|
||||
let[@inline] copy (self : t) = ref (Array.copy !self)
|
||||
|
||||
let get (type a) (self : t) ((module K) : a key) : a =
|
||||
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
|
||||
match !self.(K.offset) with
|
||||
| K.V x -> (* common case first *) x
|
||||
| Dummy ->
|
||||
(* first time we access this *)
|
||||
let v = K.init () in
|
||||
!self.(K.offset) <- K.V v;
|
||||
v
|
||||
| _ -> assert false
|
||||
|
||||
let set (type a) (self : t) ((module K) : a key) (v : a) : unit =
|
||||
assert (self != dummy);
|
||||
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
|
||||
!self.(K.offset) <- K.V v;
|
||||
()
|
||||
end
|
||||
|
||||
let new_key (type t) ~init () : t key =
|
||||
let offset = A.fetch_and_add key_count_ 1 in
|
||||
(module struct
|
||||
type nonrec t = t
|
||||
type ls_value += V of t
|
||||
|
||||
let offset = offset
|
||||
let init = init
|
||||
end : LS_KEY
|
||||
with type t = t)
|
||||
|
||||
let[@inline] get_cur_ () : ls_value array ref =
|
||||
match get_current_storage () with
|
||||
| Some r when r != dummy -> r
|
||||
| _ -> failwith "Task local storage must be accessed from within a runner."
|
||||
|
||||
let[@inline] get (key : 'a key) : 'a =
|
||||
let cur = get_cur_ () in
|
||||
Direct.get cur key
|
||||
|
||||
let[@inline] get_opt key =
|
||||
match get_current_storage () with
|
||||
let get_opt k =
|
||||
match get_current_fiber () with
|
||||
| None -> None
|
||||
| Some cur -> Some (Direct.get cur key)
|
||||
| Some fiber ->
|
||||
(match PF.FLS.get_exn fiber k with
|
||||
| x -> Some x
|
||||
| exception Not_set -> None)
|
||||
|
||||
let[@inline] set key v : unit =
|
||||
let cur = get_cur_ () in
|
||||
Direct.set cur key v
|
||||
let[@inline] get k ~default =
|
||||
let fiber = get_current_fiber_exn () in
|
||||
PF.FLS.get fiber ~default k
|
||||
|
||||
let with_value key x f =
|
||||
let old = get key in
|
||||
set key x;
|
||||
Fun.protect ~finally:(fun () -> set key old) f
|
||||
let[@inline] set k v : unit =
|
||||
let fiber = get_current_fiber_exn () in
|
||||
PF.FLS.set fiber k v
|
||||
|
||||
let get_current = get_current_storage
|
||||
*)
|
||||
let with_value k v (f : _ -> 'b) : 'b =
|
||||
let fiber = get_current_fiber_exn () in
|
||||
|
||||
match PF.FLS.get_exn fiber k with
|
||||
| exception Not_set ->
|
||||
PF.FLS.set fiber k v;
|
||||
(* nothing to restore back to, just call [f] *)
|
||||
f ()
|
||||
| old_v ->
|
||||
PF.FLS.set fiber k v;
|
||||
let finally () = PF.FLS.set fiber k old_v in
|
||||
Fun.protect f ~finally
|
||||
|
|
|
|||
|
|
@ -8,62 +8,31 @@
|
|||
@since 0.6
|
||||
*)
|
||||
|
||||
(*
|
||||
type t = Types_.local_storage
|
||||
(** Underlying storage for a task. This is mutable and
|
||||
not thread-safe. *)
|
||||
type 'a t = 'a Picos.Fiber.FLS.t
|
||||
|
||||
val dummy : t
|
||||
val create : unit -> 'a t
|
||||
(** [create ()] makes a new key. Keys are expensive and
|
||||
should never be allocated dynamically or in a loop. *)
|
||||
|
||||
type 'a key
|
||||
(** A key used to access a particular (typed) storage slot on every task. *)
|
||||
exception Not_set
|
||||
|
||||
val new_key : init:(unit -> 'a) -> unit -> 'a key
|
||||
(** [new_key ~init ()] makes a new key. Keys are expensive and
|
||||
should never be allocated dynamically or in a loop.
|
||||
The correct pattern is, at toplevel:
|
||||
|
||||
{[
|
||||
let k_foo : foo Task_ocal_storage.key =
|
||||
Task_local_storage.new_key ~init:(fun () -> make_foo ()) ()
|
||||
|
||||
(* … *)
|
||||
|
||||
(* use it: *)
|
||||
let … = Task_local_storage.get k_foo
|
||||
]}
|
||||
*)
|
||||
|
||||
val get : 'a key -> 'a
|
||||
val get_exn : 'a t -> 'a
|
||||
(** [get k] gets the value for the current task for key [k].
|
||||
Must be run from inside a task running on a runner.
|
||||
@raise Failure otherwise *)
|
||||
@raise Not_set otherwise *)
|
||||
|
||||
val get_opt : 'a key -> 'a option
|
||||
val get_opt : 'a t -> 'a option
|
||||
(** [get_opt k] gets the current task's value for key [k],
|
||||
or [None] if not run from inside the task. *)
|
||||
|
||||
val set : 'a key -> 'a -> unit
|
||||
val get : 'a t -> default:'a -> 'a
|
||||
|
||||
val set : 'a t -> 'a -> unit
|
||||
(** [set k v] sets the storage for [k] to [v].
|
||||
Must be run from inside a task running on a runner.
|
||||
@raise Failure otherwise *)
|
||||
|
||||
val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b
|
||||
val with_value : 'a t -> 'a -> (unit -> 'b) -> 'b
|
||||
(** [with_value k v f] sets [k] to [v] for the duration of the call
|
||||
to [f()]. When [f()] returns (or fails), [k] is restored
|
||||
to its old value. *)
|
||||
|
||||
val get_current : unit -> t option
|
||||
(** Access the current storage, or [None] if not run from
|
||||
within a task. *)
|
||||
|
||||
(** Direct access to values from a storage handle *)
|
||||
module Direct : sig
|
||||
val get : t -> 'a key -> 'a
|
||||
(** Access a key *)
|
||||
|
||||
val set : t -> 'a key -> 'a -> unit
|
||||
val create : unit -> t
|
||||
val copy : t -> t
|
||||
end
|
||||
*)
|
||||
|
|
|
|||
|
|
@ -2,3 +2,5 @@
|
|||
@since NEXT_RELEASE *)
|
||||
|
||||
include Picos.Trigger
|
||||
|
||||
let[@inline] await_exn (self : t) = await self |> Option.iter Exn_bt.raise
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ type state = {
|
|||
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
|
||||
mutex: Mutex.t;
|
||||
cond: Condition.t;
|
||||
as_runner: t lazy_t;
|
||||
mutable as_runner: t;
|
||||
(* init options *)
|
||||
around_task: WL.around_task;
|
||||
name: string option;
|
||||
|
|
@ -167,7 +167,9 @@ and wait_on_worker (self : worker_state) : WL.task_full =
|
|||
| task ->
|
||||
Mutex.unlock self.st.mutex;
|
||||
task
|
||||
| exception Queue.Empty -> try_steal_from_other_workers_ self
|
||||
| exception Queue.Empty ->
|
||||
Mutex.unlock self.st.mutex;
|
||||
try_steal_from_other_workers_ self
|
||||
) else (
|
||||
(* do nothing more: no task in main queue, and we are shutting
|
||||
down so no new task should arrive.
|
||||
|
|
@ -183,8 +185,7 @@ let before_start (self : worker_state) : unit =
|
|||
let t_id = Thread.id @@ Thread.self () in
|
||||
self.st.on_init_thread ~dom_id:self.dom_id ~t_id ();
|
||||
TLS.set k_cur_fiber _dummy_fiber;
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner
|
||||
(Lazy.force self.st.as_runner);
|
||||
TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner;
|
||||
TLS.set k_worker_state self;
|
||||
|
||||
(* set thread name *)
|
||||
|
|
@ -200,7 +201,7 @@ let cleanup (self : worker_state) : unit =
|
|||
self.st.on_exit_thread ~dom_id:self.dom_id ~t_id ()
|
||||
|
||||
let worker_ops : worker_state WL.ops =
|
||||
let runner (st : worker_state) = Lazy.force st.st.as_runner in
|
||||
let runner (st : worker_state) = st.st.as_runner in
|
||||
let around_task st = st.st.around_task in
|
||||
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
|
||||
st.st.on_exn ebt.exn ebt.bt
|
||||
|
|
@ -261,7 +262,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
|
||||
let offset = Random.int num_domains in
|
||||
|
||||
let rec pool =
|
||||
let pool =
|
||||
{
|
||||
id_ = pool_id_;
|
||||
active = A.make true;
|
||||
|
|
@ -276,28 +277,32 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
on_init_thread;
|
||||
on_exit_thread;
|
||||
name;
|
||||
as_runner = lazy (as_runner_ pool);
|
||||
as_runner = Runner.dummy;
|
||||
}
|
||||
in
|
||||
pool.as_runner <- as_runner_ pool;
|
||||
|
||||
(* temporary queue used to obtain thread handles from domains
|
||||
on which the thread are started. *)
|
||||
let receive_threads = Bb_queue.create () in
|
||||
|
||||
(* start the thread with index [i] *)
|
||||
let start_thread_with_idx idx =
|
||||
let create_worker_state idx =
|
||||
let dom_id = (offset + idx) mod num_domains in
|
||||
let st =
|
||||
{
|
||||
st = pool;
|
||||
thread = (* dummy *) Thread.self ();
|
||||
q = WSQ.create ~dummy:WL._dummy_task ();
|
||||
rng = Random.State.make [| idx |];
|
||||
dom_id;
|
||||
idx;
|
||||
}
|
||||
in
|
||||
{
|
||||
st = pool;
|
||||
thread = (* dummy *) Thread.self ();
|
||||
q = WSQ.create ~dummy:WL._dummy_task ();
|
||||
rng = Random.State.make [| idx |];
|
||||
dom_id;
|
||||
idx;
|
||||
}
|
||||
in
|
||||
|
||||
pool.workers <- Array.init num_threads create_worker_state;
|
||||
|
||||
(* start the thread with index [i] *)
|
||||
let start_thread_with_idx idx (st : worker_state) =
|
||||
(* function called in domain with index [i], to
|
||||
create the thread and push it into [receive_threads] *)
|
||||
let create_thread_in_domain () =
|
||||
|
|
@ -305,15 +310,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
(* send the thread from the domain back to us *)
|
||||
Bb_queue.push receive_threads (idx, thread)
|
||||
in
|
||||
|
||||
Domain_pool_.run_on dom_id create_thread_in_domain;
|
||||
|
||||
st
|
||||
Domain_pool_.run_on st.dom_id create_thread_in_domain
|
||||
in
|
||||
|
||||
(* start all worker threads, placing them on the domains
|
||||
according to their index and [offset] in a round-robin fashion. *)
|
||||
pool.workers <- Array.init num_threads start_thread_with_idx;
|
||||
Array.iteri start_thread_with_idx pool.workers;
|
||||
|
||||
(* receive the newly created threads back from domains *)
|
||||
for _j = 1 to num_threads do
|
||||
|
|
@ -322,7 +324,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
|
|||
worker_state.thread <- th
|
||||
done;
|
||||
|
||||
Lazy.force pool.as_runner
|
||||
pool.as_runner
|
||||
|
||||
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
|
||||
?name () f =
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
(name moonpool_fib)
|
||||
(public_name moonpool.fib)
|
||||
(synopsis "Fibers and structured concurrency for Moonpool")
|
||||
(libraries moonpool)
|
||||
(libraries moonpool picos)
|
||||
(enabled_if
|
||||
(>= %{ocaml_version} 5.0))
|
||||
(flags :standard -open Moonpool_private -open Moonpool)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
open Moonpool.Private.Types_
|
||||
module A = Atomic
|
||||
module FM = Handle.Map
|
||||
module Int_map = Map.Make (Int)
|
||||
module PF = Picos.Fiber
|
||||
module FLS = Picos.Fiber.FLS
|
||||
|
||||
type 'a callback = 'a Exn_bt.result -> unit
|
||||
(** Callbacks that are called when a fiber is done. *)
|
||||
|
|
@ -10,13 +13,16 @@ type cancel_callback = Exn_bt.t -> unit
|
|||
let prom_of_fut : 'a Fut.t -> 'a Fut.promise =
|
||||
Fut.Private_.unsafe_promise_of_fut
|
||||
|
||||
(* TODO: replace with picos structured at some point? *)
|
||||
module Private_ = struct
|
||||
type pfiber = PF.t
|
||||
|
||||
type 'a t = {
|
||||
id: Handle.t; (** unique identifier for this fiber *)
|
||||
state: 'a state A.t; (** Current state in the lifetime of the fiber *)
|
||||
res: 'a Fut.t;
|
||||
runner: Runner.t;
|
||||
ls: Task_local_storage.t;
|
||||
pfiber: pfiber; (** Associated picos fiber *)
|
||||
}
|
||||
|
||||
and 'a state =
|
||||
|
|
@ -30,11 +36,18 @@ module Private_ = struct
|
|||
and children = any FM.t
|
||||
and any = Any : _ t -> any [@@unboxed]
|
||||
|
||||
(** Key to access the current fiber. *)
|
||||
let k_current_fiber : any option Task_local_storage.key =
|
||||
Task_local_storage.new_key ~init:(fun () -> None) ()
|
||||
(** Key to access the current moonpool.fiber. *)
|
||||
let k_current_fiber : any FLS.t = FLS.create ()
|
||||
|
||||
let[@inline] get_cur () : any option = Task_local_storage.get k_current_fiber
|
||||
exception Not_set = FLS.Not_set
|
||||
|
||||
let[@inline] get_cur_from_exn (pfiber : pfiber) : any =
|
||||
FLS.get_exn pfiber k_current_fiber
|
||||
|
||||
let[@inline] get_cur_exn () : any =
|
||||
get_cur_from_exn @@ get_current_fiber_exn ()
|
||||
|
||||
let[@inline] get_cur_opt () = try Some (get_cur_exn ()) with _ -> None
|
||||
|
||||
let[@inline] is_closed (self : _ t) =
|
||||
match A.get self.state with
|
||||
|
|
@ -44,9 +57,9 @@ end
|
|||
|
||||
include Private_
|
||||
|
||||
let create_ ~ls ~runner () : 'a t =
|
||||
let create_ ~pfiber ~runner () : 'a t =
|
||||
let id = Handle.generate_fresh () in
|
||||
let res, _promise = Fut.make () in
|
||||
let res, _ = Fut.make () in
|
||||
{
|
||||
state =
|
||||
A.make
|
||||
|
|
@ -54,7 +67,7 @@ let create_ ~ls ~runner () : 'a t =
|
|||
id;
|
||||
res;
|
||||
runner;
|
||||
ls;
|
||||
pfiber;
|
||||
}
|
||||
|
||||
let create_done_ ~res () : _ t =
|
||||
|
|
@ -66,7 +79,7 @@ let create_done_ ~res () : _ t =
|
|||
id;
|
||||
res;
|
||||
runner = Runner.dummy;
|
||||
ls = Task_local_storage.dummy;
|
||||
pfiber = Moonpool.Private.Types_._dummy_fiber;
|
||||
}
|
||||
|
||||
let[@inline] return x = create_done_ ~res:(Fut.return x) ()
|
||||
|
|
@ -175,7 +188,8 @@ let with_on_cancel (self : _ t) cb (k : unit -> 'a) : 'a =
|
|||
let h = add_on_cancel self cb in
|
||||
Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
|
||||
|
||||
(** Successfully resolve the fiber *)
|
||||
(** Successfully resolve the fiber. This might still fail if
|
||||
some children failed. *)
|
||||
let resolve_ok_ (self : 'a t) (r : 'a) : unit =
|
||||
let r = A.make @@ Ok r in
|
||||
let promise = prom_of_fut self.res in
|
||||
|
|
@ -239,15 +253,21 @@ let add_child_ ~protect (self : _ t) (child : _ t) =
|
|||
()
|
||||
done
|
||||
|
||||
let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t =
|
||||
let spawn_ ~parent ~runner (f : unit -> 'a) : 'a t =
|
||||
let comp = Picos.Computation.create () in
|
||||
let pfiber = PF.create ~forbid:false comp in
|
||||
|
||||
(* inherit FLS from parent, if present *)
|
||||
Option.iter (fun (p : _ t) -> PF.copy_fls p.pfiber pfiber) parent;
|
||||
|
||||
(match parent with
|
||||
| Some p when is_closed p -> failwith "spawn: nursery is closed"
|
||||
| _ -> ());
|
||||
let fib = create_ ~ls ~runner () in
|
||||
let fib = create_ ~pfiber ~runner () in
|
||||
|
||||
let run () =
|
||||
(* make sure the fiber is accessible from inside itself *)
|
||||
Task_local_storage.set k_current_fiber (Some (Any fib));
|
||||
FLS.set pfiber k_current_fiber (Any fib);
|
||||
try
|
||||
let res = f () in
|
||||
resolve_ok_ fib res
|
||||
|
|
@ -257,63 +277,54 @@ let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t =
|
|||
resolve_as_failed_ fib ebt
|
||||
in
|
||||
|
||||
Runner.run_async ~ls runner run;
|
||||
Runner.run_async ~fiber:pfiber runner run;
|
||||
|
||||
fib
|
||||
|
||||
let spawn_top ~on f : _ t =
|
||||
let ls = Task_local_storage.Direct.create () in
|
||||
spawn_ ~ls ~runner:on ~parent:None f
|
||||
let spawn_top ~on f : _ t = spawn_ ~runner:on ~parent:None f
|
||||
|
||||
let spawn ?on ?(protect = true) f : _ t =
|
||||
(* spawn [f()] with a copy of our local storage *)
|
||||
let (Any p) =
|
||||
match get_cur () with
|
||||
| None -> failwith "Fiber.spawn: must be run from within another fiber."
|
||||
| Some p -> p
|
||||
try get_cur_exn ()
|
||||
with Not_set ->
|
||||
failwith "Fiber.spawn: must be run from within another fiber."
|
||||
in
|
||||
let ls = Task_local_storage.Direct.copy p.ls in
|
||||
|
||||
let runner =
|
||||
match on with
|
||||
| Some r -> r
|
||||
| None -> p.runner
|
||||
in
|
||||
let child = spawn_ ~ls ~parent:(Some p) ~runner f in
|
||||
let child = spawn_ ~parent:(Some p) ~runner f in
|
||||
add_child_ ~protect p child;
|
||||
child
|
||||
|
||||
let[@inline] spawn_ignore ?protect f : unit = ignore (spawn ?protect f : _ t)
|
||||
|
||||
let[@inline] self () : any =
|
||||
match Task_local_storage.get k_current_fiber with
|
||||
| None -> failwith "Fiber.self: must be run from inside a fiber."
|
||||
| Some f -> f
|
||||
match get_cur_exn () with
|
||||
| exception Not_set -> failwith "Fiber.self: must be run from inside a fiber."
|
||||
| f -> f
|
||||
|
||||
let with_on_self_cancel cb (k : unit -> 'a) : 'a =
|
||||
let (Any self) = self () in
|
||||
let h = add_on_cancel self cb in
|
||||
Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
|
||||
|
||||
module Suspend_ = Moonpool.Private.Suspend_
|
||||
|
||||
let check_if_cancelled_ (self : _ t) =
|
||||
match A.get self.state with
|
||||
| Terminating_or_done r ->
|
||||
(match A.get r with
|
||||
| Error ebt -> Exn_bt.raise ebt
|
||||
| _ -> ())
|
||||
| _ -> ()
|
||||
let[@inline] check_if_cancelled_ (self : _ t) = PF.check self.pfiber
|
||||
|
||||
let check_if_cancelled () =
|
||||
match Task_local_storage.get k_current_fiber with
|
||||
| None ->
|
||||
match get_cur_exn () with
|
||||
| exception Not_set ->
|
||||
failwith "Fiber.check_if_cancelled: must be run from inside a fiber."
|
||||
| Some (Any self) -> check_if_cancelled_ self
|
||||
| Any self -> check_if_cancelled_ self
|
||||
|
||||
let yield () : unit =
|
||||
match Task_local_storage.get k_current_fiber with
|
||||
| None -> failwith "Fiber.yield: must be run from inside a fiber."
|
||||
| Some (Any self) ->
|
||||
match get_cur_exn () with
|
||||
| exception Not_set ->
|
||||
failwith "Fiber.yield: must be run from inside a fiber."
|
||||
| Any self ->
|
||||
check_if_cancelled_ self;
|
||||
Suspend_.yield ();
|
||||
PF.yield ();
|
||||
check_if_cancelled_ self
|
||||
|
|
|
|||
|
|
@ -17,20 +17,27 @@ type cancel_callback = Exn_bt.t -> unit
|
|||
(** Do not rely on this, it is internal implementation details. *)
|
||||
module Private_ : sig
|
||||
type 'a state
|
||||
type pfiber
|
||||
|
||||
type 'a t = private {
|
||||
id: Handle.t; (** unique identifier for this fiber *)
|
||||
state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *)
|
||||
res: 'a Fut.t;
|
||||
runner: Runner.t;
|
||||
ls: Task_local_storage.t;
|
||||
pfiber: pfiber;
|
||||
}
|
||||
(** Type definition, exposed so that {!any} can be unboxed.
|
||||
Please do not rely on that. *)
|
||||
|
||||
type any = Any : _ t -> any [@@unboxed]
|
||||
|
||||
val get_cur : unit -> any option
|
||||
exception Not_set
|
||||
|
||||
val get_cur_exn : unit -> any
|
||||
(** [get_cur_exn ()] either returns the current fiber, or
|
||||
@raise Not_set if run outside a fiber. *)
|
||||
|
||||
val get_cur_opt : unit -> any option
|
||||
end
|
||||
|
||||
(**/**)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,20 @@
|
|||
exception Oh_no of Exn_bt.t
|
||||
|
||||
let main (f : Runner.t -> 'a) : 'a =
|
||||
let st = Fifo_pool.Private_.create_state ~threads:[| Thread.self () |] () in
|
||||
let runner = Fifo_pool.Private_.runner_of_state st in
|
||||
let worker_st =
|
||||
Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ())
|
||||
~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt)))
|
||||
()
|
||||
in
|
||||
let runner = Fifo_pool.Private_.runner_of_state worker_st in
|
||||
try
|
||||
let fiber = Fiber.spawn_top ~on:runner (fun () -> f runner) in
|
||||
Fiber.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner);
|
||||
|
||||
(* run the main thread *)
|
||||
Fifo_pool.Private_.run_thread st runner ~on_exn:(fun e bt ->
|
||||
raise (Oh_no (Exn_bt.make e bt)));
|
||||
Moonpool.Private.Worker_loop_.worker_loop worker_st
|
||||
~ops:Fifo_pool.Private_.worker_ops;
|
||||
|
||||
match Fiber.peek fiber with
|
||||
| Some (Ok x) -> x
|
||||
| Some (Error ebt) -> Exn_bt.raise ebt
|
||||
|
|
|
|||
|
|
@ -6,4 +6,4 @@
|
|||
(optional)
|
||||
(enabled_if
|
||||
(>= %{ocaml_version} 5.0))
|
||||
(libraries moonpool moonpool.private))
|
||||
(libraries moonpool moonpool.private picos))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
module A = Moonpool.Atomic
|
||||
module Suspend_ = Moonpool.Private.Suspend_
|
||||
module Domain_ = Moonpool_private.Domain_
|
||||
|
||||
module State_ = struct
|
||||
|
|
@ -9,7 +8,7 @@ module State_ = struct
|
|||
type ('a, 'b) t =
|
||||
| Init
|
||||
| Left_solved of 'a or_error
|
||||
| Right_solved of 'b or_error * Suspend_.suspension
|
||||
| Right_solved of 'b or_error * Trigger.t
|
||||
| Both_solved of 'a or_error * 'b or_error
|
||||
|
||||
let get_exn_ (self : _ t A.t) =
|
||||
|
|
@ -28,13 +27,13 @@ module State_ = struct
|
|||
Domain_.relax ();
|
||||
set_left_ self left
|
||||
)
|
||||
| Right_solved (right, cont) ->
|
||||
| Right_solved (right, tr) ->
|
||||
let new_st = Both_solved (left, right) in
|
||||
if not (A.compare_and_set self old_st new_st) then (
|
||||
Domain_.relax ();
|
||||
set_left_ self left
|
||||
) else
|
||||
cont (Ok ())
|
||||
Trigger.signal tr
|
||||
| Left_solved _ | Both_solved _ -> assert false
|
||||
|
||||
let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
|
||||
|
|
@ -45,27 +44,27 @@ module State_ = struct
|
|||
if not (A.compare_and_set self old_st new_st) then set_right_ self right
|
||||
| Init ->
|
||||
(* we are first arrived, we suspend until the left computation is done *)
|
||||
Suspend_.suspend
|
||||
{
|
||||
Suspend_.handle =
|
||||
(fun ~run:_ ~resume suspension ->
|
||||
while
|
||||
let old_st = A.get self in
|
||||
match old_st with
|
||||
| Init ->
|
||||
not
|
||||
(A.compare_and_set self old_st
|
||||
(Right_solved (right, suspension)))
|
||||
| Left_solved left ->
|
||||
(* other thread is done, no risk of race condition *)
|
||||
A.set self (Both_solved (left, right));
|
||||
resume suspension (Ok ());
|
||||
false
|
||||
| Right_solved _ | Both_solved _ -> assert false
|
||||
do
|
||||
()
|
||||
done);
|
||||
}
|
||||
let trigger = Trigger.create () in
|
||||
let must_await = ref true in
|
||||
|
||||
while
|
||||
let old_st = A.get self in
|
||||
match old_st with
|
||||
| Init ->
|
||||
(* setup trigger so that left computation will wake us up *)
|
||||
not (A.compare_and_set self old_st (Right_solved (right, trigger)))
|
||||
| Left_solved left ->
|
||||
(* other thread is done, no risk of race condition *)
|
||||
A.set self (Both_solved (left, right));
|
||||
must_await := false;
|
||||
false
|
||||
| Right_solved _ | Both_solved _ -> assert false
|
||||
do
|
||||
()
|
||||
done;
|
||||
|
||||
(* wait for the other computation to be done *)
|
||||
if !must_await then Trigger.await trigger |> Option.iter Exn_bt.raise
|
||||
| Right_solved _ | Both_solved _ -> assert false
|
||||
end
|
||||
|
||||
|
|
@ -102,7 +101,12 @@ let both_ignore f g = ignore (both f g : _ * _)
|
|||
|
||||
let for_ ?chunk_size n (f : int -> int -> unit) : unit =
|
||||
if n > 0 then (
|
||||
let has_failed = A.make false in
|
||||
let runner =
|
||||
match Runner.get_current_runner () with
|
||||
| None -> failwith "forkjoin.for_: must be run inside a moonpool runner."
|
||||
| Some r -> r
|
||||
in
|
||||
let failure = A.make None in
|
||||
let missing = A.make n in
|
||||
|
||||
let chunk_size =
|
||||
|
|
@ -113,40 +117,36 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
|
|||
max 1 (1 + (n / Moonpool.Private.num_domains ()))
|
||||
in
|
||||
|
||||
let start_tasks ~run ~resume (suspension : Suspend_.suspension) =
|
||||
let task_for ~offset ~len_range =
|
||||
match f offset (offset + len_range - 1) with
|
||||
| () ->
|
||||
if A.fetch_and_add missing (-len_range) = len_range then
|
||||
(* all tasks done successfully *)
|
||||
resume suspension (Ok ())
|
||||
| exception exn ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
if not (A.exchange has_failed true) then
|
||||
(* first one to fail, and [missing] must be >= 2
|
||||
because we're not decreasing it. *)
|
||||
resume suspension (Error { Exn_bt.exn; bt })
|
||||
in
|
||||
let trigger = Trigger.create () in
|
||||
|
||||
let i = ref 0 in
|
||||
while !i < n do
|
||||
let offset = !i in
|
||||
|
||||
let len_range = min chunk_size (n - offset) in
|
||||
assert (offset + len_range <= n);
|
||||
|
||||
run (fun () -> task_for ~offset ~len_range);
|
||||
i := !i + len_range
|
||||
done
|
||||
let task_for ~offset ~len_range =
|
||||
match f offset (offset + len_range - 1) with
|
||||
| () ->
|
||||
if A.fetch_and_add missing (-len_range) = len_range then
|
||||
(* all tasks done successfully *)
|
||||
Trigger.signal trigger
|
||||
| exception exn ->
|
||||
let bt = Printexc.get_raw_backtrace () in
|
||||
if Option.is_none (A.exchange failure (Some { Exn_bt.exn; bt })) then
|
||||
(* first one to fail, and [missing] must be >= 2
|
||||
because we're not decreasing it. *)
|
||||
Trigger.signal trigger
|
||||
in
|
||||
|
||||
Suspend_.suspend
|
||||
{
|
||||
Suspend_.handle =
|
||||
(fun ~run ~resume suspension ->
|
||||
(* run tasks, then we'll resume [suspension] *)
|
||||
start_tasks ~run ~resume suspension);
|
||||
}
|
||||
let i = ref 0 in
|
||||
while !i < n do
|
||||
let offset = !i in
|
||||
|
||||
let len_range = min chunk_size (n - offset) in
|
||||
assert (offset + len_range <= n);
|
||||
|
||||
Runner.run_async runner (fun () -> task_for ~offset ~len_range);
|
||||
i := !i + len_range
|
||||
done;
|
||||
|
||||
Trigger.await trigger |> Option.iter Exn_bt.raise;
|
||||
Option.iter Exn_bt.raise @@ A.get failure;
|
||||
()
|
||||
)
|
||||
|
||||
let all_array ?chunk_size (fs : _ array) : _ array =
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
open Base
|
||||
|
||||
let await_readable fd : unit =
|
||||
Moonpool.Private.Suspend_.suspend
|
||||
{
|
||||
handle =
|
||||
(fun ~run:_ ~resume sus ->
|
||||
Perform_action_in_lwt.schedule
|
||||
@@ Action.Wait_readable
|
||||
( fd,
|
||||
fun cancel ->
|
||||
resume sus @@ Ok ();
|
||||
Lwt_engine.stop_event cancel ));
|
||||
}
|
||||
let trigger = Trigger.create () in
|
||||
Perform_action_in_lwt.schedule
|
||||
@@ Action.Wait_readable
|
||||
( fd,
|
||||
fun cancel ->
|
||||
Trigger.signal trigger;
|
||||
Lwt_engine.stop_event cancel );
|
||||
Trigger.await_exn trigger
|
||||
|
||||
let rec read fd buf i len : int =
|
||||
if len = 0 then
|
||||
|
|
@ -25,17 +22,14 @@ let rec read fd buf i len : int =
|
|||
)
|
||||
|
||||
let await_writable fd =
|
||||
Moonpool.Private.Suspend_.suspend
|
||||
{
|
||||
handle =
|
||||
(fun ~run:_ ~resume sus ->
|
||||
Perform_action_in_lwt.schedule
|
||||
@@ Action.Wait_writable
|
||||
( fd,
|
||||
fun cancel ->
|
||||
resume sus @@ Ok ();
|
||||
Lwt_engine.stop_event cancel ));
|
||||
}
|
||||
let trigger = Trigger.create () in
|
||||
Perform_action_in_lwt.schedule
|
||||
@@ Action.Wait_writable
|
||||
( fd,
|
||||
fun cancel ->
|
||||
Trigger.signal trigger;
|
||||
Lwt_engine.stop_event cancel );
|
||||
Trigger.await_exn trigger
|
||||
|
||||
let rec write_once fd buf i len : int =
|
||||
if len = 0 then
|
||||
|
|
@ -59,16 +53,14 @@ let write fd buf i len : unit =
|
|||
|
||||
(** Sleep for the given amount of seconds *)
|
||||
let sleep_s (f : float) : unit =
|
||||
if f > 0. then
|
||||
Moonpool.Private.Suspend_.suspend
|
||||
{
|
||||
handle =
|
||||
(fun ~run:_ ~resume sus ->
|
||||
Perform_action_in_lwt.schedule
|
||||
@@ Action.Sleep
|
||||
( f,
|
||||
false,
|
||||
fun cancel ->
|
||||
resume sus @@ Ok ();
|
||||
Lwt_engine.stop_event cancel ));
|
||||
}
|
||||
if f > 0. then (
|
||||
let trigger = Trigger.create () in
|
||||
Perform_action_in_lwt.schedule
|
||||
@@ Action.Sleep
|
||||
( f,
|
||||
false,
|
||||
fun cancel ->
|
||||
Trigger.signal trigger;
|
||||
Lwt_engine.stop_event cancel );
|
||||
Trigger.await_exn trigger
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
open Common_
|
||||
module Trigger = M.Trigger
|
||||
module Fiber = Moonpool_fib.Fiber
|
||||
module FLS = Moonpool_fib.Fls
|
||||
|
||||
|
|
@ -14,7 +15,7 @@ module Action = struct
|
|||
| Sleep of float * bool * cb
|
||||
(* TODO: provide actions with cancellation, alongside a "select" operation *)
|
||||
(* | Cancel of event *)
|
||||
| On_termination : 'a Lwt.t * ('a Exn_bt.result -> unit) -> t
|
||||
| On_termination : 'a Lwt.t * 'a Exn_bt.result ref * Trigger.t -> t
|
||||
| Wakeup : 'a Lwt.u * 'a -> t
|
||||
| Wakeup_exn : _ Lwt.u * exn -> t
|
||||
| Other of (unit -> unit)
|
||||
|
|
@ -26,10 +27,14 @@ module Action = struct
|
|||
| Wait_writable (fd, cb) -> ignore (Lwt_engine.on_writable fd cb : event)
|
||||
| Sleep (f, repeat, cb) -> ignore (Lwt_engine.on_timer f repeat cb : event)
|
||||
(* | Cancel ev -> Lwt_engine.stop_event ev *)
|
||||
| On_termination (fut, f) ->
|
||||
| On_termination (fut, res, trigger) ->
|
||||
Lwt.on_any fut
|
||||
(fun x -> f @@ Ok x)
|
||||
(fun exn -> f @@ Error (Exn_bt.get_callstack 10 exn))
|
||||
(fun x ->
|
||||
res := Ok x;
|
||||
Trigger.signal trigger)
|
||||
(fun exn ->
|
||||
res := Error (Exn_bt.get_callstack 10 exn);
|
||||
Trigger.signal trigger)
|
||||
| Wakeup (prom, x) -> Lwt.wakeup prom x
|
||||
| Wakeup_exn (prom, e) -> Lwt.wakeup_exn prom e
|
||||
| Other f -> f ()
|
||||
|
|
@ -106,23 +111,19 @@ let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t =
|
|||
M.Fut.fulfill prom (Error { Exn_bt.exn; bt }));
|
||||
fut
|
||||
|
||||
let _dummy_exn_bt : Exn_bt.t =
|
||||
Exn_bt.get_callstack 0 (Failure "dummy Exn_bt from moonpool-lwt")
|
||||
|
||||
let await_lwt (fut : _ Lwt.t) =
|
||||
match Lwt.poll fut with
|
||||
| Some x -> x
|
||||
| None ->
|
||||
(* suspend fiber, wake it up when [fut] resolves *)
|
||||
M.Private.Suspend_.suspend
|
||||
{
|
||||
handle =
|
||||
(fun ~run:_ ~resume sus ->
|
||||
let on_lwt_done _ = resume sus @@ Ok () in
|
||||
Perform_action_in_lwt.(
|
||||
schedule Action.(On_termination (fut, on_lwt_done))));
|
||||
};
|
||||
|
||||
(match Lwt.poll fut with
|
||||
| Some x -> x
|
||||
| None -> assert false)
|
||||
let trigger = M.Trigger.create () in
|
||||
let res = ref (Error _dummy_exn_bt) in
|
||||
Perform_action_in_lwt.(schedule Action.(On_termination (fut, res, trigger)));
|
||||
Trigger.await trigger |> Option.iter Exn_bt.raise;
|
||||
Exn_bt.unwrap !res
|
||||
|
||||
let run_in_lwt f : _ M.Fut.t =
|
||||
let fut, prom = M.Fut.make () in
|
||||
|
|
|
|||
|
|
@ -4,4 +4,9 @@
|
|||
(private_modules common_)
|
||||
(enabled_if
|
||||
(>= %{ocaml_version} 5.0))
|
||||
(libraries moonpool moonpool.fib lwt lwt.unix))
|
||||
(libraries
|
||||
(re_export moonpool)
|
||||
(re_export moonpool.fib)
|
||||
picos
|
||||
(re_export lwt)
|
||||
lwt.unix))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue