add task_local_storage to core, modify how suspend works

This commit is contained in:
Simon Cruanes 2024-02-02 23:18:59 -05:00
parent f84414a412
commit c05a38d617
No known key found for this signature in database
GPG key ID: EBFFF6F283F3A2B4
13 changed files with 286 additions and 56 deletions

View file

@ -3,7 +3,7 @@
(name moonpool) (name moonpool)
(libraries moonpool.private) (libraries moonpool.private)
(flags :standard -open Moonpool_private) (flags :standard -open Moonpool_private)
(private_modules domain_pool_ util_pool_) (private_modules types_ domain_pool_ util_pool_)
(preprocess (preprocess
(action (action
(run %{project_root}/src/cpp/cpp.exe %{input-file})))) (run %{project_root}/src/cpp/cpp.exe %{input-file}))))

View file

@ -1,16 +1,17 @@
module TLS = Thread_local_storage_ open Types_
include Runner include Runner
let ( let@ ) = ( @@ ) let ( let@ ) = ( @@ )
type task_with_name = { type task_full = {
f: unit -> unit; f: unit -> unit;
name: string; name: string;
ls: task_ls;
} }
type state = { type state = {
threads: Thread.t array; threads: Thread.t array;
q: task_with_name Bb_queue.t; (** Queue for tasks. *) q: task_full Bb_queue.t; (** Queue for tasks. *)
} }
(** internal state *) (** internal state *)
@ -18,13 +19,16 @@ let[@inline] size_ (self : state) = Array.length self.threads
let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q
(** Run [task] as is, on the pool. *) (** Run [task] as is, on the pool. *)
let schedule_ (self : state) (task : task_with_name) : unit = let schedule_ (self : state) (task : task_full) : unit =
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task
let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
let cur_ls : task_ls ref = ref [||] in
TLS.set Types_.k_ls_values (Some cur_ls);
TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner;
let (AT_pair (before_task, after_task)) = around_task in let (AT_pair (before_task, after_task)) = around_task in
let cur_span = ref Tracing_.dummy_span in let cur_span = ref Tracing_.dummy_span in
@ -34,20 +38,32 @@ let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit =
cur_span := Tracing_.dummy_span cur_span := Tracing_.dummy_span
in in
let run_another_task ~name task' = schedule_ self { f = task'; name } in let on_suspend () =
exit_span_ ();
!cur_ls
in
let run_task (task : task_with_name) : unit = let run_another_task ~name task' =
schedule_ self { f = task'; name; ls = [||] }
in
let run_task (task : task_full) : unit =
cur_ls := task.ls;
let _ctx = before_task runner in let _ctx = before_task runner in
cur_span := Tracing_.enter_span task.name; cur_span := Tracing_.enter_span task.name;
(* run the task now, catching errors *)
(try let resume ~ls k res =
Suspend_.with_suspend task.f ~name:task.name ~run:run_another_task schedule_ self { f = (fun () -> k res); name = task.name; ls }
~on_suspend:exit_span_ in
(* run the task now, catching errors, handling effects *)
(try Suspend_.with_suspend task.f ~run:run_another_task ~resume ~on_suspend
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
on_exn e bt); on_exn e bt);
exit_span_ (); exit_span_ ();
after_task runner _ctx after_task runner _ctx;
cur_ls := [||]
in in
let main_loop () = let main_loop () =
@ -100,7 +116,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
{ threads = Array.make num_threads dummy; q = Bb_queue.create () } { threads = Array.make num_threads dummy; q = Bb_queue.create () }
in in
let run_async ~name f = schedule_ pool { f; name } in let run_async ~name f = schedule_ pool { f; name; ls = [||] } in
let runner = let runner =
Runner.For_runner_implementors.create Runner.For_runner_implementors.create

View file

@ -42,6 +42,16 @@ let[@inline] is_done self : bool =
| Done _ -> true | Done _ -> true
| Waiting _ -> false | Waiting _ -> false
let[@inline] is_success self =
match A.get self.st with
| Done (Ok _) -> true
| _ -> false
let[@inline] is_failed self =
match A.get self.st with
| Done (Error _) -> true
| _ -> false
exception Not_ready exception Not_ready
let[@inline] get_or_fail self = let[@inline] get_or_fail self =
@ -427,14 +437,14 @@ let await (fut : 'a t) : 'a =
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~name ~run k -> (fun ~ls ~run:_ ~resume k ->
on_result fut (function on_result fut (function
| Ok _ -> | Ok _ ->
(* schedule continuation with the same name *) (* schedule continuation with the same name *)
run ~name (fun () -> k (Ok ())) resume ~ls k (Ok ())
| Error (exn, bt) -> | Error (exn, bt) ->
(* fail continuation immediately *) (* fail continuation immediately *)
k (Error (exn, bt)))); resume ~ls k (Error (exn, bt))));
}; };
(* un-suspended: we should have a result! *) (* un-suspended: we should have a result! *)
get_or_fail_exn fut get_or_fail_exn fut
@ -452,3 +462,7 @@ end
include Infix include Infix
module Infix_local = Infix [@@deprecated "use Infix"] module Infix_local = Infix [@@deprecated "use Infix"]
module Private_ = struct
let[@inline] unsafe_promise_of_fut x = x
end

