wip: port to picos

This commit is contained in:
Simon Cruanes 2024-08-28 16:09:45 -04:00
parent a0068b09b3
commit 07a7fc3a1c
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
17 changed files with 314 additions and 316 deletions

View file

@ -7,3 +7,7 @@ let show self = Printexc.to_string (exn self)
let pp out self = Format.pp_print_string out (show self)
type nonrec 'a result = ('a, t) result
let[@inline] unwrap = function
| Ok x -> x
| Error ebt -> raise ebt

View file

@ -21,3 +21,7 @@ val show : t -> string
val pp : Format.formatter -> t -> unit
type nonrec 'a result = ('a, t) result
val unwrap : 'a result -> 'a
(** [unwrap (Ok x)] is [x], [unwrap (Error ebt)] re-raises [ebt].
@since NEXT_RELEASE *)

View file

@ -11,7 +11,7 @@ type state = {
threads: Thread.t array;
q: task_full Bb_queue.t; (** Queue for tasks. *)
around_task: WL.around_task;
as_runner: t lazy_t;
mutable as_runner: t;
(* init options *)
name: string option;
on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
@ -24,7 +24,6 @@ type worker_state = {
idx: int;
dom_idx: int;
st: state;
mutable current: fiber;
}
let[@inline] size_ (self : state) = Array.length self.threads
@ -95,7 +94,7 @@ let cleanup (self : worker_state) : unit =
self.st.on_exit_thread ~dom_id:self.dom_idx ~t_id ()
let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = Lazy.force st.st.as_runner in
let runner (st : worker_state) = st.st.as_runner in
let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn ebt.exn ebt.bt
@ -111,9 +110,9 @@ let worker_ops : worker_state WL.ops =
cleanup;
}
let create ?(on_init_thread = default_thread_init_exit_)
let create_ ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ?num_threads ?name () : t =
?around_task ~threads ?name () : state =
(* wrapper *)
let around_task =
match around_task with
@ -121,6 +120,23 @@ let create ?(on_init_thread = default_thread_init_exit_)
| None -> default_around_task_
in
let self =
{
threads;
q = Bb_queue.create ();
around_task;
as_runner = Runner.dummy;
name;
on_init_thread;
on_exit_thread;
on_exn;
}
in
self.as_runner <- runner_of_state self;
self
let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name () : t =
let num_domains = Domain_pool_.max_number_of_domains () in
(* number of threads to run *)
@ -129,20 +145,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in
let rec pool =
let pool =
let dummy_thread = Thread.self () in
{
threads = Array.make num_threads dummy_thread;
q = Bb_queue.create ();
around_task;
as_runner = lazy (runner_of_state pool);
name;
on_init_thread;
on_exit_thread;
on_exn;
}
let threads = Array.make num_threads dummy_thread in
create_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ~threads ?name
()
in
let runner = runner_of_state pool in
(* temporary queue used to obtain thread handles from domains
@ -156,7 +164,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* function called in domain with index [i], to
create the thread and push it into [receive_threads] *)
let create_thread_in_domain () =
let st = { idx = i; dom_idx; st = pool; current = _dummy_fiber } in
let st = { idx = i; dom_idx; st = pool } in
let thread = Thread.create (WL.worker_loop ~ops:worker_ops) st in
(* send the thread from the domain back to us *)
Bb_queue.push receive_threads (i, thread)
@ -187,3 +195,14 @@ let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool
module Private_ = struct
type nonrec worker_state = worker_state
let worker_ops = worker_ops
let runner_of_state (self : worker_state) = worker_ops.runner self
let create_single_threaded_state ~thread ?on_exn () : worker_state =
let st : state = create_ ?on_exn ~threads:[| thread |] () in
{ idx = 0; dom_idx = 0; st }
end

View file

