refactor: streamline suspend, make most of it 5.0-dependent

This commit is contained in:
Simon Cruanes 2024-02-09 20:56:11 -05:00
parent f7449416e4
commit 712a030206
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
6 changed files with 89 additions and 59 deletions

View file

@ -43,8 +43,9 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
!cur_ls !cur_ls
in in
let run_another_task ~name task' = let run_another_task ls ~name task' =
schedule_ self { f = task'; name; ls = [||] } let ls' = Array.copy ls in
schedule_ self { f = task'; name; ls = ls' }
in in
let run_task (task : task_full) : unit = let run_task (task : task_full) : unit =
@ -52,12 +53,21 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let _ctx = before_task runner in let _ctx = before_task runner in
cur_span := Tracing_.enter_span task.name; cur_span := Tracing_.enter_span task.name;
let resume ~ls k res = let resume ls k res =
schedule_ self { f = (fun () -> k res); name = task.name; ls } schedule_ self { f = (fun () -> k res); name = task.name; ls }
in in
(* run the task now, catching errors, handling effects *) (* run the task now, catching errors, handling effects *)
(try Suspend_.with_suspend task.f ~run:run_another_task ~resume ~on_suspend (try
[@@@ifge 5.0]
Suspend_.with_suspend (WSH {
run=run_another_task;
resume;
on_suspend;
}) task.f
[@@@else_]
task.f()
[@@@endif]
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
on_exn e bt); on_exn e bt);

View file

@ -437,11 +437,11 @@ let await (fut : 'a t) : 'a =
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~ls ~run:_ ~resume k -> (fun ~run:_ ~resume k ->
on_result fut (function on_result fut (function
| Ok _ -> | Ok _ ->
(* schedule continuation with the same name *) (* schedule continuation with the same name *)
resume ~ls k (Ok ()) resume k (Ok ())
| Error (exn, bt) -> | Error (exn, bt) ->
(* fail continuation immediately *) (* fail continuation immediately *)
k (Error (exn, bt)))); k (Error (exn, bt))));

View file

@ -4,17 +4,17 @@ module A = Atomic_
type suspension = unit Exn_bt.result -> unit type suspension = unit Exn_bt.result -> unit
type task = unit -> unit type task = unit -> unit
[@@@ifge 5.0]
type suspension_handler = { type suspension_handler = {
handle: handle:
ls:task_ls ->
run:(name:string -> task -> unit) -> run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) -> resume:(suspension -> unit Exn_bt.result -> unit) ->
suspension -> suspension ->
unit; unit;
} }
[@@unboxed] [@@unboxed]
[@@@ifge 5.0]
[@@@ocaml.alert "-unstable"] [@@@ocaml.alert "-unstable"]
type _ Effect.t += type _ Effect.t +=
@ -24,9 +24,18 @@ type _ Effect.t +=
let[@inline] yield () = Effect.perform Yield let[@inline] yield () = Effect.perform Yield
let[@inline] suspend h = Effect.perform (Suspend h) let[@inline] suspend h = Effect.perform (Suspend h)
let with_suspend ~on_suspend ~(run : name:string -> task -> unit) type with_suspend_handler =
~(resume : ls:task_ls -> suspension -> unit Exn_bt.result -> unit) | WSH : {
(f : unit -> unit) : unit = on_suspend: unit -> 'state;
(** on_suspend called when [f()] suspends itself. *)
run: 'state -> name:string -> task -> unit;
(** run used to schedule new tasks *)
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
(** resume run the suspension. Must be called exactly once. *)
}
-> with_suspend_handler
let with_suspend (WSH { on_suspend; run; resume }) (f : unit -> unit) : unit =
let module E = Effect.Deep in let module E = Effect.Deep in
(* effect handler *) (* effect handler *)
let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option =
@ -35,22 +44,22 @@ let with_suspend ~on_suspend ~(run : name:string -> task -> unit)
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *) (* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some Some
(fun k -> (fun k ->
let ls = on_suspend () in let state = on_suspend () in
let k' : suspension = function let k' : suspension = function
| Ok () -> E.continue k () | Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in in
h.handle ~ls ~run ~resume k') h.handle ~run:(run state) ~resume:(resume state) k')
| Yield -> | Yield ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *) (* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some Some
(fun k -> (fun k ->
let ls = on_suspend () in let state = on_suspend () in
let k' : suspension = function let k' : suspension = function
| Ok () -> E.continue k () | Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in in
resume ~ls k' (Ok ())) resume state k' @@ Ok ())
| _ -> None | _ -> None
in in
@ -59,15 +68,14 @@ let with_suspend ~on_suspend ~(run : name:string -> task -> unit)
(* DLA interop *) (* DLA interop *)
let prepare_for_await () : Dla_.t = let prepare_for_await () : Dla_.t =
(* current state *) (* current state *)
let st : (_ * _ * suspension) option A.t = A.make None in let st : (_ * suspension) option A.t = A.make None in
let release () : unit = let release () : unit =
match A.exchange st None with match A.exchange st None with
| None -> () | None -> ()
| Some (ls, resume, k) -> resume ~ls k @@ Ok () | Some (resume, k) -> resume k @@ Ok ()
and await () : unit = and await () : unit =
suspend suspend { handle = (fun ~run:_ ~resume k -> A.set st (Some (resume, k))) }
{ handle = (fun ~ls ~run:_ ~resume k -> A.set st (Some (ls, resume, k))) }
in in
let t = { Dla_.release; await } in let t = { Dla_.release; await } in

View file

@ -8,13 +8,14 @@ open Types_
type suspension = unit Exn_bt.result -> unit type suspension = unit Exn_bt.result -> unit
(** A suspended computation *) (** A suspended computation *)
[@@@ifge 5.0]
type task = unit -> unit type task = unit -> unit
type suspension_handler = { type suspension_handler = {
handle: handle:
ls:task_ls ->
run:(name:string -> task -> unit) -> run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) -> resume:(suspension -> unit Exn_bt.result -> unit) ->
suspension -> suspension ->
unit; unit;
} }
@ -28,6 +29,8 @@ type suspension_handler = {
eventually); eventually);
- a [run] function that can be used to start tasks to perform some - a [run] function that can be used to start tasks to perform some
computation. computation.
- a [resume] function to resume the suspended computation. This
must be called exactly once, in all situations.
This means that a fork-join primitive, for example, can use a single call This means that a fork-join primitive, for example, can use a single call
to {!suspend} to: to {!suspend} to:
@ -37,9 +40,9 @@ type suspension_handler = {
runs in parallel with the other calls. The calls must coordinate so runs in parallel with the other calls. The calls must coordinate so
that, once they are all done, the suspended caller is resumed with the that, once they are all done, the suspended caller is resumed with the
aggregated result of the computation. aggregated result of the computation.
- use [resume] exactly
*) *)
[@@@ifge 5.0]
[@@@ocaml.alert "-unstable"] [@@@ocaml.alert "-unstable"]
type _ Effect.t += type _ Effect.t +=
@ -63,30 +66,29 @@ val suspend : suspension_handler -> unit
and a task runner function. and a task runner function.
*) *)
type with_suspend_handler =
| WSH : {
on_suspend: unit -> 'state;
(** on_suspend called when [f()] suspends itself. *)
run: 'state -> name:string -> task -> unit;
(** run used to schedule new tasks *)
resume: 'state -> suspension -> unit Exn_bt.result -> unit;
(** resume run the suspension. Must be called exactly once. *)
}
-> with_suspend_handler
val with_suspend : with_suspend_handler -> (unit -> unit) -> unit
(** [with_suspend wsh f]
runs [f()] in an environment where [suspend] will work.
If [f()] suspends with suspension handler [h],
this calls [wsh.on_suspend()] to capture the current state [st].
Then [h.handle ~st ~run ~resume k] is called, where [k] is the suspension.
The suspension should always be passed exactly once to
[resume]. [run] should be used to start other tasks.
*)
[@@@endif] [@@@endif]
val prepare_for_await : unit -> Dla_.t val prepare_for_await : unit -> Dla_.t
(** Our stub for DLA. Unstable. *) (** Our stub for DLA. Unstable. *)
val with_suspend :
on_suspend:(unit -> task_ls) ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
(unit -> unit) ->
unit
(** [with_suspend ~name ~on_suspend ~run ~resume f]
runs [f()] in an environment where [suspend]
will work (on OCaml 5) or do nothing (on OCaml 4.xx).
If [f()] suspends with suspension handler [h],
this calls [h ~run ~resume k] where [k] is the suspension.
The suspension should always be passed exactly once to
[resume]. [run] should be used to start other tasks.
@param on_suspend called when [f()] suspends itself.
@param name used for tracing, if not [""].
@param run used to schedule new tasks
@param resume run the suspension. Must be called exactly once.
This will not do anything on OCaml 4.x.
*)

