From c39435d8eb76cce802664a62498e538bda21743b Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Tue, 27 Feb 2024 22:31:25 -0500 Subject: [PATCH] fix fifo_pool: resume can be called from another worker we might schedule on worker 1, suspend, resume on worker 2, and resume from there. --- src/core/fifo_pool.ml | 71 ++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/src/core/fifo_pool.ml b/src/core/fifo_pool.ml index 60894f92..df4837c2 100644 --- a/src/core/fifo_pool.ml +++ b/src/core/fifo_pool.ml @@ -4,10 +4,17 @@ include Runner let ( let@ ) = ( @@ ) let k_storage = Task_local_storage.Private_.Storage.k_storage -type task_full = { - f: unit -> unit; - ls: Task_local_storage.storage; -} +type task_full = + | T_start of { + ls: Task_local_storage.storage; + f: task; + } + | T_resume : { + ls: Task_local_storage.storage; + k: 'a -> unit; + x: 'a; + } + -> task_full type state = { threads: Thread.t array; @@ -23,56 +30,56 @@ let schedule_ (self : state) (task : task_full) : unit = try Bb_queue.push self.q task with Bb_queue.Closed -> raise Shutdown type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task +type worker_state = { cur_ls: Task_local_storage.storage ref } + +let k_worker_state : worker_state option ref TLS.key = + TLS.new_key (fun () -> ref None) let worker_thread_ (self : state) (runner : t) ~on_exn ~around_task : unit = - let cur_ls : Task_local_storage.storage ref = - ref Task_local_storage.Private_.Storage.dummy - in - TLS.set k_storage (Some cur_ls); + let w = { cur_ls = ref Task_local_storage.Private_.Storage.dummy } in + TLS.get k_worker_state := Some w; + TLS.set k_storage (Some w.cur_ls); TLS.get Runner.For_runner_implementors.k_cur_runner := Some runner; let (AT_pair (before_task, after_task)) = around_task in - let cur_span = ref Tracing_.dummy_span in - - let[@inline] exit_span_ () = - Tracing_.exit_span !cur_span; - cur_span := Tracing_.dummy_span - in - let on_suspend () = - exit_span_ (); - !cur_ls - in - - let run_another_task ls task' = - let ls' = Task_local_storage.Private_.Storage.copy ls in - schedule_ self { f = task'; ls = ls' } + match !(TLS.get k_worker_state) with + | None -> assert false + | Some w -> !(w.cur_ls) in + let run_another_task ls task' = schedule_ self @@ T_start { f = task'; ls } in + let resume ls k res = schedule_ self @@ T_resume { ls; k; x = res } in let run_task (task : task_full) : unit = - cur_ls := task.ls; + let ls = + match task with + | T_start { ls; _ } | T_resume { ls; _ } -> ls + in + w.cur_ls := ls; let _ctx = before_task runner in - let resume ls k res = schedule_ self { f = (fun () -> k res); ls } in - (* run the task now, catching errors, handling effects *) (try - Suspend_.with_suspend - (WSH { run = run_another_task; resume; on_suspend }) - task.f + match task with + | T_start { f = task; _ } -> + (* run [task()] and handle [suspend] in it *) + Suspend_.with_suspend + (WSH { on_suspend; run = run_another_task; resume }) + task + | T_resume { k; x; _ } -> + (* this is already in an effect handler *) + k x with e -> let bt = Printexc.get_raw_backtrace () in on_exn e bt); - exit_span_ (); after_task runner _ctx; - cur_ls := Task_local_storage.Private_.Storage.dummy + w.cur_ls := Task_local_storage.Private_.Storage.dummy in let main_loop () = let continue = ref true in while !continue do - assert (!cur_span = Tracing_.dummy_span); match Bb_queue.pop self.q with | task -> run_task task | exception Bb_queue.Closed -> continue := false @@ -123,7 +130,7 @@ let create ?(on_init_thread = default_thread_init_exit_) { threads = Array.make num_threads dummy; q = Bb_queue.create () } in - let run_async ~ls f = schedule_ pool { f; ls } in + let run_async ~ls f = schedule_ pool @@ T_start { f; ls } in let runner = Runner.For_runner_implementors.create