@ -44,3 +44,21 @@ val with_ : (unit -> (t -> 'a) -> 'a, _) create_args
When [f pool] returns or fails, [pool] is shutdown and its resources
are released.
Most parameters are the same as in {!create}. *)
(**/**)
module Private_ : sig
type worker_state
val worker_ops : worker_state Worker_loop_.ops
val create_single_threaded_state :
thread:Thread.t ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
unit ->
worker_state
val runner_of_state : worker_state -> Runner.t
end
(**/**)

View file

@ -1,83 +1,41 @@
open Types_
(*
module A = Atomic_
module PF = Picos.Fiber
type 'a key = 'a ls_key
type 'a t = 'a PF.FLS.t
let key_count_ = A.make 0
exception Not_set = PF.FLS.Not_set
type t = local_storage
type ls_value += Dummy
let create = PF.FLS.create
let dummy : t = _dummy_ls
let[@inline] get_exn k =
let fiber = get_current_fiber_exn () in
PF.FLS.get_exn fiber k
(** Resize array of TLS values *)
let[@inline never] resize_ (cur : ls_value array ref) n =
if n > Sys.max_array_length then failwith "too many task local storage keys";
let len = Array.length !cur in
let new_ls =
Array.make (min Sys.max_array_length (max n ((len * 2) + 2))) Dummy
in
Array.blit !cur 0 new_ls 0 len;
cur := new_ls
module Direct = struct
type nonrec t = t
let create = create_local_storage
let[@inline] copy (self : t) = ref (Array.copy !self)
let get (type a) (self : t) ((module K) : a key) : a =
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
match !self.(K.offset) with
| K.V x -> (* common case first *) x
| Dummy ->
(* first time we access this *)
let v = K.init () in
!self.(K.offset) <- K.V v;
v
| _ -> assert false
let set (type a) (self : t) ((module K) : a key) (v : a) : unit =
assert (self != dummy);
if K.offset >= Array.length !self then resize_ self (K.offset + 1);
!self.(K.offset) <- K.V v;
()
end
let new_key (type t) ~init () : t key =
let offset = A.fetch_and_add key_count_ 1 in
(module struct
type nonrec t = t
type ls_value += V of t
let offset = offset
let init = init
end : LS_KEY
with type t = t)
let[@inline] get_cur_ () : ls_value array ref =
match get_current_storage () with
| Some r when r != dummy -> r
| _ -> failwith "Task local storage must be accessed from within a runner."
let[@inline] get (key : 'a key) : 'a =
let cur = get_cur_ () in
Direct.get cur key
let[@inline] get_opt key =
match get_current_storage () with
let get_opt k =
match get_current_fiber () with
| None -> None
| Some cur -> Some (Direct.get cur key)
| Some fiber ->
(match PF.FLS.get_exn fiber k with
| x -> Some x
| exception Not_set -> None)
let[@inline] set key v : unit =
let cur = get_cur_ () in
Direct.set cur key v
let[@inline] get k ~default =
let fiber = get_current_fiber_exn () in
PF.FLS.get fiber ~default k
let with_value key x f =
let old = get key in
set key x;
Fun.protect ~finally:(fun () -> set key old) f
let[@inline] set k v : unit =
let fiber = get_current_fiber_exn () in
PF.FLS.set fiber k v
let get_current = get_current_storage
*)
let with_value k v (f : _ -> 'b) : 'b =
let fiber = get_current_fiber_exn () in
match PF.FLS.get_exn fiber k with
| exception Not_set ->
PF.FLS.set fiber k v;
(* nothing to restore back to, just call [f] *)
f ()
| old_v ->
PF.FLS.set fiber k v;
let finally () = PF.FLS.set fiber k old_v in
Fun.protect f ~finally

View file

@ -8,62 +8,31 @@
@since 0.6
*)
(*
type t = Types_.local_storage
(** Underlying storage for a task. This is mutable and
not thread-safe. *)
type 'a t = 'a Picos.Fiber.FLS.t
val dummy : t
val create : unit -> 'a t
(** [create ()] makes a new key. Keys are expensive and
should never be allocated dynamically or in a loop. *)
type 'a key
(** A key used to access a particular (typed) storage slot on every task. *)
exception Not_set
val new_key : init:(unit -> 'a) -> unit -> 'a key
(** [new_key ~init ()] makes a new key. Keys are expensive and
should never be allocated dynamically or in a loop.
The correct pattern is, at toplevel:
{[
let k_foo : foo Task_ocal_storage.key =
Task_local_storage.new_key ~init:(fun () -> make_foo ()) ()
(**)
(* use it: *)
let = Task_local_storage.get k_foo
]}
*)
val get : 'a key -> 'a
val get_exn : 'a t -> 'a
(** [get k] gets the value for the current task for key [k].
Must be run from inside a task running on a runner.
@raise Failure otherwise *)
@raise Not_set otherwise *)
val get_opt : 'a key -> 'a option
val get_opt : 'a t -> 'a option
(** [get_opt k] gets the current task's value for key [k],
or [None] if not run from inside the task. *)
val set : 'a key -> 'a -> unit
val get : 'a t -> default:'a -> 'a
val set : 'a t -> 'a -> unit
(** [set k v] sets the storage for [k] to [v].
Must be run from inside a task running on a runner.
@raise Failure otherwise *)
val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b
val with_value : 'a t -> 'a -> (unit -> 'b) -> 'b
(** [with_value k v f] sets [k] to [v] for the duration of the call
to [f()]. When [f()] returns (or fails), [k] is restored
to its old value. *)
val get_current : unit -> t option
(** Access the current storage, or [None] if not run from
within a task. *)
(** Direct access to values from a storage handle *)
module Direct : sig
val get : t -> 'a key -> 'a
(** Access a key *)
val set : t -> 'a key -> 'a -> unit
val create : unit -> t
val copy : t -> t
end
*)

View file

@ -2,3 +2,5 @@
@since NEXT_RELEASE *)
include Picos.Trigger
let[@inline] await_exn (self : t) = await self |> Option.iter Exn_bt.raise

View file

@ -25,7 +25,7 @@ type state = {
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
mutex: Mutex.t;
cond: Condition.t;
as_runner: t lazy_t;
mutable as_runner: t;
(* init options *)
around_task: WL.around_task;
name: string option;
@ -167,7 +167,9 @@ and wait_on_worker (self : worker_state) : WL.task_full =
| task ->
Mutex.unlock self.st.mutex;
task
| exception Queue.Empty -> try_steal_from_other_workers_ self
| exception Queue.Empty ->
Mutex.unlock self.st.mutex;
try_steal_from_other_workers_ self
) else (
(* do nothing more: no task in main queue, and we are shutting
down so no new task should arrive.
@ -183,8 +185,7 @@ let before_start (self : worker_state) : unit =
let t_id = Thread.id @@ Thread.self () in
self.st.on_init_thread ~dom_id:self.dom_id ~t_id ();
TLS.set k_cur_fiber _dummy_fiber;
TLS.set Runner.For_runner_implementors.k_cur_runner
(Lazy.force self.st.as_runner);
TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner;
TLS.set k_worker_state self;
(* set thread name *)
@ -200,7 +201,7 @@ let cleanup (self : worker_state) : unit =
self.st.on_exit_thread ~dom_id:self.dom_id ~t_id ()
let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = Lazy.force st.st.as_runner in
let runner (st : worker_state) = st.st.as_runner in
let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn ebt.exn ebt.bt
@ -261,7 +262,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* make sure we don't bias towards the first domain(s) in {!D_pool_} *)
let offset = Random.int num_domains in
let rec pool =
let pool =
{
id_ = pool_id_;
active = A.make true;
@ -276,18 +277,18 @@ let create ?(on_init_thread = default_thread_init_exit_)
on_init_thread;
on_exit_thread;
name;
as_runner = lazy (as_runner_ pool);
as_runner = Runner.dummy;
}
in
pool.as_runner <- as_runner_ pool;
(* temporary queue used to obtain thread handles from domains
on which the thread are started. *)
let receive_threads = Bb_queue.create () in
(* start the thread with index [i] *)
let start_thread_with_idx idx =
let create_worker_state idx =
let dom_id = (offset + idx) mod num_domains in
let st =
{
st = pool;
thread = (* dummy *) Thread.self ();
@ -298,6 +299,10 @@ let create ?(on_init_thread = default_thread_init_exit_)
}
in
pool.workers <- Array.init num_threads create_worker_state;
(* start the thread with index [i] *)
let start_thread_with_idx idx (st : worker_state) =
(* function called in domain with index [i], to
create the thread and push it into [receive_threads] *)
let create_thread_in_domain () =
@ -305,15 +310,12 @@ let create ?(on_init_thread = default_thread_init_exit_)
(* send the thread from the domain back to us *)
Bb_queue.push receive_threads (idx, thread)
in
Domain_pool_.run_on dom_id create_thread_in_domain;
st
Domain_pool_.run_on st.dom_id create_thread_in_domain
in
(* start all worker threads, placing them on the domains
according to their index and [offset] in a round-robin fashion. *)
pool.workers <- Array.init num_threads start_thread_with_idx;
Array.iteri start_thread_with_idx pool.workers;
(* receive the newly created threads back from domains *)
for _j = 1 to num_threads do
@ -322,7 +324,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
worker_state.thread <- th
done;
Lazy.force pool.as_runner
pool.as_runner
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
?name () f =

View file

@ -2,7 +2,7 @@
(name moonpool_fib)
(public_name moonpool.fib)
(synopsis "Fibers and structured concurrency for Moonpool")
(libraries moonpool)
(libraries moonpool picos)
(enabled_if
(>= %{ocaml_version} 5.0))
(flags :standard -open Moonpool_private -open Moonpool)

View file

@ -1,6 +1,9 @@
open Moonpool.Private.Types_
module A = Atomic
module FM = Handle.Map
module Int_map = Map.Make (Int)
module PF = Picos.Fiber
module FLS = Picos.Fiber.FLS
type 'a callback = 'a Exn_bt.result -> unit
(** Callbacks that are called when a fiber is done. *)
@ -10,13 +13,16 @@ type cancel_callback = Exn_bt.t -> unit
let prom_of_fut : 'a Fut.t -> 'a Fut.promise =
Fut.Private_.unsafe_promise_of_fut
(* TODO: replace with picos structured at some point? *)
module Private_ = struct
type pfiber = PF.t
type 'a t = {
id: Handle.t; (** unique identifier for this fiber *)
state: 'a state A.t; (** Current state in the lifetime of the fiber *)
res: 'a Fut.t;
runner: Runner.t;
ls: Task_local_storage.t;
pfiber: pfiber; (** Associated picos fiber *)
}
and 'a state =
@ -30,11 +36,18 @@ module Private_ = struct
and children = any FM.t
and any = Any : _ t -> any [@@unboxed]
(** Key to access the current fiber. *)
let k_current_fiber : any option Task_local_storage.key =
Task_local_storage.new_key ~init:(fun () -> None) ()
(** Key to access the current moonpool.fiber. *)
let k_current_fiber : any FLS.t = FLS.create ()
let[@inline] get_cur () : any option = Task_local_storage.get k_current_fiber
exception Not_set = FLS.Not_set
let[@inline] get_cur_from_exn (pfiber : pfiber) : any =
FLS.get_exn pfiber k_current_fiber
let[@inline] get_cur_exn () : any =
get_cur_from_exn @@ get_current_fiber_exn ()
let[@inline] get_cur_opt () = try Some (get_cur_exn ()) with _ -> None
let[@inline] is_closed (self : _ t) =
match A.get self.state with
@ -44,9 +57,9 @@ end
include Private_
let create_ ~ls ~runner () : 'a t =
let create_ ~pfiber ~runner () : 'a t =
let id = Handle.generate_fresh () in
let res, _promise = Fut.make () in
let res, _ = Fut.make () in
{
state =
A.make
@ -54,7 +67,7 @@ let create_ ~ls ~runner () : 'a t =
id;
res;
runner;
ls;
pfiber;
}
let create_done_ ~res () : _ t =
@ -66,7 +79,7 @@ let create_done_ ~res () : _ t =
id;
res;
runner = Runner.dummy;
ls = Task_local_storage.dummy;
pfiber = Moonpool.Private.Types_._dummy_fiber;
}
let[@inline] return x = create_done_ ~res:(Fut.return x) ()
@ -175,7 +188,8 @@ let with_on_cancel (self : _ t) cb (k : unit -> 'a) : 'a =
let h = add_on_cancel self cb in
Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
(** Successfully resolve the fiber *)
(** Successfully resolve the fiber. This might still fail if
some children failed. *)
let resolve_ok_ (self : 'a t) (r : 'a) : unit =
let r = A.make @@ Ok r in
let promise = prom_of_fut self.res in
@ -239,15 +253,21 @@ let add_child_ ~protect (self : _ t) (child : _ t) =
()
done
let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t =
let spawn_ ~parent ~runner (f : unit -> 'a) : 'a t =
let comp = Picos.Computation.create () in
let pfiber = PF.create ~forbid:false comp in
(* inherit FLS from parent, if present *)
Option.iter (fun (p : _ t) -> PF.copy_fls p.pfiber pfiber) parent;
(match parent with
| Some p when is_closed p -> failwith "spawn: nursery is closed"
| _ -> ());
let fib = create_ ~ls ~runner () in
let fib = create_ ~pfiber ~runner () in
let run () =
(* make sure the fiber is accessible from inside itself *)
Task_local_storage.set k_current_fiber (Some (Any fib));
FLS.set pfiber k_current_fiber (Any fib);
try
let res = f () in
resolve_ok_ fib res
@ -257,63 +277,54 @@ let spawn_ ~ls ~parent ~runner (f : unit -> 'a) : 'a t =
resolve_as_failed_ fib ebt
in
Runner.run_async ~ls runner run;
Runner.run_async ~fiber:pfiber runner run;
fib
let spawn_top ~on f : _ t =
let ls = Task_local_storage.Direct.create () in
spawn_ ~ls ~runner:on ~parent:None f
let spawn_top ~on f : _ t = spawn_ ~runner:on ~parent:None f
let spawn ?on ?(protect = true) f : _ t =
(* spawn [f()] with a copy of our local storage *)
let (Any p) =
match get_cur () with
| None -> failwith "Fiber.spawn: must be run from within another fiber."
| Some p -> p
try get_cur_exn ()
with Not_set ->
failwith "Fiber.spawn: must be run from within another fiber."
in
let ls = Task_local_storage.Direct.copy p.ls in
let runner =
match on with
| Some r -> r
| None -> p.runner
in
let child = spawn_ ~ls ~parent:(Some p) ~runner f in
let child = spawn_ ~parent:(Some p) ~runner f in
add_child_ ~protect p child;
child
let[@inline] spawn_ignore ?protect f : unit = ignore (spawn ?protect f : _ t)
let[@inline] self () : any =
match Task_local_storage.get k_current_fiber with
| None -> failwith "Fiber.self: must be run from inside a fiber."
| Some f -> f
match get_cur_exn () with
| exception Not_set -> failwith "Fiber.self: must be run from inside a fiber."
| f -> f
let with_on_self_cancel cb (k : unit -> 'a) : 'a =
let (Any self) = self () in
let h = add_on_cancel self cb in
Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
module Suspend_ = Moonpool.Private.Suspend_
let check_if_cancelled_ (self : _ t) =
match A.get self.state with
| Terminating_or_done r ->
(match A.get r with
| Error ebt -> Exn_bt.raise ebt
| _ -> ())
| _ -> ()
let[@inline] check_if_cancelled_ (self : _ t) = PF.check self.pfiber
let check_if_cancelled () =
match Task_local_storage.get k_current_fiber with
| None ->
match get_cur_exn () with
| exception Not_set ->
failwith "Fiber.check_if_cancelled: must be run from inside a fiber."
| Some (Any self) -> check_if_cancelled_ self
| Any self -> check_if_cancelled_ self
let yield () : unit =
match Task_local_storage.get k_current_fiber with
| None -> failwith "Fiber.yield: must be run from inside a fiber."
| Some (Any self) ->
match get_cur_exn () with
| exception Not_set ->
failwith "Fiber.yield: must be run from inside a fiber."
| Any self ->
check_if_cancelled_ self;
Suspend_.yield ();
PF.yield ();
check_if_cancelled_ self

View file

@ -17,20 +17,27 @@ type cancel_callback = Exn_bt.t -> unit
(** Do not rely on this, it is internal implementation details. *)
module Private_ : sig
type 'a state
type pfiber
type 'a t = private {
id: Handle.t; (** unique identifier for this fiber *)
state: 'a state Atomic.t; (** Current state in the lifetime of the fiber *)
res: 'a Fut.t;
runner: Runner.t;
ls: Task_local_storage.t;
pfiber: pfiber;
}
(** Type definition, exposed so that {!any} can be unboxed.
Please do not rely on that. *)
type any = Any : _ t -> any [@@unboxed]
val get_cur : unit -> any option
exception Not_set
val get_cur_exn : unit -> any
(** [get_cur_exn ()] either returns the current fiber, or
@raise Not_set if run outside a fiber. *)
val get_cur_opt : unit -> any option
end
(**/**)

View file

@ -1,14 +1,20 @@
exception Oh_no of Exn_bt.t
let main (f : Runner.t -> 'a) : 'a =
let st = Fifo_pool.Private_.create_state ~threads:[| Thread.self () |] () in
let runner = Fifo_pool.Private_.runner_of_state st in
let worker_st =
Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ())
~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt)))
()
in
let runner = Fifo_pool.Private_.runner_of_state worker_st in
try
let fiber = Fiber.spawn_top ~on:runner (fun () -> f runner) in
Fiber.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner);
(* run the main thread *)
Fifo_pool.Private_.run_thread st runner ~on_exn:(fun e bt ->
raise (Oh_no (Exn_bt.make e bt)));
Moonpool.Private.Worker_loop_.worker_loop worker_st
~ops:Fifo_pool.Private_.worker_ops;
match Fiber.peek fiber with
| Some (Ok x) -> x
| Some (Error ebt) -> Exn_bt.raise ebt

View file

@ -6,4 +6,4 @@
(optional)
(enabled_if
(>= %{ocaml_version} 5.0))
(libraries moonpool moonpool.private))
(libraries moonpool moonpool.private picos))

View file

@ -1,5 +1,4 @@
module A = Moonpool.Atomic
module Suspend_ = Moonpool.Private.Suspend_
module Domain_ = Moonpool_private.Domain_
module State_ = struct
@ -9,7 +8,7 @@ module State_ = struct
type ('a, 'b) t =
| Init
| Left_solved of 'a or_error
| Right_solved of 'b or_error * Suspend_.suspension
| Right_solved of 'b or_error * Trigger.t
| Both_solved of 'a or_error * 'b or_error
let get_exn_ (self : _ t A.t) =
@ -28,13 +27,13 @@ module State_ = struct
Domain_.relax ();
set_left_ self left
)
| Right_solved (right, cont) ->
| Right_solved (right, tr) ->
let new_st = Both_solved (left, right) in
if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
set_left_ self left
) else
cont (Ok ())
Trigger.signal tr
| Left_solved _ | Both_solved _ -> assert false
let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
@ -45,27 +44,27 @@ module State_ = struct
if not (A.compare_and_set self old_st new_st) then set_right_ self right
| Init ->
(* we are first arrived, we suspend until the left computation is done *)
Suspend_.suspend
{
Suspend_.handle =
(fun ~run:_ ~resume suspension ->
let trigger = Trigger.create () in
let must_await = ref true in
while
let old_st = A.get self in
match old_st with
| Init ->
not
(A.compare_and_set self old_st
(Right_solved (right, suspension)))
(* setup trigger so that left computation will wake us up *)
not (A.compare_and_set self old_st (Right_solved (right, trigger)))
| Left_solved left ->
(* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right));
resume suspension (Ok ());
must_await := false;
false
| Right_solved _ | Both_solved _ -> assert false
do
()
done);
}
done;
(* wait for the other computation to be done *)
if !must_await then Trigger.await trigger |> Option.iter Exn_bt.raise
| Right_solved _ | Both_solved _ -> assert false
end
@ -102,7 +101,12 @@ let both_ignore f g = ignore (both f g : _ * _)
let for_ ?chunk_size n (f : int -> int -> unit) : unit =
if n > 0 then (
let has_failed = A.make false in
let runner =
match Runner.get_current_runner () with
| None -> failwith "forkjoin.for_: must be run inside a moonpool runner."
| Some r -> r
in
let failure = A.make None in
let missing = A.make n in
let chunk_size =
@ -113,19 +117,20 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
max 1 (1 + (n / Moonpool.Private.num_domains ()))
in
let start_tasks ~run ~resume (suspension : Suspend_.suspension) =
let trigger = Trigger.create () in
let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with
| () ->
if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *)
resume suspension (Ok ())
Trigger.signal trigger
| exception exn ->
let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then
if Option.is_none (A.exchange failure (Some { Exn_bt.exn; bt })) then
(* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *)
resume suspension (Error { Exn_bt.exn; bt })
Trigger.signal trigger
in
let i = ref 0 in
@ -135,18 +140,13 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n);
run (fun () -> task_for ~offset ~len_range);
Runner.run_async runner (fun () -> task_for ~offset ~len_range);
i := !i + len_range
done
in
done;
Suspend_.suspend
{
Suspend_.handle =
(fun ~run ~resume suspension ->
(* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~resume suspension);
}
Trigger.await trigger |> Option.iter Exn_bt.raise;
Option.iter Exn_bt.raise @@ A.get failure;
()
)
let all_array ?chunk_size (fs : _ array) : _ array =

View file

@ -1,17 +1,14 @@
open Base
let await_readable fd : unit =
Moonpool.Private.Suspend_.suspend
{
handle =
(fun ~run:_ ~resume sus ->
let trigger = Trigger.create () in
Perform_action_in_lwt.schedule
@@ Action.Wait_readable
( fd,
fun cancel ->
resume sus @@ Ok ();
Lwt_engine.stop_event cancel ));
}
Trigger.signal trigger;
Lwt_engine.stop_event cancel );
Trigger.await_exn trigger
let rec read fd buf i len : int =
if len = 0 then
@ -25,17 +22,14 @@ let rec read fd buf i len : int =
)
let await_writable fd =
Moonpool.Private.Suspend_.suspend
{
handle =
(fun ~run:_ ~resume sus ->
let trigger = Trigger.create () in
Perform_action_in_lwt.schedule
@@ Action.Wait_writable
( fd,
fun cancel ->
resume sus @@ Ok ();
Lwt_engine.stop_event cancel ));
}
Trigger.signal trigger;
Lwt_engine.stop_event cancel );
Trigger.await_exn trigger
let rec write_once fd buf i len : int =
if len = 0 then
@ -59,16 +53,14 @@ let write fd buf i len : unit =
(** Sleep for the given amount of seconds *)
let sleep_s (f : float) : unit =
if f > 0. then
Moonpool.Private.Suspend_.suspend
{
handle =
(fun ~run:_ ~resume sus ->
if f > 0. then (
let trigger = Trigger.create () in
Perform_action_in_lwt.schedule
@@ Action.Sleep
( f,
false,
fun cancel ->
resume sus @@ Ok ();
Lwt_engine.stop_event cancel ));
}
Trigger.signal trigger;
Lwt_engine.stop_event cancel );
Trigger.await_exn trigger
)

View file

@ -1,4 +1,5 @@
open Common_
module Trigger = M.Trigger
module Fiber = Moonpool_fib.Fiber
module FLS = Moonpool_fib.Fls
@ -14,7 +15,7 @@ module Action = struct
| Sleep of float * bool * cb
(* TODO: provide actions with cancellation, alongside a "select" operation *)
(* | Cancel of event *)
| On_termination : 'a Lwt.t * ('a Exn_bt.result -> unit) -> t
| On_termination : 'a Lwt.t * 'a Exn_bt.result ref * Trigger.t -> t
| Wakeup : 'a Lwt.u * 'a -> t
| Wakeup_exn : _ Lwt.u * exn -> t
| Other of (unit -> unit)
@ -26,10 +27,14 @@ module Action = struct
| Wait_writable (fd, cb) -> ignore (Lwt_engine.on_writable fd cb : event)
| Sleep (f, repeat, cb) -> ignore (Lwt_engine.on_timer f repeat cb : event)
(* | Cancel ev -> Lwt_engine.stop_event ev *)
| On_termination (fut, f) ->
| On_termination (fut, res, trigger) ->
Lwt.on_any fut
(fun x -> f @@ Ok x)
(fun exn -> f @@ Error (Exn_bt.get_callstack 10 exn))
(fun x ->
res := Ok x;
Trigger.signal trigger)
(fun exn ->
res := Error (Exn_bt.get_callstack 10 exn);
Trigger.signal trigger)
| Wakeup (prom, x) -> Lwt.wakeup prom x
| Wakeup_exn (prom, e) -> Lwt.wakeup_exn prom e
| Other f -> f ()
@ -106,23 +111,19 @@ let fut_of_lwt (lwt_fut : _ Lwt.t) : _ M.Fut.t =
M.Fut.fulfill prom (Error { Exn_bt.exn; bt }));
fut
let _dummy_exn_bt : Exn_bt.t =
Exn_bt.get_callstack 0 (Failure "dummy Exn_bt from moonpool-lwt")
let await_lwt (fut : _ Lwt.t) =
match Lwt.poll fut with
| Some x -> x
| None ->
(* suspend fiber, wake it up when [fut] resolves *)
M.Private.Suspend_.suspend
{
handle =
(fun ~run:_ ~resume sus ->
let on_lwt_done _ = resume sus @@ Ok () in
Perform_action_in_lwt.(
schedule Action.(On_termination (fut, on_lwt_done))));
};
(match Lwt.poll fut with
| Some x -> x
| None -> assert false)
let trigger = M.Trigger.create () in
let res = ref (Error _dummy_exn_bt) in
Perform_action_in_lwt.(schedule Action.(On_termination (fut, res, trigger)));
Trigger.await trigger |> Option.iter Exn_bt.raise;
Exn_bt.unwrap !res
let run_in_lwt f : _ M.Fut.t =
let fut, prom = M.Fut.make () in

View file

@ -4,4 +4,9 @@
(private_modules common_)
(enabled_if
(>= %{ocaml_version} 5.0))
(libraries moonpool moonpool.fib lwt lwt.unix))
(libraries
(re_export moonpool)
(re_export moonpool.fib)
picos
(re_export lwt)
lwt.unix))