Compare commits

...

6 commits

Author SHA1 Message Date
Simon Cruanes
89b8ed3221
wip 2025-11-12 00:36:30 -05:00
Simon Cruanes
b53a067234
fix test 2025-11-12 00:30:40 -05:00
Simon Cruanes
2c3cc8892a
consolidate thread-local-storage into single record 2025-11-12 00:30:39 -05:00
Simon Cruanes
ee7972910f
breaking: remove around_task from schedulers 2025-11-12 00:25:02 -05:00
Simon Cruanes
2ce3fa7d3e
docs 2025-11-12 00:25:02 -05:00
Simon Cruanes
8770d4fb9c
repro for #41 2025-11-12 00:25:02 -05:00
19 changed files with 145 additions and 128 deletions

View file

@ -67,6 +67,14 @@ bench-pi:
'./_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 16 -mode forkjoin -kind=pool' \ './_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 16 -mode forkjoin -kind=pool' \
'./_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 20 -mode forkjoin -kind=pool' './_build/default/benchs/pi.exe -n $(PI_NSTEPS) -j 20 -mode forkjoin -kind=pool'
bench-repro-41:
dune build $(DUNE_OPTS_BENCH) examples/repro_41/run.exe
hyperfine --warmup=1 \
"./_build/default/examples/repro_41/run.exe 4 domainslib" \
"./_build/default/examples/repro_41/run.exe 4 moonpool" \
"./_build/default/examples/repro_41/run.exe 5 moonpool" \
"./_build/default/examples/repro_41/run.exe 5 seq"
.PHONY: test clean bench-fib bench-pi .PHONY: test clean bench-fib bench-pi
VERSION=$(shell awk '/^version:/ {print $$2}' moonpool.opam) VERSION=$(shell awk '/^version:/ {print $$2}' moonpool.opam)

5
examples/repro_41/dune Normal file
View file

@ -0,0 +1,5 @@
(executables
(names run)
(enabled_if
(>= %{ocaml_version} 5.0))
(libraries moonpool trace trace-tef domainslib))

54
examples/repro_41/run.ml Normal file
View file

@ -0,0 +1,54 @@
(* fibo.ml *)
let cutoff = 25
let input = 40
let rec fibo_seq n =
if n <= 1 then
n
else
fibo_seq (n - 1) + fibo_seq (n - 2)
let rec fibo_domainslib ctx n =
if n <= cutoff then
fibo_seq n
else
let open Domainslib in
let fut1 = Task.async ctx (fun () -> fibo_domainslib ctx (n - 1)) in
let fut2 = Task.async ctx (fun () -> fibo_domainslib ctx (n - 2)) in
Task.await ctx fut1 + Task.await ctx fut2
let rec fibo_moonpool ctx n =
if n <= cutoff then
fibo_seq n
else
let open Moonpool in
let fut1 = Fut.spawn ~on:ctx (fun () -> fibo_moonpool ctx (n - 1)) in
let fut2 = Fut.spawn ~on:ctx (fun () -> fibo_moonpool ctx (n - 2)) in
Fut.await fut1 + Fut.await fut2
let usage =
"fibo.exe <num_domains> [ domainslib | moonpool | moonpool_fifo | seq ]"
let num_domains = try int_of_string Sys.argv.(1) with _ -> failwith usage
let implem = try Sys.argv.(2) with _ -> failwith usage
let () =
let output =
match implem with
| "moonpool" ->
let open Moonpool in
let ctx = Ws_pool.create ~num_threads:num_domains () in
Ws_pool.run_wait_block ctx (fun () -> fibo_moonpool ctx input)
| "moonpool_fifo" ->
let open Moonpool in
let ctx = Fifo_pool.create ~num_threads:num_domains () in
Ws_pool.run_wait_block ctx (fun () -> fibo_moonpool ctx input)
| "domainslib" ->
let open Domainslib in
let pool = Task.setup_pool ~num_domains () in
Task.run pool (fun () -> fibo_domainslib pool input)
| "seq" -> fibo_seq input
| _ -> failwith usage
in
print_int output;
print_newline ()

