refactor: streamline suspend, make most of it 5.0-dependent

This commit is contained in:
Simon Cruanes 2024-02-09 20:56:11 -05:00
parent f7449416e4
commit 712a030206
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
6 changed files with 89 additions and 59 deletions

View file

@ -43,8 +43,9 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
!cur_ls
in
let run_another_task ~name task' =
schedule_ self { f = task'; name; ls = [||] }
let run_another_task ls ~name task' =
let ls' = Array.copy ls in
schedule_ self { f = task'; name; ls = ls' }
in
let run_task (task : task_full) : unit =
@ -52,12 +53,21 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let _ctx = before_task runner in
cur_span := Tracing_.enter_span task.name;
let resume ~ls k res =
let resume ls k res =
schedule_ self { f = (fun () -> k res); name = task.name; ls }
in
(* run the task now, catching errors, handling effects *)
(try Suspend_.with_suspend task.f ~run:run_another_task ~resume ~on_suspend
(try
[@@@ifge 5.0]
Suspend_.with_suspend (WSH {
run=run_another_task;
resume;
on_suspend;
}) task.f
[@@@else_]
task.f()
[@@@endif]
with e ->
let bt = Printexc.get_raw_backtrace () in
on_exn e bt);

View file

@ -437,11 +437,11 @@ let await (fut : 'a t) : 'a =
Suspend_.suspend
{
Suspend_.handle =
(fun ~ls ~run:_ ~resume k ->
(fun ~run:_ ~resume k ->
on_result fut (function
| Ok _ ->
(* schedule continuation with the same name *)
resume ~ls k (Ok ())
resume k (Ok ())
| Error (exn, bt) ->
(* fail continuation immediately *)
k (Error (exn, bt))));

View file

@ -4,17 +4,17 @@ module A = Atomic_
type suspension = unit Exn_bt.result -> unit
type task = unit -> unit
[@@@ifge 5.0]
type suspension_handler = {
handle:
ls:task_ls ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
resume:(suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
}
[@@unboxed]
[@@@ifge 5.0]
[@@@ocaml.alert "-unstable"]
type _ Effect.t +=
@ -24,9 +24,18 @@ type _ Effect.t +=
let[@inline] yield () = Effect.perform Yield
let[@inline] suspend h = Effect.perform (Suspend h)
let with_suspend ~on_suspend ~(run : name:string -> task -> unit)
~(resume : ls:task_ls -> suspension -> unit Exn_bt.result -> unit)
(f : unit -> unit) : unit =
type with_suspend_handler =
| WSH : {
on_suspend: unit -> 'state;
(** on_suspend called when [f()] suspends itself. *)
run: 'state -> name:string -> task -> unit;
(** run used to schedule new tasks *)
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
(** resume run the suspension. Must be called exactly once. *)
}
-> with_suspend_handler
let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit =
let module E = Effect.Deep in
(* effect handler *)
let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option =
@ -35,22 +44,22 @@ let with_suspend ~on_suspend ~(run : name:string -> task -> unit)
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some
(fun k ->
let ls = on_suspend () in
let state = on_suspend () in
let k' : suspension = function
| Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in
h.handle ~ls ~run ~resume k')
h.handle ~run:(run state) ~resume:(resume state) k')
| Yield ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some
(fun k ->
let ls = on_suspend () in
let state = on_suspend () in
let k' : suspension = function
| Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in
resume ~ls k' (Ok ()))
resume state k' @@ Ok ())
| _ -> None
in
@ -59,15 +68,14 @@ let with_suspend ~on_suspend ~(run : name:string -> task -> unit)
(* DLA interop *)
let prepare_for_await () : Dla_.t =
(* current state *)
let st : (_ * _ * suspension) option A.t = A.make None in
let st : (_ * suspension) option A.t = A.make None in
let release () : unit =
match A.exchange st None with
| None -> ()
| Some (ls, resume, k) -> resume ~ls k @@ Ok ()
| Some (resume, k) -> resume k @@ Ok ()
and await () : unit =
suspend
{ handle = (fun ~ls ~run:_ ~resume k -> A.set st (Some (ls, resume, k))) }
suspend { handle = (fun ~run:_ ~resume k -> A.set st (Some (resume, k))) }
in
let t = { Dla_.release; await } in

View file

@ -8,13 +8,14 @@ open Types_
type suspension = unit Exn_bt.result -> unit
(** A suspended computation *)
[@@@ifge 5.0]
type task = unit -> unit
type suspension_handler = {
handle:
ls:task_ls ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
resume:(suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
}
@ -28,6 +29,8 @@ type suspension_handler = {
eventually);
- a [run] function that can be used to start tasks to perform some
computation.
- a [resume] function to resume the suspended computation. This
must be called exactly once, in all situations.
This means that a fork-join primitive, for example, can use a single call
to {!suspend} to:
@ -37,9 +40,9 @@ type suspension_handler = {
runs in parallel with the other calls. The calls must coordinate so
that, once they are all done, the suspended caller is resumed with the
aggregated result of the computation.
- use [resume] exactly
*)
[@@@ifge 5.0]
[@@@ocaml.alert "-unstable"]
type _ Effect.t +=
@ -63,30 +66,29 @@ val suspend : suspension_handler -> unit
and a task runner function.
*)
type with_suspend_handler =
| WSH : {
on_suspend: unit -> 'state;
(** on_suspend called when [f()] suspends itself. *)
run: 'state -> name:string -> task -> unit;
(** run used to schedule new tasks *)
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
(** resume run the suspension. Must be called exactly once. *)
}
-> with_suspend_handler
val with_suspend : with_suspend_handler -> (unit -> unit) -> unit
(** [with_suspend wsh f]
runs [f()] in an environment where [suspend] will work.
If [f()] suspends with suspension handler [h],
this calls [wsh.on_suspend()] to capture the current state [st].
Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension.
The suspension should always be passed exactly once to
[resume]. [run] should be used to start other tasks.
*)
[@@@endif]
val prepare_for_await : unit -> Dla_.t
(** Our stub for DLA. Unstable. *)
val with_suspend :
on_suspend:(unit -> task_ls) ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
(unit -> unit) ->
unit
(** [with_suspend ~name ~on_suspend ~run ~resume f]
runs [f()] in an environment where [suspend]
will work (on OCaml 5) or do nothing (on OCaml 4.xx).
If [f()] suspends with suspension handler [h],
this calls [h ~run ~resume k] where [k] is the suspension.
The suspension should always be passed exactly once to
[resume]. [run] should be used to start other tasks.
@param on_suspend called when [f()] suspends itself.
@param name used for tracing, if not [""].
@param run used to schedule new tasks
@param resume run the suspension. Must be called exactly once.
This will not do anything on OCaml 4.x.
*)

View file

@ -20,6 +20,8 @@ type task_full = {
ls: task_ls;
}
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type worker_state = {
pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t;
@ -32,8 +34,6 @@ type worker_state = {
allowed to push into the queue, but other workers
can come and steal from it if they're idle. *)
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type state = {
id_: Id.t;
active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
@ -125,12 +125,13 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
!(w.cur_ls)
in
let run_another_task ~name task' =
let run_another_task ls ~name task' =
let w = find_current_worker_ () in
schedule_task_ self w ~name ~ls:[||] task'
let ls' = Array.copy ls in
schedule_task_ self w ~name ~ls:ls' task'
in
let resume ~ls k r =
let resume ls k r =
let w = find_current_worker_ () in
schedule_task_ self w ~name ~ls (fun () -> k r)
in
@ -138,10 +139,19 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
(* run the task now, catching errors *)
(try
(* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:run_another_task ~resume ~on_suspend
[@@@ifge 5.0]
Suspend_.with_suspend (WSH {
on_suspend;
run=run_another_task;
resume;
}) task
[@@@else_]
task ()
[@@@endif]
with e ->
let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt);
exit_span_ ();
after_task runner _ctx;
w.cur_ls := [||]

View file

@ -48,7 +48,7 @@ module State_ = struct
Suspend_.suspend
{
Suspend_.handle =
(fun ~ls ~run:_ ~resume suspension ->
(fun ~run:_ ~resume suspension ->
while
let old_st = A.get self in
match old_st with
@ -59,7 +59,7 @@ module State_ = struct
| Left_solved left ->
(* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right));
resume ~ls suspension (Ok ());
resume suspension (Ok ());
false
| Right_solved _ | Both_solved _ -> assert false
do
@ -113,19 +113,19 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
max 1 (1 + (n / Moonpool.Private.num_domains ()))
in
let start_tasks ~ls ~run ~resume (suspension : Suspend_.suspension) =
let start_tasks ~run ~resume (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with
| () ->
if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *)
resume ~ls suspension (Ok ())
resume suspension (Ok ())
| exception exn ->
let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then
(* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *)
resume ~ls suspension (Error (exn, bt))
resume suspension (Error (exn, bt))
in
let i = ref 0 in
@ -143,9 +143,9 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
Suspend_.suspend
{
Suspend_.handle =
(fun ~ls ~run ~resume suspension ->
(fun ~run ~resume suspension ->
(* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~ls ~resume suspension);
start_tasks ~run ~resume suspension);
}
)