diff --git a/src/pool.ml b/src/pool.ml index a6b74dc6..52aa050b 100644 --- a/src/pool.ml +++ b/src/pool.ml @@ -87,6 +87,27 @@ let num_tasks (self : t) : int = Array.iter (fun q -> n := !n + Bb_queue.size q) self.qs; !n +(* DLA interop *) +let prepare_for_await () : Dla_.t = + (* current state *) + let st : + ((with_handler:bool -> task -> unit) * Suspend_types_.suspension) option + A.t = + A.make None + in + + let release () : unit = + match A.exchange st None with + | None -> () + | Some (run, k) -> run ~with_handler:true (fun () -> k (Ok ())) + and await () : unit = + Suspend_.suspend + { Suspend_types_.handle = (fun ~run k -> A.set st (Some (run, k))) } + in + + let t = { Dla_.release; await } in + t + exception Got_task of task type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task @@ -96,7 +117,7 @@ let worker_thread_ pool ~on_exn ~around_task (active : bool A.t) let num_qs = Array.length qs in let (AT_pair (before_task, after_task)) = around_task in - try + let main_loop () = while A.get active do (* last resort: block on my queue *) let pop_blocking () = @@ -117,12 +138,18 @@ let worker_thread_ pool ~on_exn ~around_task (active : bool A.t) in let _ctx = before_task pool in + (* run the task now, catching errors *) (try task () with e -> let bt = Printexc.get_raw_backtrace () in on_exn e bt); after_task pool _ctx done + in + + try + (* handle domain-local await *) + Dla_.using ~prepare_for_await ~while_running:main_loop with Bb_queue.Closed -> () let default_thread_init_exit_ ~dom_id:_ ~t_id:_ () = ()