View file

@ -6,18 +6,15 @@ type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?name:string -> ?name:string ->
'a 'a
(** Arguments used in {!create}. See {!create} for explanations. *) (** Arguments used in {!create}. See {!create} for explanations. *)
let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?name () : t = let create ?on_init_thread ?on_exit_thread ?on_exn ?name () : t =
Fifo_pool.create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?name Fifo_pool.create ?on_init_thread ?on_exit_thread ?on_exn ?name ~num_threads:1
~num_threads:1 () ()
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?name () f = let with_ ?on_init_thread ?on_exit_thread ?on_exn ?name () f =
let pool = let pool = create ?on_init_thread ?on_exit_thread ?on_exn ?name () in
create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?name ()
in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool f pool

View file

@ -13,7 +13,6 @@ type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?name:string -> ?name:string ->
'a 'a
(** Arguments used in {!create}. See {!create} for explanations. *) (** Arguments used in {!create}. See {!create} for explanations. *)

View file

@ -10,7 +10,6 @@ let ( let@ ) = ( @@ )
type state = { type state = {
threads: Thread.t array; threads: Thread.t array;
q: task_full Bb_queue.t; (** Queue for tasks. *) q: task_full Bb_queue.t; (** Queue for tasks. *)
around_task: WL.around_task;
mutable as_runner: t; mutable as_runner: t;
(* init options *) (* init options *)
name: string option; name: string option;
@ -28,11 +27,6 @@ type worker_state = {
let[@inline] size_ (self : state) = Array.length self.threads let[@inline] size_ (self : state) = Array.length self.threads
let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q let[@inline] num_tasks_ (self : state) : int = Bb_queue.size self.q
(*
get_thread_state = TLS.get_opt k_worker_state
*)
let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = () let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()
let shutdown_ ~wait (self : state) : unit = let shutdown_ ~wait (self : state) : unit =
@ -43,13 +37,10 @@ type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?num_threads:int -> ?num_threads:int ->
?name:string -> ?name:string ->
'a 'a
let default_around_task_ : WL.around_task = AT_pair (ignore, fun _ _ -> ())
(** Run [task] as is, on the pool. *) (** Run [task] as is, on the pool. *)
let schedule_ (self : state) (task : task_full) : unit = let schedule_ (self : state) (task : task_full) : unit =
try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown
@ -88,7 +79,6 @@ let cleanup (self : worker_state) : unit =
let worker_ops : worker_state WL.ops = let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = st.st.as_runner in let runner (st : worker_state) = st.st.as_runner in
let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) = let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn (Exn_bt.exn ebt) (Exn_bt.bt ebt) st.st.on_exn (Exn_bt.exn ebt) (Exn_bt.bt ebt)
in in
@ -96,7 +86,6 @@ let worker_ops : worker_state WL.ops =
WL.schedule = schedule_w; WL.schedule = schedule_w;
runner; runner;
get_next_task; get_next_task;
around_task;
on_exn; on_exn;
before_start; before_start;
cleanup; cleanup;
@ -104,19 +93,11 @@ let worker_ops : worker_state WL.ops =
let create_ ?(on_init_thread = default_thread_init_exit_) let create_ ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ~threads ?name () : state = ~threads ?name () : state =
(* wrapper *)
let around_task =
match around_task with
| Some (f, g) -> WL.AT_pair (f, g)
| None -> default_around_task_
in
let self = let self =
{ {
threads; threads;
q = Bb_queue.create (); q = Bb_queue.create ();
around_task;
as_runner = Runner.dummy; as_runner = Runner.dummy;
name; name;
on_init_thread; on_init_thread;
@ -127,8 +108,7 @@ let create_ ?(on_init_thread = default_thread_init_exit_)
self.as_runner <- runner_of_state self; self.as_runner <- runner_of_state self;
self self
let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads let create ?on_init_thread ?on_exit_thread ?on_exn ?num_threads ?name () : t =
?name () : t =
let num_domains = Domain_pool_.max_number_of_domains () in let num_domains = Domain_pool_.max_number_of_domains () in
(* number of threads to run *) (* number of threads to run *)
@ -140,8 +120,7 @@ let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
let pool = let pool =
let dummy_thread = Thread.self () in let dummy_thread = Thread.self () in
let threads = Array.make num_threads dummy_thread in let threads = Array.make num_threads dummy_thread in
create_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ~threads ?name create_ ?on_init_thread ?on_exit_thread ?on_exn ~threads ?name ()
()
in in
let runner = runner_of_state pool in let runner = runner_of_state pool in
@ -181,11 +160,9 @@ let create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads
runner runner
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads let with_ ?on_init_thread ?on_exit_thread ?on_exn ?num_threads ?name () f =
?name () f =
let pool = let pool =
create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads create ?on_init_thread ?on_exit_thread ?on_exn ?num_threads ?name ()
?name ()
in in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool f pool

View file

@ -20,7 +20,6 @@ type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?num_threads:int -> ?num_threads:int ->
?name:string -> ?name:string ->
'a 'a
@ -35,9 +34,6 @@ val create : (unit -> t, _) create_args
[Domain.recommended_domain_count()], ie one worker per CPU core. On OCaml [Domain.recommended_domain_count()], ie one worker per CPU core. On OCaml
4 the default is [4] (since there is only one domain). 4 the default is [4] (since there is only one domain).
@param on_exit_thread called at the end of each worker thread in the pool. @param on_exit_thread called at the end of each worker thread in the pool.
@param around_task
a pair of [before, after] functions ran around each task. See
{!Pool.create_args}.
@param name name for the pool, used in tracing (since 0.6) *) @param name name for the pool, used in tracing (since 0.6) *)
val with_ : (unit -> (t -> 'a) -> 'a, _) create_args val with_ : (unit -> (t -> 'a) -> 'a, _) create_args

View file

@ -9,19 +9,19 @@ let k_local_hmap : Hmap.t FLS.t = FLS.create ()
(** Access the local [hmap], or an empty one if not set *) (** Access the local [hmap], or an empty one if not set *)
let[@inline] get_local_hmap () : Hmap.t = let[@inline] get_local_hmap () : Hmap.t =
match TLS.get_exn k_cur_fiber with match TLS.get_exn k_cur_st with
| exception TLS.Not_set -> Hmap.empty | exception TLS.Not_set -> Hmap.empty
| fiber -> FLS.get fiber ~default:Hmap.empty k_local_hmap | { cur_fiber = fiber; _ } -> FLS.get fiber ~default:Hmap.empty k_local_hmap
let[@inline] set_local_hmap (h : Hmap.t) : unit = let[@inline] set_local_hmap (h : Hmap.t) : unit =
match TLS.get_exn k_cur_fiber with match TLS.get_exn k_cur_st with
| exception TLS.Not_set -> () | exception TLS.Not_set -> ()
| fiber -> FLS.set fiber k_local_hmap h | { cur_fiber = fiber; _ } -> FLS.set fiber k_local_hmap h
let[@inline] update_local_hmap (f : Hmap.t -> Hmap.t) : unit = let[@inline] update_local_hmap (f : Hmap.t -> Hmap.t) : unit =
match TLS.get_exn k_cur_fiber with match TLS.get_exn k_cur_st with
| exception TLS.Not_set -> () | exception TLS.Not_set -> ()
| fiber -> | { cur_fiber = fiber; _ } ->
let h = FLS.get fiber ~default:Hmap.empty k_local_hmap in let h = FLS.get fiber ~default:Hmap.empty k_local_hmap in
let h = f h in let h = f h in
FLS.set fiber k_local_hmap h FLS.set fiber k_local_hmap h

View file

@ -1,6 +1,7 @@
exception Oh_no of Exn_bt.t exception Oh_no of Exn_bt.t
let main' ?(block_signals = false) () (f : Runner.t -> 'a) : 'a = let main' ?(block_signals = false) () (f : Runner.t -> 'a) : 'a =
let module WL = Worker_loop_ in
let worker_st = let worker_st =
Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ()) Fifo_pool.Private_.create_single_threaded_state ~thread:(Thread.self ())
~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt))) ~on_exn:(fun e bt -> raise (Oh_no (Exn_bt.make e bt)))
@ -8,15 +9,17 @@ let main' ?(block_signals = false) () (f : Runner.t -> 'a) : 'a =
in in
let runner = Fifo_pool.Private_.runner_of_state worker_st in let runner = Fifo_pool.Private_.runner_of_state worker_st in
try try
let fiber = Fut.spawn ~on:runner (fun () -> f runner) in let fut = Fut.spawn ~on:runner (fun () -> f runner) in
Fut.on_result fiber (fun _ -> Runner.shutdown_without_waiting runner); Fut.on_result fut (fun _ -> Runner.shutdown_without_waiting runner);
Thread_local_storage.set Runner.For_runner_implementors.k_cur_st
{ cur_fiber = Picos.Fiber.create ~forbid:true fut; runner };
(* run the main thread *) (* run the main thread *)
Worker_loop_.worker_loop worker_st WL.worker_loop worker_st ~block_signals (* do not disturb existing thread *)
~block_signals (* do not disturb existing thread *)
~ops:Fifo_pool.Private_.worker_ops; ~ops:Fifo_pool.Private_.worker_ops;
match Fut.peek fiber with match Fut.peek fut with
| Some (Ok x) -> x | Some (Ok x) -> x
| Some (Error ebt) -> Exn_bt.raise ebt | Some (Error ebt) -> Exn_bt.raise ebt
| None -> assert false | None -> assert false