View file

@ -20,6 +20,8 @@ type task_full = {
ls: task_ls; ls: task_ls;
} }
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type worker_state = { type worker_state = {
pool_id_: Id.t; (** Unique per pool *) pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t; mutable thread: Thread.t;
@ -32,8 +34,6 @@ type worker_state = {
allowed to push into the queue, but other workers allowed to push into the queue, but other workers
can come and steal from it if they're idle. *) can come and steal from it if they're idle. *)
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
type state = { type state = {
id_: Id.t; id_: Id.t;
active: bool A.t; (** Becomes [false] when the pool is shutdown. *) active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
@ -125,12 +125,13 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
!(w.cur_ls) !(w.cur_ls)
in in
let run_another_task ~name task' = let run_another_task ls ~name task' =
let w = find_current_worker_ () in let w = find_current_worker_ () in
schedule_task_ self w ~name ~ls:[||] task' let ls' = Array.copy ls in
schedule_task_ self w ~name ~ls:ls' task'
in in
let resume ~ls k r = let resume ls k r =
let w = find_current_worker_ () in let w = find_current_worker_ () in
schedule_task_ self w ~name ~ls (fun () -> k r) schedule_task_ self w ~name ~ls (fun () -> k r)
in in
@ -138,10 +139,19 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
(* run the task now, catching errors *) (* run the task now, catching errors *)
(try (try
(* run [task()] and handle [suspend] in it *) (* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend task ~run:run_another_task ~resume ~on_suspend [@@@ifge 5.0]
Suspend_.with_suspend (WSH {
on_suspend;
run=run_another_task;
resume;
}) task
[@@@else_]
task ()
[@@@endif]
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt); self.on_exn e bt);
exit_span_ (); exit_span_ ();
after_task runner _ctx; after_task runner _ctx;
w.cur_ls := [||] w.cur_ls := [||]

View file

@ -48,7 +48,7 @@ module State_ = struct
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~ls ~run:_ ~resume suspension -> (fun ~run:_ ~resume suspension ->
while while
let old_st = A.get self in let old_st = A.get self in
match old_st with match old_st with
@ -59,7 +59,7 @@ module State_ = struct
| Left_solved left -> | Left_solved left ->
(* other thread is done, no risk of race condition *) (* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right)); A.set self (Both_solved (left, right));
resume ~ls suspension (Ok ()); resume suspension (Ok ());
false false
| Right_solved _ | Both_solved _ -> assert false | Right_solved _ | Both_solved _ -> assert false
do do
@ -113,19 +113,19 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
max 1 (1 + (n / Moonpool.Private.num_domains ())) max 1 (1 + (n / Moonpool.Private.num_domains ()))
in in
let start_tasks ~ls ~run ~resume (suspension : Suspend_.suspension) = let start_tasks ~run ~resume (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range = let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with match f offset (offset + len_range - 1) with
| () -> | () ->
if A.fetch_and_add missing (-len_range) = len_range then if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *) (* all tasks done successfully *)
resume ~ls suspension (Ok ()) resume suspension (Ok ())
| exception exn -> | exception exn ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then if not (A.exchange has_failed true) then
(* first one to fail, and [missing] must be >= 2 (* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *) because we're not decreasing it. *)
resume ~ls suspension (Error (exn, bt)) resume suspension (Error (exn, bt))
in in
let i = ref 0 in let i = ref 0 in
@ -143,9 +143,9 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~ls ~run ~resume suspension -> (fun ~run ~resume suspension ->
(* run tasks, then we'll resume [suspension] *) (* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~ls ~resume suspension); start_tasks ~run ~resume suspension);
} }
) )