diff --git a/src/core/ws_pool.ml b/src/core/ws_pool.ml index 16ad4fc7..956dd9ce 100644 --- a/src/core/ws_pool.ml +++ b/src/core/ws_pool.ml @@ -14,13 +14,20 @@ module Id = struct let equal : t -> t -> bool = ( == ) end -type task_full = { - f: task; - ls: Task_local_storage.storage; -} - type around_task = AT_pair : (t -> 'a) * (t -> 'a -> unit) -> around_task +type task_full = + | T_start of { + ls: Task_local_storage.storage; + f: task; + } + | T_resume : { + ls: Task_local_storage.storage; + k: 'a -> unit; + x: 'a; + } + -> task_full + type worker_state = { pool_id_: Id.t; (** Unique per pool *) mutable thread: Thread.t; @@ -73,10 +80,8 @@ let[@inline] try_wake_someone_ (self : state) : unit = ) (** Run [task] as is, on the pool. *) -let schedule_task_ (self : state) ~ls (w : worker_state option) (f : task) - : unit = +let schedule_task_ (self : state) ~w (task : task_full) : unit = (* Printf.printf "schedule task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) - let task = { f; ls } in match w with | Some w when Id.equal self.id_ w.pool_id_ -> (* we're on this same pool, schedule in the worker's state. Otherwise @@ -105,40 +110,59 @@ let schedule_task_ (self : state) ~ls (w : worker_state option) (f : task) raise Shutdown (** Run this task, now. Must be called from a worker. *) -let run_task_now_ (self : state) ~runner (w : worker_state) ~ls task : - unit = +let run_task_now_ (self : state) ~runner ~(w : worker_state) (task : task_full) + : unit = (* Printf.printf "run task now (%d)\n%!" (Thread.id @@ Thread.self ()); *) let (AT_pair (before_task, after_task)) = self.around_task in + + let ls = + match task with + | T_start { ls; _ } | T_resume { ls; _ } -> ls + in + w.cur_ls := ls; let _ctx = before_task runner in let[@inline] on_suspend () = - let w' = find_current_worker_ () in - let ls= !(w.cur_ls) in - w', ls + let w = + match find_current_worker_ () with + | Some w -> w + | None -> assert false + in + let ls = !(w.cur_ls) in + ls in - let run_another_task (w,ls) task' = + let run_another_task ls (task' : task) = + let w = + match find_current_worker_ () with + | Some w when Id.equal w.pool_id_ self.id_ -> Some w + | _ -> None + in let ls' = Task_local_storage.Private_.Storage.copy ls in - schedule_task_ self w ~ls:ls' task' + schedule_task_ self ~w @@ T_start { ls = ls'; f = task' } in - let resume (w,ls) k r = - schedule_task_ self w ~ls (fun () -> k r) + let resume ls k x = + let w = + match find_current_worker_ () with + | Some w when Id.equal w.pool_id_ self.id_ -> Some w + | _ -> None + in + schedule_task_ self ~w @@ T_resume { ls; k; x } in (* run the task now, catching errors *) (try - (* run [task()] and handle [suspend] in it *) -[@@@ifge 5.0] - Suspend_.with_suspend (WSH { - on_suspend; - run=run_another_task; - resume; - }) task -[@@@else_] - task () -[@@@endif] + match task with + | T_start { f = task; _ } -> + (* run [task()] and handle [suspend] in it *) + Suspend_.with_suspend + (WSH { on_suspend; run = run_another_task; resume }) + task + | T_resume { k; x; _ } -> + (* this is already in an effect handler *) + k x with e -> let bt = Printexc.get_raw_backtrace () in self.on_exn e bt); @@ -146,9 +170,9 @@ let run_task_now_ (self : state) ~runner (w : worker_state) ~ls task : after_task runner _ctx; w.cur_ls := Task_local_storage.Private_.Storage.dummy -let[@inline] run_async_ (self : state) ~ls (f : task) : unit = +let run_async_ (self : state) ~ls (f : task) : unit = let w = find_current_worker_ () in - schedule_task_ self w ~ls f + schedule_task_ self ~w @@ T_start { f; ls } (* TODO: function to schedule many tasks from the outside. - build a queue @@ -194,7 +218,7 @@ let worker_run_self_tasks_ (self : state) ~runner w : unit = match WSQ.pop w.q with | Some task -> try_wake_someone_ self; - run_task_now_ self ~runner w ~ls:task.ls task.f + run_task_now_ self ~runner ~w task | None -> continue := false done @@ -207,7 +231,7 @@ let worker_thread_ (self : state) ~(runner : t) (w : worker_state) : unit = worker_run_self_tasks_ self ~runner w; try_steal () and run_task task : unit = - run_task_now_ self ~runner w ~ls:task.ls task.f; + run_task_now_ self ~runner ~w task; main () and try_steal () = match try_to_steal_work_once_ self w with @@ -266,7 +290,8 @@ type ('a, 'b) create_args = 'a (** Arguments used in {!create}. See {!create} for explanations. *) -let dummy_task_ = { f = ignore; ls = Task_local_storage.Private_.Storage.dummy ; } +let dummy_task_ : task_full = + T_start { f = ignore; ls = Task_local_storage.Private_.Storage.dummy } let create ?(on_init_thread = default_thread_init_exit_) ?(on_exit_thread = default_thread_init_exit_) ?(on_exn = fun _ _ -> ())