diff --git a/src/fork_join.ml b/src/fork_join.ml index ac5ba5d7..8ad61cec 100644 --- a/src/fork_join.ml +++ b/src/fork_join.ml @@ -3,91 +3,100 @@ module A = Atomic_ module State_ = struct - type 'a single_res = - | St_none - | St_some of 'a - | St_fail of exn * Printexc.raw_backtrace + type error = exn * Printexc.raw_backtrace + type 'a or_error = ('a, error) result - type ('a, 'b) t = { - mutable suspension: - ((unit, exn * Printexc.raw_backtrace) result -> unit) option; - (** suspended caller *) - left: 'a single_res; - right: 'b single_res; - } + type ('a, 'b) t = + | Init + | Left_solved of 'a or_error + | Right_solved of 'b or_error * Suspend_.suspension + | Both_solved of 'a or_error * 'b or_error - let get_exn (self : _ t A.t) = + let get_exn_ (self : _ t A.t) = match A.get self with - | { left = St_fail (e, bt); _ } | { right = St_fail (e, bt); _ } -> - Printexc.raise_with_backtrace e bt - | { left = St_some x; right = St_some y; _ } -> x, y + | Both_solved (Ok a, Ok b) -> a, b + | Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) -> + Printexc.raise_with_backtrace exn bt | _ -> assert false - let check_if_state_complete_ (self : _ t) : unit = - match self.left, self.right, self.suspension with - | St_some _, St_some _, Some f -> f (Ok ()) - | St_fail (e, bt), _, Some f | _, St_fail (e, bt), Some f -> - f (Error (e, bt)) - | _ -> () - - let set_left_ (self : _ t A.t) (x : _ single_res) = - while - let old_st = A.get self in - let new_st = { old_st with left = x } in - if A.compare_and_set self old_st new_st then ( - check_if_state_complete_ new_st; - false + let rec set_left_ (self : _ t A.t) (left : _ or_error) = + let old_st = A.get self in + match old_st with + | Init -> + let new_st = Left_solved left in + if not (A.compare_and_set self old_st new_st) then ( + Domain_.relax (); + set_left_ self left + ) + | Right_solved (right, cont) -> + let new_st = Both_solved (left, right) in + if not (A.compare_and_set self old_st new_st) then ( + Domain_.relax (); + set_left_ self left ) else - true - do - Domain_.relax () - done + cont (Ok ()) + | Left_solved _ | Both_solved _ -> assert false - let set_right_ (self : _ t A.t) (y : _ single_res) = - while - let old_st = A.get self in - let new_st = { old_st with right = y } in - if A.compare_and_set self old_st new_st then ( - check_if_state_complete_ new_st; - false - ) else - true - do - Domain_.relax () - done + let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit = + let old_st = A.get self in + match old_st with + | Left_solved left -> + let new_st = Both_solved (left, right) in + if not (A.compare_and_set self old_st new_st) then set_right_ self right + | Init -> + (* we are first arrived, we suspend until the left computation is done *) + Suspend_.suspend + { + Suspend_.handle = + (fun ~run:_ suspension -> + while + let old_st = A.get self in + match old_st with + | Init -> + not + (A.compare_and_set self old_st + (Right_solved (right, suspension))) + | Left_solved left -> + (* other thread is done, no risk of race condition *) + A.set self (Both_solved (left, right)); + suspension (Ok ()); + false + | Right_solved _ | Both_solved _ -> assert false + do + () + done); + } + | Right_solved _ | Both_solved _ -> assert false end let both f g : _ * _ = - let open State_ in - let st = A.make { suspension = None; left = St_none; right = St_none } in + let module ST = State_ in + let st = A.make ST.Init in - let start_tasks ~run () : unit = - run (fun () -> - try - let res = f () in - set_left_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_left_ st (St_fail (e, bt))); - - run (fun () -> - try - let res = g () in - set_right_ st (St_some res) - with e -> - let bt = Printexc.get_raw_backtrace () in - set_right_ st (St_fail (e, bt))) + let runner = + match Runner.get_current_runner () with + | None -> invalid_arg "Fork_join.both must be run from within a runner" + | Some r -> r in - Suspend_.suspend - { - Suspend_.handle = - (fun ~run suspension -> - (* nothing else is started, no race condition possible *) - (A.get st).suspension <- Some suspension; - start_tasks ~run ()); - }; - get_exn st + (* start computing [f] in the background *) + Runner.run_async runner (fun () -> + try + let res = f () in + ST.set_left_ st (Ok res) + with exn -> + let bt = Printexc.get_raw_backtrace () in + ST.set_left_ st (Error (exn, bt))); + + let res_right = + try Ok (g ()) + with exn -> + let bt = Printexc.get_raw_backtrace () in + Error (exn, bt) + in + + ST.set_right_ st res_right; + ST.get_exn_ st let both_ignore f g = ignore (both f g : _ * _)