moonpool/src/fib/fiber.ml

332 lines
9.3 KiB
OCaml

open Moonpool.Private.Types_
module A = Atomic
module FM = Handle.Map
module Int_map = Map.Make (Int)
module PF = Picos.Fiber
module FLS = Picos.Fiber.FLS
type 'a callback = 'a Exn_bt.result -> unit
(** Callbacks that are called when a fiber is done. *)
type cancel_callback = Exn_bt.t -> unit
let prom_of_fut : 'a Fut.t -> 'a Fut.promise =
Fut.Private_.unsafe_promise_of_fut
(* TODO: replace with picos structured at some point? *)
module Private_ = struct
type pfiber = PF.t
type 'a t = {
id: Handle.t; (** unique identifier for this fiber *)
state: 'a state A.t; (** Current state in the lifetime of the fiber *)
res: 'a Fut.t;
runner: Runner.t;
pfiber: pfiber; (** Associated picos fiber *)
}
and 'a state =
| Alive of {
children: children;
on_cancel: cancel_callback Int_map.t;
cancel_id: int;
}
| Terminating_or_done of 'a Exn_bt.result A.t
and children = any FM.t
and any = Any : _ t -> any [@@unboxed]
(** Key to access the current moonpool.fiber. *)
let k_current_fiber : any FLS.t = FLS.create ()
exception Not_set = FLS.Not_set
let[@inline] get_cur_from_exn (pfiber : pfiber) : any =
FLS.get_exn pfiber k_current_fiber
let[@inline] get_cur_exn () : any =
get_cur_from_exn @@ get_current_fiber_exn ()
let[@inline] get_cur_opt () = try Some (get_cur_exn ()) with _ -> None
let[@inline] is_closed (self : _ t) =
match A.get self.state with
| Alive _ -> false
| Terminating_or_done _ -> true
end
include Private_
let create_ ~pfiber ~runner () : 'a t =
let id = Handle.generate_fresh () in
let res, _ = Fut.make () in
{
state =
A.make
@@ Alive { children = FM.empty; on_cancel = Int_map.empty; cancel_id = 0 };
id;
res;
runner;
pfiber;
}
let create_done_ ~res () : _ t =
let id = Handle.generate_fresh () in
{
state =
A.make
@@ Alive { children = FM.empty; on_cancel = Int_map.empty; cancel_id = 0 };
id;
res;
runner = Runner.dummy;
pfiber = Moonpool.Private.Types_._dummy_fiber;
}
let[@inline] return x = create_done_ ~res:(Fut.return x) ()
let[@inline] fail ebt = create_done_ ~res:(Fut.fail_exn_bt ebt) ()
let[@inline] res self = self.res
let[@inline] peek self = Fut.peek self.res
let[@inline] is_done self = Fut.is_done self.res
let[@inline] is_success self = Fut.is_success self.res
let[@inline] is_cancelled self = Fut.is_failed self.res
let[@inline] on_result (self : _ t) f = Fut.on_result self.res f
let[@inline] await self = Fut.await self.res
let[@inline] wait_block self = Fut.wait_block self.res
let[@inline] wait_block_exn self = Fut.wait_block_exn self.res
(** Resolve [promise] once [children] are all done *)
let resolve_once_children_are_done_ ~children ~promise
(res : 'a Exn_bt.result A.t) : unit =
let n_children = FM.cardinal children in
if n_children > 0 then (
(* wait for all children to be done *)
let n_waiting = A.make (FM.cardinal children) in
let on_child_finish (r : _ result) =
(* make sure the parent fails if any child fails *)
(match r with
| Ok _ -> ()
| Error ebt -> A.set res (Error ebt));
(* if we're the last to finish, resolve the parent fiber's [res] *)
if A.fetch_and_add n_waiting (-1) = 1 then (
let res = A.get res in
Fut.fulfill promise res
)
in
FM.iter (fun _ (Any f) -> Fut.on_result f.res on_child_finish) children
) else
Fut.fulfill promise @@ A.get res
let rec resolve_as_failed_ : type a. a t -> Exn_bt.t -> unit =
fun self ebt ->
let promise = prom_of_fut self.res in
while
match A.get self.state with
| Alive { children; cancel_id = _; on_cancel } as old ->
let new_st = Terminating_or_done (A.make @@ Error ebt) in
if A.compare_and_set self.state old new_st then (
(* here, unlike in {!resolve_fiber}, we immediately cancel children *)
cancel_children_ ~children ebt;
Int_map.iter (fun _ cb -> cb ebt) on_cancel;
resolve_once_children_are_done_ ~children ~promise (A.make @@ Error ebt);
false
) else
true
| Terminating_or_done _ -> false
do
()
done
(** Cancel eagerly all children *)
and cancel_children_ ebt ~children : unit =
FM.iter (fun _ (Any f) -> resolve_as_failed_ f ebt) children
type cancel_handle = int
let add_on_cancel (self : _ t) cb : cancel_handle =
let h = ref 0 in
while
match A.get self.state with
| Alive { children; cancel_id; on_cancel } as old ->
let new_st =
Alive
{
children;
cancel_id = cancel_id + 1;
on_cancel = Int_map.add cancel_id cb on_cancel;
}
in
if A.compare_and_set self.state old new_st then (
h := cancel_id;
false
) else
true
| Terminating_or_done r ->
(match A.get r with
| Error ebt -> cb ebt
| Ok _ -> ());
false
do
()
done;
!h
let remove_on_cancel (self : _ t) h =
while
match A.get self.state with
| Alive ({ on_cancel; _ } as alive) as old ->
let new_st =
Alive { alive with on_cancel = Int_map.remove h on_cancel }
in
not (A.compare_and_set self.state old new_st)
| Terminating_or_done _ -> false
do
()
done
let with_on_cancel (self : _ t) cb (k : unit -> 'a) : 'a =
let h = add_on_cancel self cb in
Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
(** Successfully resolve the fiber. This might still fail if
some children failed. *)
let resolve_ok_ (self : 'a t) (r : 'a) : unit =
let r = A.make @@ Ok r in
let promise = prom_of_fut self.res in
while
match A.get self.state with
| Alive { children; _ } as old ->
let new_st = Terminating_or_done r in
if A.compare_and_set self.state old new_st then (
resolve_once_children_are_done_ ~children ~promise r;
false
) else
true
| Terminating_or_done _ -> false
do
()
done
let remove_child_ (self : _ t) (child : _ t) =
while
match A.get self.state with
| Alive ({ children; _ } as alive) as old ->
let new_st =
Alive { alive with children = FM.remove child.id children }
in
not (A.compare_and_set self.state old new_st)
| _ -> false
do
()
done
(** Add a child to [self].
@param protected if true, the child's failure will not affect [self]. *)
let add_child_ ~protect (self : _ t) (child : _ t) =
while
match A.get self.state with
| Alive ({ children; _ } as alive) as old ->
let new_st =
Alive { alive with children = FM.add child.id (Any child) children }
in
if A.compare_and_set self.state old new_st then (
(* make sure to remove [child] from [self.children] once it's done;
fail [self] is [child] failed and [protect=false] *)
Fut.on_result child.res (function
| Ok _ -> remove_child_ self child
| Error ebt ->
(* child failed, we must fail too *)
remove_child_ self child;
if not protect then resolve_as_failed_ self ebt);
false
) else
true
| Terminating_or_done r ->
(match A.get r with
| Error ebt ->
(* cancel child immediately *)
resolve_as_failed_ child ebt
| Ok _ -> ());
false
do
()
done
let spawn_ ~parent ~runner (f : unit -> 'a) : 'a t =
let comp = Picos.Computation.create () in
let pfiber = PF.create ~forbid:false comp in
(* copy local hmap from parent, if present *)
Option.iter
(fun (p : _ t) -> Fls.Private_hmap_ls_.copy_fls p.pfiber pfiber)
parent;
(match parent with
| Some p when is_closed p -> failwith "spawn: nursery is closed"
| _ -> ());
let fib = create_ ~pfiber ~runner () in
let run () =
(* make sure the fiber is accessible from inside itself *)
FLS.set pfiber k_current_fiber (Any fib);
try
let res = f () in
resolve_ok_ fib res
with exn ->
let bt = Printexc.get_raw_backtrace () in
let ebt = Exn_bt.make exn bt in
resolve_as_failed_ fib ebt
in
Runner.run_async ~fiber:pfiber runner run;
fib
let spawn_top ~on f : _ t = spawn_ ~runner:on ~parent:None f
let spawn ?on ?(protect = true) f : _ t =
(* spawn [f()] with a copy of our local storage *)
let (Any p) =
try get_cur_exn ()
with Not_set ->
failwith "Fiber.spawn: must be run from within another fiber."
in
let runner =
match on with
| Some r -> r
| None -> p.runner
in
let child = spawn_ ~parent:(Some p) ~runner f in
add_child_ ~protect p child;
child
let[@inline] spawn_ignore ?protect f : unit = ignore (spawn ?protect f : _ t)
let[@inline] self () : any =
match get_cur_exn () with
| exception Not_set -> failwith "Fiber.self: must be run from inside a fiber."
| f -> f
let with_on_self_cancel cb (k : unit -> 'a) : 'a =
let (Any self) = self () in
let h = add_on_cancel self cb in
Fun.protect k ~finally:(fun () -> remove_on_cancel self h)
let[@inline] check_if_cancelled_ (self : _ t) = PF.check self.pfiber
let check_if_cancelled () =
match get_cur_exn () with
| exception Not_set ->
failwith "Fiber.check_if_cancelled: must be run from inside a fiber."
| Any self -> check_if_cancelled_ self
let yield () : unit =
match get_cur_exn () with
| exception Not_set ->
failwith "Fiber.yield: must be run from inside a fiber."
| Any self ->
check_if_cancelled_ self;
PF.yield ();
check_if_cancelled_ self