View file

@ -47,7 +47,12 @@ module For_runner_implementors = struct
let create ~size ~num_tasks ~shutdown ~run_async () : t = let create ~size ~num_tasks ~shutdown ~run_async () : t =
{ size; num_tasks; shutdown; run_async } { size; num_tasks; shutdown; run_async }
let k_cur_runner : t TLS.t = Types_.k_cur_runner type nonrec thread_local_state = thread_local_state = {
mutable runner: t;
mutable cur_fiber: fiber;
}
let k_cur_st : thread_local_state TLS.t = Types_.k_cur_st
end end
let dummy : t = let dummy : t =

View file

@ -72,7 +72,13 @@ module For_runner_implementors : sig
{b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, so {b NOTE}: the runner should support DLA and {!Suspend_} on OCaml 5.x, so
that {!Fork_join} and other 5.x features work properly. *) that {!Fork_join} and other 5.x features work properly. *)
val k_cur_runner : t Thread_local_storage.t type thread_local_state = {
mutable runner: t;
mutable cur_fiber: fiber;
}
(** State set in thread-local-storage for worker threads *)
val k_cur_st : thread_local_state Thread_local_storage.t
(** Key that should be used by each runner to store itself in TLS on every (** Key that should be used by each runner to store itself in TLS on every
thread it controls, so that tasks running on these threads can access the thread it controls, so that tasks running on these threads can access the
runner. This is necessary for {!get_current_runner} to work. *) runner. This is necessary for {!get_current_runner} to work. *)

