From 39cdc376134f04665c8db30a0ea799e13ce5c10f Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 28 Feb 2024 22:50:19 -0500 Subject: [PATCH] feat fiber: expose add_on_cancel/remove_on_cancel also make it more robust by using a map --- src/fib/fiber.ml | 65 +++++++++++++++++++++++++++++++---------------- src/fib/fiber.mli | 11 ++++++++ src/fib/fls.mli | 8 ++++++ 3 files changed, 62 insertions(+), 22 deletions(-) diff --git a/src/fib/fiber.ml b/src/fib/fiber.ml index 6d888c8b..ddc47651 100644 --- a/src/fib/fiber.ml +++ b/src/fib/fiber.ml @@ -1,5 +1,6 @@ module A = Atomic module FM = Handle.Map +module Int_map = Map.Make (Int) type 'a callback = 'a Exn_bt.result -> unit (** Callbacks that are called when a fiber is done. *) @@ -20,7 +21,8 @@ module Private_ = struct and 'a state = | Alive of { children: children; - on_cancel: cancel_callback list; + on_cancel: cancel_callback Int_map.t; + cancel_id: int; } | Terminating_or_done of 'a Exn_bt.result A.t @@ -71,12 +73,12 @@ let rec resolve_as_failed_ : type a. a t -> Exn_bt.t -> unit = let promise = prom_of_fut self.res in while match A.get self.state with - | Alive { children; on_cancel } as old -> + | Alive { children; cancel_id = _; on_cancel } as old -> let new_st = Terminating_or_done (A.make @@ Error ebt) in if A.compare_and_set self.state old new_st then ( (* here, unlike in {!resolve_fiber}, we immediately cancel children *) cancel_children_ ~children ebt; - List.iter (fun cb -> cb ebt) on_cancel; + Int_map.iter (fun _ cb -> cb ebt) on_cancel; resolve_once_children_are_done_ ~children ~promise (A.make @@ Error ebt); false ) else @@ -96,7 +98,7 @@ let resolve_ok_ (self : 'a t) (r : 'a) : unit = let promise = prom_of_fut self.res in while match A.get self.state with - | Alive { children; on_cancel = _ } as old -> + | Alive { children; _ } as old -> let new_st = Terminating_or_done r in if A.compare_and_set self.state old new_st then ( resolve_once_children_are_done_ ~children ~promise r; @@ -111,9 +113,9 @@ let resolve_ok_ (self : 'a t) (r : 'a) : unit = let remove_child_ (self : _ t) (child : _ t) = while match A.get self.state with - | Alive { children; on_cancel } as old -> + | Alive ({ children; _ } as alive) as old -> let new_st = - Alive { children = FM.remove child.id children; on_cancel } + Alive { alive with children = FM.remove child.id children } in not (A.compare_and_set self.state old new_st) | _ -> false @@ -126,9 +128,9 @@ let remove_child_ (self : _ t) (child : _ t) = let add_child_ ~protect (self : _ t) (child : _ t) = while match A.get self.state with - | Alive { children; on_cancel } as old -> + | Alive ({ children; _ } as alive) as old -> let new_st = - Alive { children = FM.add child.id (Any child) children; on_cancel } + Alive { alive with children = FM.add child.id (Any child) children } in if A.compare_and_set self.state old new_st then ( @@ -159,7 +161,10 @@ let spawn_ ~ls ~on (f : _ -> 'a) : 'a t = let res, _promise = Fut.make () in let fib = { - state = A.make @@ Alive { children = FM.empty; on_cancel = [] }; + state = + A.make + @@ Alive + { children = FM.empty; on_cancel = Int_map.empty; cancel_id = 0 }; id; res; runner = on; @@ -199,12 +204,26 @@ let spawn_link ~protect f : _ t = add_child_ ~protect parent child; child -let add_cancel_cb_ (self : _ t) cb = +type cancel_handle = int + +let add_on_cancel (self : _ t) cb : cancel_handle = + let h = ref 0 in while match A.get self.state with - | Alive { children; on_cancel } as old -> - let new_st = Alive { children; on_cancel = cb :: on_cancel } in - not (A.compare_and_set self.state old new_st) + | Alive { children; cancel_id; on_cancel } as old -> + let new_st = + Alive + { + children; + cancel_id = cancel_id + 1; + on_cancel = Int_map.add cancel_id cb on_cancel; + } + in + if A.compare_and_set self.state old new_st then ( + h := cancel_id; + false + ) else + true | Terminating_or_done r -> (match A.get r with | Error ebt -> cb ebt @@ -212,14 +231,16 @@ let add_cancel_cb_ (self : _ t) cb = false do () - done + done; + !h -let remove_top_cancel_cb_ (self : _ t) = +let remove_on_cancel (self : _ t) h = while match A.get self.state with - | Alive { on_cancel = []; _ } -> assert false - | Alive { children; on_cancel = _ :: tl } as old -> - let new_st = Alive { children; on_cancel = tl } in + | Alive ({ on_cancel; _ } as alive) as old -> + let new_st = + Alive { alive with on_cancel = Int_map.remove h on_cancel } + in not (A.compare_and_set self.state old new_st) | Terminating_or_done _ -> false do @@ -227,13 +248,13 @@ let remove_top_cancel_cb_ (self : _ t) = done let with_cancel_callback (self : _ t) cb (k : unit -> 'a) : 'a = - add_cancel_cb_ self cb; - Fun.protect k ~finally:(fun () -> remove_top_cancel_cb_ self) + let h = add_on_cancel self cb in + Fun.protect k ~finally:(fun () -> remove_on_cancel self h) let with_self_cancel_callback cb (k : unit -> 'a) : 'a = let (Any self) = self () in - add_cancel_cb_ self cb; - Fun.protect k ~finally:(fun () -> remove_top_cancel_cb_ self) + let h = add_on_cancel self cb in + Fun.protect k ~finally:(fun () -> remove_on_cancel self h) let[@inline] await self = Fut.await self.res let[@inline] wait_block self = Fut.wait_block self.res diff --git a/src/fib/fiber.mli b/src/fib/fiber.mli index ebbe5e96..c1fbdce7 100644 --- a/src/fib/fiber.mli +++ b/src/fib/fiber.mli @@ -81,6 +81,17 @@ val yield : unit -> unit (** Yield control to the scheduler from the current fiber. @raise Failure if not run from inside a fiber. *) +type cancel_handle +(** An opaque handle for a single cancel callback in a fiber *) + +val add_on_cancel : _ t -> cancel_callback -> cancel_handle +(** [add_on_cancel fib cb] adds [cb] to the list of cancel callbacks + for [fib]. If [fib] is already cancelled, [cb] is called immediately. *) + +val remove_on_cancel : _ t -> cancel_handle -> unit +(** [remove_on_cancel fib h] removes the cancel callback + associated with handle [h]. *) + val with_cancel_callback : _ t -> cancel_callback -> (unit -> 'a) -> 'a (** [with_cancel_callback fib cb (fun () -> )] evaluates [e] in a scope in which, if the fiber [fib] is cancelled, diff --git a/src/fib/fls.mli b/src/fib/fls.mli index ccd0d2ee..97bb450f 100644 --- a/src/fib/fls.mli +++ b/src/fib/fls.mli @@ -3,6 +3,14 @@ This storage is associated to the current fiber, just like thread-local storage is associated with the current thread. + + See {!Moonpool.Task_local_storage} for more general information, as + this is based on it. + + {b NOTE}: it's important to note that, while each fiber + has its own storage, spawning a sub-fiber [f2] from a fiber [f1] + will only do a shallow copy of the storage. + Values inside [f1]'s storage will be physically shared with [f2]. *) include module type of struct