View file

@ -84,6 +84,14 @@ val is_done : _ t -> bool
(** Is the future resolved? This is the same as [peek fut |> Option.is_some]. (** Is the future resolved? This is the same as [peek fut |> Option.is_some].
@since 0.2 *) @since 0.2 *)
val is_success : _ t -> bool
(** Checks if the future is resolved with [Ok _] as a result.
@since NEXT_RELEASE *)
val is_failed : _ t -> bool
(** Checks if the future is resolved with [Error _] as a result.
@since NEXT_RELEASE *)
(** {2 Combinators} *) (** {2 Combinators} *)
val spawn : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a t val spawn : ?name:string -> on:Runner.t -> (unit -> 'a) -> 'a t
@ -268,3 +276,12 @@ include module type of Infix
module Infix_local = Infix module Infix_local = Infix
[@@deprecated "Use Infix"] [@@deprecated "Use Infix"]
(** @deprecated use Infix instead *) (** @deprecated use Infix instead *)
(**/**)
module Private_ : sig
val unsafe_promise_of_fut : 'a t -> 'a promise
(** please do not use *)
end
(**/**)

View file

@ -23,6 +23,7 @@ module Fut = Fut
module Lock = Lock module Lock = Lock
module Immediate_runner = Immediate_runner module Immediate_runner = Immediate_runner
module Runner = Runner module Runner = Runner
module Task_local_storage = Task_local_storage
module Thread_local_storage = Thread_local_storage_ module Thread_local_storage = Thread_local_storage_
module Ws_pool = Ws_pool module Ws_pool = Ws_pool

View file

@ -59,6 +59,7 @@ val await : 'a Fut.t -> 'a
module Lock = Lock module Lock = Lock
module Fut = Fut module Fut = Fut
module Chan = Chan module Chan = Chan
module Task_local_storage = Task_local_storage
module Thread_local_storage = Thread_local_storage_ module Thread_local_storage = Thread_local_storage_
(** A simple blocking queue. (** A simple blocking queue.
@ -187,8 +188,10 @@ module Atomic = Atomic_
(**/**) (**/**)
(** Private internals, with no stability guarantees *)
module Private : sig module Private : sig
module Ws_deque_ = Ws_deque_ module Ws_deque_ = Ws_deque_
(** A deque for work stealing, fixed size. *)
(** {2 Suspensions} *) (** {2 Suspensions} *)

View file

@ -1,33 +1,55 @@
type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit open Types_
type suspension = unit Exn_bt.result -> unit
type task = unit -> unit type task = unit -> unit
type suspension_handler = { type suspension_handler = {
handle: name:string -> run:(name:string -> task -> unit) -> suspension -> unit; handle:
ls:task_ls ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
} }
[@@unboxed] [@@unboxed]
[@@@ifge 5.0] [@@@ifge 5.0]
[@@@ocaml.alert "-unstable"] [@@@ocaml.alert "-unstable"]
type _ Effect.t += Suspend : suspension_handler -> unit Effect.t type _ Effect.t +=
| Suspend : suspension_handler -> unit Effect.t
| Yield : unit Effect.t
let[@inline] yield () = Effect.perform Yield
let[@inline] suspend h = Effect.perform (Suspend h) let[@inline] suspend h = Effect.perform (Suspend h)
let with_suspend ~name ~on_suspend ~(run : name:string -> task -> unit) let with_suspend ~on_suspend ~(run : name:string -> task -> unit)
~(resume : ls:task_ls -> suspension -> unit Exn_bt.result -> unit)
(f : unit -> unit) : unit = (f : unit -> unit) : unit =
let module E = Effect.Deep in let module E = Effect.Deep in
(* effect handler *) (* effect handler *)
let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option = let effc : type e. e Effect.t -> ((e, _) E.continuation -> _) option =
function function
| Suspend h -> | Suspend h ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some Some
(fun k -> (fun k ->
on_suspend (); let ls = on_suspend () in
let k' : suspension = function let k' : suspension = function
| Ok () -> E.continue k () | Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt | Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in in
h.handle ~name ~run k') h.handle ~ls ~run ~resume k')
| Yield ->
(* TODO: discontinue [k] if current fiber (if any) is cancelled? *)
Some
(fun k ->
let ls = on_suspend () in
let k' : suspension = function
| Ok () -> E.continue k ()
| Error (exn, bt) -> E.discontinue_with_backtrace k exn bt
in
resume ~ls k' (Ok ()))
| _ -> None | _ -> None
in in

View file

@ -3,13 +3,20 @@
This module is an implementation detail of Moonpool and should This module is an implementation detail of Moonpool and should
not be used outside of it, except by experts to implement {!Runner}. *) not be used outside of it, except by experts to implement {!Runner}. *)
type suspension = (unit, exn * Printexc.raw_backtrace) result -> unit open Types_
type suspension = unit Exn_bt.result -> unit
(** A suspended computation *) (** A suspended computation *)
type task = unit -> unit type task = unit -> unit
type suspension_handler = { type suspension_handler = {
handle: name:string -> run:(name:string -> task -> unit) -> suspension -> unit; handle:
ls:task_ls ->
run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
suspension ->
unit;
} }
[@@unboxed] [@@unboxed]
(** The handler that knows what to do with the suspended computation. (** The handler that knows what to do with the suspended computation.
@ -40,9 +47,16 @@ type _ Effect.t +=
(** The effect used to suspend the current thread and pass it, suspended, (** The effect used to suspend the current thread and pass it, suspended,
to the handler. The handler will ensure that the suspension is resumed later to the handler. The handler will ensure that the suspension is resumed later
once some computation has been done. *) once some computation has been done. *)
| Yield : unit Effect.t
(** The effect used to interrupt the current computation and immediately re-schedule
it on the same runner. *)
[@@@ocaml.alert "+unstable"] [@@@ocaml.alert "+unstable"]
val yield : unit -> unit
(** Interrupt current computation, and re-schedule it at the end of the
runner's job queue. *)
val suspend : suspension_handler -> unit val suspend : suspension_handler -> unit
(** [suspend h] jumps back to the nearest {!with_suspend} (** [suspend h] jumps back to the nearest {!with_suspend}
and calls [h.handle] with the current continuation [k] and calls [h.handle] with the current continuation [k]
@ -52,17 +66,24 @@ val suspend : suspension_handler -> unit
[@@@endif] [@@@endif]
val with_suspend : val with_suspend :
name:string -> on_suspend:(unit -> task_ls) ->
on_suspend:(unit -> unit) ->
run:(name:string -> task -> unit) -> run:(name:string -> task -> unit) ->
resume:(ls:task_ls -> suspension -> unit Exn_bt.result -> unit) ->
(unit -> unit) -> (unit -> unit) ->
unit unit
(** [with_suspend ~run f] runs [f()] in an environment where [suspend] (** [with_suspend ~name ~on_suspend ~run ~resume f]
will work. If [f()] suspends with suspension handler [h], runs [f()] in an environment where [suspend]
this calls [h ~run k] where [k] is the suspension. will work (on OCaml 5) or do nothing (on OCaml 4.xx).
The suspension should always run in a new task, via [run].
If [f()] suspends with suspension handler [h],
this calls [h ~run ~resume k] where [k] is the suspension.
The suspension should always be passed exactly once to
[resume]. [run] should be used to start other tasks.
@param on_suspend called when [f()] suspends itself. @param on_suspend called when [f()] suspends itself.
@param name used for tracing, if not [""].
@param run used to schedule new tasks
@param resume run the suspension. Must be called exactly once.
This will not do anything on OCaml 4.x. This will not do anything on OCaml 4.x.
*) *)

View file

@ -0,0 +1,53 @@
open Types_
module A = Atomic
type 'a key = 'a ls_key
let key_count_ = A.make 0
let new_key (type t) ~init () : t key =
let offset = A.fetch_and_add key_count_ 1 in
(module struct
type nonrec t = t
type ls_value += V of t
let offset = offset
let init = init
end : LS_KEY
with type t = t)
type ls_value += Dummy
(** Resize array of TLS values *)
let[@inline never] resize_ (cur : ls_value array ref) n =
let len = Array.length !cur in
let new_ls = Array.make (max n (len * 2)) Dummy in
Array.blit !cur 0 new_ls 0 len;
cur := new_ls
let[@inline] get_cur_ () : ls_value array ref =
match TLS.get k_ls_values with
| Some r -> r
| None -> failwith "Task local storage must be accessed from within a runner."
let get (type a) ((module K) : a key) : a =
let cur = get_cur_ () in
if K.offset >= Array.length !cur then resize_ cur K.offset;
match !cur.(K.offset) with
| K.V x -> (* common case first *) x
| Dummy ->
(* first time we access this *)
let v = K.init () in
!cur.(K.offset) <- K.V v;
v
| _ -> assert false
let set (type a) ((module K) : a key) (v : a) : unit =
let cur = get_cur_ () in
if K.offset >= Array.length !cur then resize_ cur K.offset;
!cur.(K.offset) <- K.V v
let with_value key x f =
let old = get key in
set key x;
Fun.protect ~finally:(fun () -> set key old) f

View file

@ -0,0 +1,43 @@
(** Task-local storage.
This storage is associated to the current task,
just like thread-local storage is associated with
the current thread. The storage is carried along in case
the current task is suspended.
@since NEXT_RELEASE
*)
type 'a key
(** A key used to access a particular (typed) storage slot on every task. *)
val new_key : init:(unit -> 'a) -> unit -> 'a key
(** [new_key ~init ()] makes a new key. Keys are expensive and
should never be allocated dynamically or in a loop.
The correct pattern is, at toplevel:
{[
let k_foo : foo Task_ocal_storage.key =
Task_local_storage.new_key ~init:(fun () -> make_foo ()) ()
(**)
(* use it: *)
let = Task_local_storage.get k_foo
]}
*)
val get : 'a key -> 'a
(** [get k] gets the value for the current task for key [k].
Must be run from inside a task running on a runner.
@raise Failure otherwise *)
val set : 'a key -> 'a -> unit
(** [set k v] sets the storage for [k] to [v].
Must be run from inside a task running on a runner.
@raise Failure otherwise *)
val with_value : 'a key -> 'a -> (unit -> 'b) -> 'b
(** [with_value k v f] sets [k] to [v] for the duration of the call
to [f()]. When [f()] returns (or fails), [k] is restored
to its old value. *)

26
src/core/types_.ml Normal file
View file

@ -0,0 +1,26 @@
module TLS = Thread_local_storage_
type ls_value = ..
(** Key for task local storage *)
module type LS_KEY = sig
type t
type ls_value += V of t
val offset : int
(** Unique offset *)
val init : unit -> t
end
type 'a ls_key = (module LS_KEY with type t = 'a)
(** A LS key (task local storage) *)
type task_ls = ls_value array
(** Store the current LS values for the current thread.
A worker thread is going to cycle through many tasks, each of which
has its own storage. This key allows tasks running on the worker
to access their own storage *)
let k_ls_values : task_ls ref option TLS.key = TLS.new_key (fun () -> None)

View file

@ -1,3 +1,4 @@
open Types_
module WSQ = Ws_deque_ module WSQ = Ws_deque_
module A = Atomic_ module A = Atomic_
module TLS = Thread_local_storage_ module TLS = Thread_local_storage_
@ -13,16 +14,18 @@ module Id = struct
let equal : t -> t -> bool = ( == ) let equal : t -> t -> bool = ( == )
end end
type task_with_name = { type task_full = {
f: task; f: task;
name: string; name: string;
ls: task_ls;
} }
type worker_state = { type worker_state = {
pool_id_: Id.t; (** Unique per pool *) pool_id_: Id.t; (** Unique per pool *)
mutable thread: Thread.t; mutable thread: Thread.t;
q: task_with_name WSQ.t; (** Work stealing queue *) q: task_full WSQ.t; (** Work stealing queue *)
mutable cur_span: int64; mutable cur_span: int64;
cur_ls: task_ls ref; (** Task storage *)
rng: Random.State.t; rng: Random.State.t;
} }
(** State for a given worker. Only this worker is (** State for a given worker. Only this worker is
@ -35,7 +38,7 @@ type state = {
id_: Id.t; id_: Id.t;
active: bool A.t; (** Becomes [false] when the pool is shutdown. *) active: bool A.t; (** Becomes [false] when the pool is shutdown. *)
workers: worker_state array; (** Fixed set of workers. *) workers: worker_state array; (** Fixed set of workers. *)
main_q: task_with_name Queue.t; main_q: task_full Queue.t;
(** Main queue for tasks coming from the outside *) (** Main queue for tasks coming from the outside *)
mutable n_waiting: int; (* protected by mutex *) mutable n_waiting: int; (* protected by mutex *)
mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *) mutable n_waiting_nonzero: bool; (** [n_waiting > 0] *)
@ -72,10 +75,10 @@ let[@inline] try_wake_someone_ (self : state) : unit =
) )
(** Run [task] as is, on the pool. *) (** Run [task] as is, on the pool. *)
let schedule_task_ (self : state) ~name (w : worker_state option) (f : task) : let schedule_task_ (self : state) ~name ~ls (w : worker_state option) (f : task)
unit = : unit =
(* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let task = { f; name } in let task = { f; name; ls } in
match w with match w with
| Some w when Id.equal self.id_ w.pool_id_ -> | Some w when Id.equal self.id_ w.pool_id_ ->
(* we're on this same pool, schedule in the worker's state. Otherwise (* we're on this same pool, schedule in the worker's state. Otherwise
@ -104,9 +107,11 @@ let schedule_task_ (self : state) ~name (w : worker_state option) (f : task) :
raise Shutdown raise Shutdown
(** Run this task, now. Must be called from a worker. *) (** Run this task, now. Must be called from a worker. *)
let run_task_now_ (self : state) ~runner (w : worker_state) ~name task : unit = let run_task_now_ (self : state) ~runner (w : worker_state) ~name ~ls task :
unit =
(* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *)
let (AT_pair (before_task, after_task)) = self.around_task in let (AT_pair (before_task, after_task)) = self.around_task in
w.cur_ls := ls;
let _ctx = before_task runner in let _ctx = before_task runner in
w.cur_span <- Tracing_.enter_span name; w.cur_span <- Tracing_.enter_span name;
@ -115,25 +120,32 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~name task : unit =
w.cur_span <- Tracing_.dummy_span w.cur_span <- Tracing_.dummy_span
in in
let on_suspend () =
exit_span_ ();
!(w.cur_ls)
in
let run_another_task ~name task' = let run_another_task ~name task' =
let w = find_current_worker_ () in let w = find_current_worker_ () in
schedule_task_ self w ~name task' schedule_task_ self w ~name ~ls:[||] task'
in in
let resume ~ls k r = schedule_task_ self (Some w) ~name ~ls (fun () -> k r) in
(* run the task now, catching errors *) (* run the task now, catching errors *)
(try (try
(* run [task()] and handle [suspend] in it *) (* run [task()] and handle [suspend] in it *)
Suspend_.with_suspend task ~name ~run:run_another_task Suspend_.with_suspend task ~run:run_another_task ~resume ~on_suspend
~on_suspend:exit_span_
with e -> with e ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
self.on_exn e bt); self.on_exn e bt);
exit_span_ (); exit_span_ ();
after_task runner _ctx after_task runner _ctx;
w.cur_ls := [||]
let[@inline] run_async_ (self : state) ~name (f : task) : unit = let[@inline] run_async_ (self : state) ~name (f : task) : unit =
let w = find_current_worker_ () in let w = find_current_worker_ () in
schedule_task_ self w ~name f schedule_task_ self w ~name ~ls:[||] f
(* TODO: function to schedule many tasks from the outside. (* TODO: function to schedule many tasks from the outside.
- build a queue - build a queue
@ -150,11 +162,11 @@ let[@inline] wait_ (self : state) : unit =
self.n_waiting <- self.n_waiting - 1; self.n_waiting <- self.n_waiting - 1;
if self.n_waiting = 0 then self.n_waiting_nonzero <- false if self.n_waiting = 0 then self.n_waiting_nonzero <- false
exception Got_task of task_with_name exception Got_task of task_full
(** Try to steal a task *) (** Try to steal a task *)
let try_to_steal_work_once_ (self : state) (w : worker_state) : let try_to_steal_work_once_ (self : state) (w : worker_state) : task_full option
task_with_name option = =
let init = Random.State.int w.rng (Array.length self.workers) in let init = Random.State.int w.rng (Array.length self.workers) in
try try
@ -179,7 +191,7 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit =
match WSQ.pop w.q with match WSQ.pop w.q with
| Some task -> | Some task ->
try_wake_someone_ self; try_wake_someone_ self;
run_task_now_ self ~runner w ~name:task.name task.f run_task_now_ self ~runner w ~name:task.name ~ls:task.ls task.f
| None -> continue := false | None -> continue := false
done done
@ -192,7 +204,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit =
worker_run_self_tasks_ self ~runner w; worker_run_self_tasks_ self ~runner w;
try_steal () try_steal ()
and run_task task : unit = and run_task task : unit =
run_task_now_ self ~runner w ~name:task.name task.f; run_task_now_ self ~runner w ~name:task.name ~ls:task.ls task.f;
main () main ()
and try_steal () = and try_steal () =
match try_to_steal_work_once_ self w with match try_to_steal_work_once_ self w with
@ -249,7 +261,7 @@ type ('a, 'b) create_args =
'a 'a
(** Arguments used in {!create}. See {!create} for explanations. *) (** Arguments used in {!create}. See {!create} for explanations. *)
let dummy_task_ = { f = ignore; name = "DUMMY_TASK" } let dummy_task_ = { f = ignore; ls = [||]; name = "DUMMY_TASK" }
let create ?(on_init_thread = default_thread_init_exit_) let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
@ -277,6 +289,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
cur_span = Tracing_.dummy_span; cur_span = Tracing_.dummy_span;
q = WSQ.create ~dummy:dummy_task_ (); q = WSQ.create ~dummy:dummy_task_ ();
rng = Random.State.make [| i |]; rng = Random.State.make [| i |];
cur_ls = ref [||];
}) })
in in
@ -318,6 +331,7 @@ let create ?(on_init_thread = default_thread_init_exit_)
let thread = Thread.self () in let thread = Thread.self () in
let t_id = Thread.id thread in let t_id = Thread.id thread in
on_init_thread ~dom_id:dom_idx ~t_id (); on_init_thread ~dom_id:dom_idx ~t_id ();
TLS.set k_ls_values (Some w.cur_ls);
(* set thread name *) (* set thread name *)
Option.iter Option.iter

View file

@ -48,7 +48,7 @@ module State_ = struct
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~name:_ ~run:_ suspension -> (fun ~ls ~run:_ ~resume suspension ->
while while
let old_st = A.get self in let old_st = A.get self in
match old_st with match old_st with
@ -59,7 +59,7 @@ module State_ = struct
| Left_solved left -> | Left_solved left ->
(* other thread is done, no risk of race condition *) (* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right)); A.set self (Both_solved (left, right));
suspension (Ok ()); resume ~ls suspension (Ok ());
false false
| Right_solved _ | Both_solved _ -> assert false | Right_solved _ | Both_solved _ -> assert false
do do
@ -113,19 +113,19 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
max 1 (1 + (n / Moonpool.Private.num_domains ())) max 1 (1 + (n / Moonpool.Private.num_domains ()))
in in
let start_tasks ~name ~run (suspension : Suspend_.suspension) = let start_tasks ~ls ~run ~resume (suspension : Suspend_.suspension) =
let task_for ~offset ~len_range = let task_for ~offset ~len_range =
match f offset (offset + len_range - 1) with match f offset (offset + len_range - 1) with
| () -> | () ->
if A.fetch_and_add missing (-len_range) = len_range then if A.fetch_and_add missing (-len_range) = len_range then
(* all tasks done successfully *) (* all tasks done successfully *)
run ~name (fun () -> suspension (Ok ())) resume ~ls suspension (Ok ())
| exception exn -> | exception exn ->
let bt = Printexc.get_raw_backtrace () in let bt = Printexc.get_raw_backtrace () in
if not (A.exchange has_failed true) then if not (A.exchange has_failed true) then
(* first one to fail, and [missing] must be >= 2 (* first one to fail, and [missing] must be >= 2
because we're not decreasing it. *) because we're not decreasing it. *)
run ~name (fun () -> suspension (Error (exn, bt))) resume ~ls suspension (Error (exn, bt))
in in
let i = ref 0 in let i = ref 0 in
@ -135,7 +135,7 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
let len_range = min chunk_size (n - offset) in let len_range = min chunk_size (n - offset) in
assert (offset + len_range <= n); assert (offset + len_range <= n);
run ~name (fun () -> task_for ~offset ~len_range); run ~name:"" (fun () -> task_for ~offset ~len_range);
i := !i + len_range i := !i + len_range
done done
in in
@ -143,9 +143,9 @@ let for_ ?chunk_size n (f : int -> int -> unit) : unit =
Suspend_.suspend Suspend_.suspend
{ {
Suspend_.handle = Suspend_.handle =
(fun ~name ~run suspension -> (fun ~ls ~run ~resume suspension ->
(* run tasks, then we'll resume [suspension] *) (* run tasks, then we'll resume [suspension] *)
start_tasks ~run ~name suspension); start_tasks ~run ~ls ~resume suspension);
} }
) )