View file

@ -11,8 +11,12 @@ type runner = {
num_tasks: unit -> int; num_tasks: unit -> int;
} }
let k_cur_runner : runner TLS.t = TLS.create () type thread_local_state = {
let k_cur_fiber : fiber TLS.t = TLS.create () mutable runner: runner;
mutable cur_fiber: fiber;
}
let k_cur_st : thread_local_state TLS.t = TLS.create ()
let _dummy_computation : Picos.Computation.packed = let _dummy_computation : Picos.Computation.packed =
let c = Picos.Computation.create () in let c = Picos.Computation.create () in
@ -20,11 +24,15 @@ let _dummy_computation : Picos.Computation.packed =
Picos.Computation.Packed c Picos.Computation.Packed c
let _dummy_fiber = Picos.Fiber.create_packed ~forbid:true _dummy_computation let _dummy_fiber = Picos.Fiber.create_packed ~forbid:true _dummy_computation
let[@inline] get_current_runner () : _ option = TLS.get_opt k_cur_runner
let[@inline] get_current_runner () : _ option =
match TLS.get_exn k_cur_st with
| st -> Some st.runner
| exception TLS.Not_set -> None
let[@inline] get_current_fiber () : fiber option = let[@inline] get_current_fiber () : fiber option =
match TLS.get_exn k_cur_fiber with match TLS.get_exn k_cur_st with
| f when f != _dummy_fiber -> Some f | { cur_fiber = f; _ } when f != _dummy_fiber -> Some f
| _ -> None | _ -> None
| exception TLS.Not_set -> None | exception TLS.Not_set -> None
@ -32,7 +40,7 @@ let error_get_current_fiber_ =
"Moonpool: get_current_fiber was called outside of a fiber." "Moonpool: get_current_fiber was called outside of a fiber."
let[@inline] get_current_fiber_exn () : fiber = let[@inline] get_current_fiber_exn () : fiber =
match TLS.get_exn k_cur_fiber with match TLS.get_exn k_cur_st with
| f when f != _dummy_fiber -> f | { cur_fiber = f; _ } when f != _dummy_fiber -> f
| _ -> failwith error_get_current_fiber_ | _ -> failwith error_get_current_fiber_
| exception TLS.Not_set -> failwith error_get_current_fiber_ | exception TLS.Not_set -> failwith error_get_current_fiber_

