From 712a030206f9b0913de3582d4740726f7591815a Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Fri, 9 Feb 2024 20:56:11 -0500 Subject: [PATCH] refactor: streamline `suspend`, make most of it 5.0-dependent --- src/core/fifo_pool.ml | 18 ++++++++--- src/core/fut.ml | 4 +-- src/core/suspend_.ml | 36 +++++++++++++-------- src/core/suspend_.mli | 54 ++++++++++++++++--------------- src/core/ws_pool.ml | 22 +++++++++---- src/forkjoin/moonpool_forkjoin.ml | 14 ++++---- 6 files changed, 89 insertions(+), 59 deletions(-) diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index 6f7b3700..1095d16e 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -43,8 +43,9 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = !cur_ls in - let run_another_task ~name task' = - schedule_ self { f = task'; name; ls = [||] } + let run_another_task ls ~name task' = + let ls' = Array.copy ls in + schedule_ self { f = task'; name; ls = ls' } in let run_task (task : task_full) : unit = @@ -52,12 +53,21 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = let _ctx = before_task runner in cur_span := Tracing_.enter_span task.name; - let resume ~ls k res = + let resume ls k res = schedule_ self { f = (fun () -> k res); name = task.name; ls } in (* run the task now, catching errors, handling effects *) - (try Suspend_.with_suspend task.f ~run:run_another_task ~resume ~on_suspend + (try +[@@@ifge 5.0] + Suspend_.with_suspend (WSH { + run=run_another_task; + resume; + on_suspend; + }) task.f +[@@@else_] + task.f() +[@@@endif] with e -> let bt = Printexc.get_raw_backtrace () in on_exn e bt); diff --git a/src/core/fut.ml b/src/core/fut.ml index afe7dc39..6d8d264d 100644 --- a/src/core/fut.ml +++ b/src/core/fut.ml @@ -437,11 +437,11 @@ let await (fut : 'a t) : 'a = Suspend_.suspend { Suspend_.handle = - (fun ~ls ~run:_ ~resume k -> + (fun ~run:_ ~resume k -> on_result fut (function | Ok _ -> (* schedule continuation with the same name *) - resume ~ls k (Ok ()) + resume k (Ok ()) | Error (exn, bt) -> (* fail continuation immediately *) k (Error (exn, bt)))); diff --git a/src/core/suspend_.ml b/src/core/suspend_.ml index fb02bc3a..cb4293a8 100644 --- a/src/core/suspend_.ml +++ b/src/core/suspend_.ml @@ -4,17 +4,17 @@ module A = Atomic_ type suspension = unit Exn_bt.result -> unit type task = unit -> unit +[@@@ifge 5.0] + type suspension_handler = { handle: - ls:task_ls -> run:(name:string -> task -> unit) -> - resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) -> + resume:(suspension -> unit Exn_bt.result -> unit) -> suspension -> unit; } [@@unboxed] -[@@@ifge 5.0] [@@@ocaml.alert "-unstable"] type _ Effect.t += @@ -24,9 +24,18 @@ type _ Effect.t += let[@inline] yield () = Effect.perform Yield let[@inline] suspend h = Effect.perform (Suspend h) -let with_suspend ~on_suspend ~(run : name:string -> task -> unit) - ~(resume : ls:task_ls -> suspension -> unit Exn_bt.result -> unit) - (f : unit -> unit) : unit = +type with_suspend_handler = + | WSH : { + on_suspend: unit -> 'state; + (** on_suspend called when [f()] suspends itself. *) + run: 'state -> name:string -> task -> unit; + (** run used to schedule new tasks *) + resume: 'state -> suspension -> unit Exn_bt.result -> unit; + (** resume run the suspension. Must be called exactly once. *) + } + -> with_suspend_handler + +let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit = let module E = Effect.Deep in (* effect handler *) let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = @@ -35,22 +44,22 @@ let with_suspend ~on_suspend ~(run : name:string -> task -> unit) (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) Some (fun k -> - let ls = on_suspend () in + let state = on_suspend () in let k' : suspension = function | Ok () -> E.continue k () | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt in - h.handle ~ls ~run ~resume k') + h.handle ~run:(run state) ~resume:(resume state) k') | Yield -> (* TODO: discontinue [k] if current fiber (if any) is cancelled? *) Some (fun k -> - let ls = on_suspend () in + let state = on_suspend () in let k' : suspension = function | Ok () -> E.continue k () | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt in - resume ~ls k' (Ok ())) + resume state k' @@ Ok ()) | _ -> None in @@ -59,15 +68,14 @@ let with_suspend ~on_suspend ~(run : name:string -> task -> unit) (* DLA interop *) let prepare_for_await () : Dla_.t = (* current state *) - let st : (_ * _ * suspension) option A.t = A.make None in + let st : (_ * suspension) option A.t = A.make None in let release () : unit = match A.exchange st None with | None -> () - | Some (ls, resume, k) -> resume ~ls k @@ Ok () + | Some (resume, k) -> resume k @@ Ok () and await () : unit = - suspend - { handle = (fun ~ls ~run:_ ~resume k -> A.set st (Some (ls, resume, k))) } + suspend { handle = (fun ~run:_ ~resume k -> A.set st (Some (resume, k))) } in let t = { Dla_.release; await } in diff --git a/src/core/suspend_.mli b/src/core/suspend_.mli index bd922f41..1fff43ac 100644 --- a/src/core/suspend_.mli +++ b/src/core/suspend_.mli @@ -8,13 +8,14 @@ open Types_ type suspension = unit Exn_bt.result -> unit (** A suspended computation *) +[@@@ifge 5.0] + type task = unit -> unit type suspension_handler = { handle: - ls:task_ls -> run:(name:string -> task -> unit) -> - resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) -> + resume:(suspension -> unit Exn_bt.result -> unit) -> suspension -> unit; } @@ -28,6 +29,8 @@ type suspension_handler = { eventually); - a [run] function that can be used to start tasks to perform some computation. + - a [resume] function to resume the suspended computation. This + must be called exactly once, in all situations. This means that a fork-join primitive, for example, can use a single call to {!suspend} to: @@ -37,9 +40,9 @@ type suspension_handler = { runs in parallel with the other calls. The calls must coordinate so that, once they are all done, the suspended caller is resumed with the aggregated result of the computation. + - use [resume] exactly *) -[@@@ifge 5.0] [@@@ocaml.alert "-unstable"] type _ Effect.t += @@ -63,30 +66,29 @@ val suspend : suspension_handler -> unit and a task runner function. *) +type with_suspend_handler = + | WSH : { + on_suspend: unit -> 'state; + (** on_suspend called when [f()] suspends itself. *) + run: 'state -> name:string -> task -> unit; + (** run used to schedule new tasks *) + resume: 'state -> suspension -> unit Exn_bt.result -> unit; + (** resume run the suspension. Must be called exactly once. *) + } + -> with_suspend_handler + +val with_suspend : with_suspend_handler -> (unit -> unit) -> unit +(** [with_suspend wsh f] + runs [f()] in an environment where [suspend] will work. + + If [f()] suspends with suspension handler [h], + this calls [wsh.on_suspend()] to capture the current state [st]. + Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension. + The suspension should always be passed exactly once to + [resume]. [run] should be used to start other tasks. +*) + [@@@endif] val prepare_for_await : unit -> Dla_.t (** Our stub for DLA. Unstable. *) - -val with_suspend : - on_suspend:(unit -> task_ls) -> - run:(name:string -> task -> unit) -> - resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) -> - (unit -> unit) -> - unit -(** [with_suspend ~name ~on_suspend ~run ~resume f] - runs [f()] in an environment where [suspend] - will work (on OCaml 5) or do nothing (on OCaml 4.xx). - - If [f()] suspends with suspension handler [h], - this calls [h ~run ~resume k] where [k] is the suspension. - The suspension should always be passed exactly once to - [resume]. [run] should be used to start other tasks. - - @param on_suspend called when [f()] suspends itself. - @param name used for tracing, if not [""]. - @param run used to schedule new tasks - @param resume run the suspension. Must be called exactly once. - - This will not do anything on OCaml 4.x. -*) diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index 6da8e31a..bb84fe75 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -20,6 +20,8 @@ type task_full = { ls: task_ls; } +type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task + type worker_state = { pool_id_: Id.t; (** Unique per pool *) mutable thread: Thread.t; @@ -32,8 +34,6 @@ type worker_state = { allowed to push into the queue, but other workers can come and steal from it if they're idle. *) -type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task - type state = { id_: Id.t; active: bool A.t; (** Becomes [false] when the pool is shutdown. *) @@ -125,12 +125,13 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task : !(w.cur_ls) in - let run_another_task ~name task' = + let run_another_task ls ~name task' = let w = find_current_worker_ () in - schedule_task_ self w ~name ~ls:[||] task' + let ls' = Array.copy ls in + schedule_task_ self w ~name ~ls:ls' task' in - let resume ~ls k r = + let resume ls k r = let w = find_current_worker_ () in schedule_task_ self w ~name ~ls (fun () -> k r) in @@ -138,10 +139,19 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task : (* run the task now, catching errors *) (try (* run [task()] and handle [suspend] in it *) - Suspend_.with_suspend task ~run:run_another_task ~resume ~on_suspend +[@@@ifge 5.0] + Suspend_.with_suspend (WSH { + on_suspend; + run=run_another_task; + resume; + }) task +[@@@else_] + task () +[@@@endif] with e -> let bt = Printexc.get_raw_backtrace () in self.on_exn e bt); + exit_span_ (); after_task runner _ctx; w.cur_ls := [||] diff --git a/src/forkjoin/moonpool_forkjoin.ml b/src/forkjoin/moonpool_forkjoin.ml index 01b7a7c2..27aa1984 100644 --- a/src/forkjoin/moonpool_forkjoin.ml +++ b/src/forkjoin/moonpool_forkjoin.ml @@ -48,7 +48,7 @@ module State_ = struct Suspend_.suspend { Suspend_.handle = - (fun ~ls ~run:_ ~resume suspension -> + (fun ~run:_ ~resume suspension -> while let old_st = A.get self in match old_st with @@ -59,7 +59,7 @@ module State_ = struct | Left_solved left -> (* other thread is done, no risk of race condition *) A.set self (Both_solved (left, right)); - resume ~ls suspension (Ok ()); + resume suspension (Ok ()); false | Right_solved _ | Both_solved _ -> assert false do @@ -113,19 +113,19 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = max 1 (1 + (n / Moonpool.Private.num_domains ())) in - let start_tasks ~ls ~run ~resume (suspension : Suspend_.suspension) = + let start_tasks ~run ~resume (suspension : Suspend_.suspension) = let task_for ~offset ~len_range = match f offset (offset + len_range - 1) with | () -> if A.fetch_and_add missing (-len_range) = len_range then (* all tasks done successfully *) - resume ~ls suspension (Ok ()) + resume suspension (Ok ()) | exception exn -> let bt = Printexc.get_raw_backtrace () in if not (A.exchange has_failed true) then (* first one to fail, and [missing] must be >= 2 because we're not decreasing it. *) - resume ~ls suspension (Error (exn, bt)) + resume suspension (Error (exn, bt)) in let i = ref 0 in @@ -143,9 +143,9 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit = Suspend_.suspend { Suspend_.handle = - (fun ~ls ~run ~resume suspension -> + (fun ~run ~resume suspension -> (* run tasks, then we'll resume [suspension] *) - start_tasks ~run ~ls ~resume suspension); + start_tasks ~run ~resume suspension); } )