View file

@ -13,15 +13,11 @@ type task_full =
} }
-> task_full -> task_full
type around_task =
| AT_pair : (Runner.t -> 'a) * (Runner.t -> 'a -> unit) -> around_task
exception No_more_tasks exception No_more_tasks
type 'st ops = { type 'st ops = {
schedule: 'st -> task_full -> unit; schedule: 'st -> task_full -> unit;
get_next_task: 'st -> task_full; (** @raise No_more_tasks *) get_next_task: 'st -> task_full; (** @raise No_more_tasks *)
around_task: 'st -> around_task;
on_exn: 'st -> Exn_bt.t -> unit; on_exn: 'st -> Exn_bt.t -> unit;
runner: 'st -> Runner.t; runner: 'st -> Runner.t;
before_start: 'st -> unit; before_start: 'st -> unit;
@ -106,7 +102,13 @@ end
module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct
open Args open Args
let cur_fiber : fiber ref = ref _dummy_fiber let cur_st : Runner.For_runner_implementors.thread_local_state Lazy.t =
lazy
(match TLS.get_exn Runner.For_runner_implementors.k_cur_st with
| st -> st
| exception TLS.Not_set ->
failwith "Moonpool: worker loop: no current state set")
let runner = ops.runner st let runner = ops.runner st
type state = type state =
@ -117,15 +119,12 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct
let state = ref New let state = ref New
let run_task (task : task_full) : unit = let run_task (task : task_full) : unit =
let (AT_pair (before_task, after_task)) = ops.around_task st in
let fiber = let fiber =
match task with match task with
| T_start { fiber; _ } | T_resume { fiber; _ } -> fiber | T_start { fiber; _ } | T_resume { fiber; _ } -> fiber
in in
cur_fiber := fiber; (Lazy.force cur_st).cur_fiber <- fiber;
TLS.set k_cur_fiber fiber;
let _ctx = before_task runner in
(* run the task now, catching errors, handling effects *) (* run the task now, catching errors, handling effects *)
assert (task != _dummy_task); assert (task != _dummy_task);
@ -140,10 +139,7 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct
let ebt = Exn_bt.make e bt in let ebt = Exn_bt.make e bt in
ops.on_exn st ebt); ops.on_exn st ebt);
after_task runner _ctx; (Lazy.force cur_st).cur_fiber <- _dummy_fiber
cur_fiber := _dummy_fiber;
TLS.set k_cur_fiber _dummy_fiber
let setup ~block_signals () : unit = let setup ~block_signals () : unit =
if !state <> New then invalid_arg "worker_loop.setup: not a new instance"; if !state <> New then invalid_arg "worker_loop.setup: not a new instance";
@ -166,9 +162,9 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct
with _ -> () with _ -> ()
); );
TLS.set Runner.For_runner_implementors.k_cur_runner runner; ops.before_start st;
(Lazy.force cur_st).runner <- runner;
ops.before_start st ()
let run ?(max_tasks = max_int) () : unit = let run ?(max_tasks = max_int) () : unit =
if !state <> Ready then invalid_arg "worker_loop.run: not setup"; if !state <> Ready then invalid_arg "worker_loop.run: not setup";
@ -186,7 +182,7 @@ module Fine_grained (Args : FINE_GRAINED_ARGS) () = struct
let teardown () = let teardown () =
if !state <> Torn_down then ( if !state <> Torn_down then (
state := Torn_down; state := Torn_down;
cur_fiber := _dummy_fiber; (Lazy.force cur_st).cur_fiber <- _dummy_fiber;
ops.cleanup st ops.cleanup st
) )
end end

View file

@ -18,15 +18,11 @@ type task_full =
val _dummy_task : task_full val _dummy_task : task_full
type around_task =
| AT_pair : (Runner.t -> 'a) * (Runner.t -> 'a -> unit) -> around_task
exception No_more_tasks exception No_more_tasks
type 'st ops = { type 'st ops = {
schedule: 'st -> task_full -> unit; schedule: 'st -> task_full -> unit;
get_next_task: 'st -> task_full; get_next_task: 'st -> task_full;
around_task: 'st -> around_task;
on_exn: 'st -> Exn_bt.t -> unit; on_exn: 'st -> Exn_bt.t -> unit;
runner: 'st -> Runner.t; runner: 'st -> Runner.t;
before_start: 'st -> unit; before_start: 'st -> unit;

View file

@ -28,7 +28,6 @@ type state = {
cond: Condition.t; cond: Condition.t;
mutable as_runner: t; mutable as_runner: t;
(* init options *) (* init options *)
around_task: WL.around_task;
name: string option; name: string option;
on_init_thread: dom_id:int -> t_id:int -> unit -> unit; on_init_thread: dom_id:int -> t_id:int -> unit -> unit;
on_exit_thread: dom_id:int -> t_id:int -> unit -> unit; on_exit_thread: dom_id:int -> t_id:int -> unit -> unit;
@ -56,7 +55,8 @@ let num_tasks_ (self : state) : int =
!n !n
(** TLS, used by worker to store their specific state and be able to retrieve it (** TLS, used by worker to store their specific state and be able to retrieve it
from tasks when we schedule new sub-tasks. *) from tasks when we schedule new sub-tasks. This way we can schedule the new
task directly in the local work queue, where it might be stolen. *)
let k_worker_state : worker_state TLS.t = TLS.create () let k_worker_state : worker_state TLS.t = TLS.create ()
let[@inline] get_current_worker_ () : worker_state option = let[@inline] get_current_worker_ () : worker_state option =
@ -180,8 +180,8 @@ and wait_on_main_queue (self : worker_state) : WL.task_full =
let before_start (self : worker_state) : unit = let before_start (self : worker_state) : unit =
let t_id = Thread.id @@ Thread.self () in let t_id = Thread.id @@ Thread.self () in
self.st.on_init_thread ~dom_id:self.dom_id ~t_id (); self.st.on_init_thread ~dom_id:self.dom_id ~t_id ();
TLS.set k_cur_fiber _dummy_fiber; TLS.set Runner.For_runner_implementors.k_cur_st
TLS.set Runner.For_runner_implementors.k_cur_runner self.st.as_runner; { cur_fiber = _dummy_fiber; runner = self.st.as_runner };
TLS.set k_worker_state self; TLS.set k_worker_state self;
(* set thread name *) (* set thread name *)
@ -198,7 +198,6 @@ let cleanup (self : worker_state) : unit =
let worker_ops : worker_state WL.ops = let worker_ops : worker_state WL.ops =
let runner (st : worker_state) = st.st.as_runner in let runner (st : worker_state) = st.st.as_runner in
let around_task st = st.st.around_task in
let on_exn (st : worker_state) (ebt : Exn_bt.t) = let on_exn (st : worker_state) (ebt : Exn_bt.t) =
st.st.on_exn (Exn_bt.exn ebt) (Exn_bt.bt ebt) st.st.on_exn (Exn_bt.exn ebt) (Exn_bt.bt ebt)
in in
@ -206,7 +205,6 @@ let worker_ops : worker_state WL.ops =
WL.schedule = schedule_from_w; WL.schedule = schedule_from_w;
runner; runner;
get_next_task; get_next_task;
around_task;
on_exn; on_exn;
before_start; before_start;
cleanup; cleanup;
@ -235,7 +233,6 @@ type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?num_threads:int -> ?num_threads:int ->
?name:string -> ?name:string ->
'a 'a
@ -243,15 +240,8 @@ type ('a, 'b) create_args =
let create ?(on_init_thread = default_thread_init_exit_) let create ?(on_init_thread = default_thread_init_exit_)
?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ()) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())
?around_task ?num_threads ?name () : t = ?num_threads ?name () : t =
let pool_id_ = Id.create () in let pool_id_ = Id.create () in
(* wrapper *)
let around_task =
match around_task with
| Some (f, g) -> WL.AT_pair (f, g)
| None -> WL.AT_pair (ignore, fun _ _ -> ())
in
let num_domains = Domain_pool_.max_number_of_domains () in let num_domains = Domain_pool_.max_number_of_domains () in
let num_threads = Util_pool_.num_threads ?num_threads () in let num_threads = Util_pool_.num_threads ?num_threads () in
@ -268,7 +258,6 @@ let create ?(on_init_thread = default_thread_init_exit_)
n_waiting_nonzero = true; n_waiting_nonzero = true;
mutex = Mutex.create (); mutex = Mutex.create ();
cond = Condition.create (); cond = Condition.create ();
around_task;
on_exn; on_exn;
on_init_thread; on_init_thread;
on_exit_thread; on_exit_thread;
@ -324,11 +313,9 @@ let create ?(on_init_thread = default_thread_init_exit_)
pool.as_runner pool.as_runner
let with_ ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads let with_ ?on_init_thread ?on_exit_thread ?on_exn ?num_threads ?name () f =
?name () f =
let pool = let pool =
create ?on_init_thread ?on_exit_thread ?on_exn ?around_task ?num_threads create ?on_init_thread ?on_exit_thread ?on_exn ?num_threads ?name ()
?name ()
in in
let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in let@ () = Fun.protect ~finally:(fun () -> shutdown pool) in
f pool f pool

View file

@ -24,7 +24,6 @@ type ('a, 'b) create_args =
?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_init_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) -> ?on_exit_thread:(dom_id:int -> t_id:int -> unit -> unit) ->
?on_exn:(exn -> Printexc.raw_backtrace -> unit) -> ?on_exn:(exn -> Printexc.raw_backtrace -> unit) ->
?around_task:(t -> 'b) * (t -> 'b -> unit) ->
?num_threads:int -> ?num_threads:int ->
?name:string -> ?name:string ->
'a 'a
@ -40,11 +39,6 @@ val create : (unit -> t, _) create_args
[Domain.recommended_domain_count()], ie one worker thread per CPU core. On [Domain.recommended_domain_count()], ie one worker thread per CPU core. On
OCaml 4 the default is [4] (since there is only one domain). OCaml 4 the default is [4] (since there is only one domain).
@param on_exit_thread called at the end of each thread in the pool @param on_exit_thread called at the end of each thread in the pool
@param around_task
a pair of [before, after], where [before pool] is called before a task is
processed, on the worker thread about to run it, and returns [x]; and
[after pool x] is called by the same thread after the task is over. (since
0.2)
@param name @param name
a name for this thread pool, used if tracing is enabled (since 0.6) *) a name for this thread pool, used if tracing is enabled (since 0.6) *)

View file

@ -146,7 +146,7 @@ let work_ idx (st : worker_state) : unit =
let () = let () =
assert (Domain_.is_main_domain ()); assert (Domain_.is_main_domain ());
let w = { th_count = Atomic.make 1; q = Bb_queue.create () } in let w = { th_count = Atomic.make 1; q = Bb_queue.create () } in
(* thread that stays alive *) (* thread that stays alive since [th_count>0] will always hold *)
ignore (Thread.create (fun () -> work_ 0 w) () : Thread.t); ignore (Thread.create (fun () -> work_ 0 w) () : Thread.t);
domains_.(0) <- Lock.create (Some w, None) domains_.(0) <- Lock.create (Some w, None)
@ -154,7 +154,8 @@ let[@inline] max_number_of_domains () : int = Array.length domains_
let run_on (i : int) (f : unit -> unit) : unit = let run_on (i : int) (f : unit -> unit) : unit =
assert (i < Array.length domains_); assert (i < Array.length domains_);
let w =
let w : worker_state =
Lock.update_map domains_.(i) (function Lock.update_map domains_.(i) (function
| (Some w, _) as st -> | (Some w, _) as st ->
Atomic.incr w.th_count; Atomic.incr w.th_count;

View file

@ -7,8 +7,6 @@ end
module Fut = Moonpool.Fut module Fut = Moonpool.Fut
let default_around_task_ : WL.around_task = AT_pair (ignore, fun _ _ -> ())
let on_uncaught_exn : (Moonpool.Exn_bt.t -> unit) ref = let on_uncaught_exn : (Moonpool.Exn_bt.t -> unit) ref =
ref (fun ebt -> ref (fun ebt ->
Printf.eprintf "uncaught exception in moonpool-lwt:\n%s" (Exn_bt.show ebt)) Printf.eprintf "uncaught exception in moonpool-lwt:\n%s" (Exn_bt.show ebt))
@ -90,8 +88,6 @@ end
module Ops = struct module Ops = struct
type st = Scheduler_state.st type st = Scheduler_state.st
let around_task _ = default_around_task_
let schedule (self : st) t = let schedule (self : st) t =
if Atomic.get self.closed then if Atomic.get self.closed then
failwith "moonpool-lwt.schedule: scheduler is closed"; failwith "moonpool-lwt.schedule: scheduler is closed";
@ -122,15 +118,7 @@ module Ops = struct
() ()
let ops : st WL.ops = let ops : st WL.ops =
{ { schedule; get_next_task; on_exn; runner; before_start; cleanup }
schedule;
around_task;
get_next_task;
on_exn;
runner;
before_start;
cleanup;
}
let setup st = let setup st =
if Atomic.compare_and_set Scheduler_state.cur_st None (Some st) then if Atomic.compare_and_set Scheduler_state.cur_st None (Some st) then

View file

@ -14,14 +14,11 @@ let run ~kind () =
let pool = let pool =
let on_init_thread ~dom_id:_ ~t_id () = let on_init_thread ~dom_id:_ ~t_id () =
Trace.set_thread_name (Printf.sprintf "pool worker %d" t_id) Trace.set_thread_name (Printf.sprintf "pool worker %d" t_id)
and around_task =
( (fun self -> Trace.counter_int "n_tasks" (Ws_pool.num_tasks self)),
fun self () -> Trace.counter_int "n_tasks" (Ws_pool.num_tasks self) )
in in
match kind with match kind with
| `Simple -> Fifo_pool.create ~num_threads:3 ~on_init_thread ~around_task () | `Simple -> Fifo_pool.create ~num_threads:3 ~on_init_thread ()
| `Ws_pool -> Ws_pool.create ~num_threads:3 ~on_init_thread ~around_task () | `Ws_pool -> Ws_pool.create ~num_threads:3 ~on_init_thread ()
in in
(* make all threads busy *) (* make all threads busy *)