perf fork-join: in both f g only run f in the background

`g` can run immediately on same thread, otherwise we just suspend the
computation and start a new task for nothing.
This commit is contained in:
Simon Cruanes 2023-11-06 00:09:38 -05:00
parent 6e6a2a1faa
commit d2be2db0ef

View file

@ -3,91 +3,100 @@
module A = Atomic_ module A = Atomic_
module State_ = struct module State_ = struct
type 'a single_res = type error = exn * Printexc.raw_backtrace
| St_none type 'a or_error = ('a, error) result
| St_some of 'a
| St_fail of exn * Printexc.raw_backtrace
type ('a, 'b) t = { type ('a, 'b) t =
mutable suspension: | Init
((unit, exn * Printexc.raw_backtrace) result -> unit) option; | Left_solved of 'a or_error
(** suspended caller *) | Right_solved of 'b or_error * Suspend_.suspension
left: 'a single_res; | Both_solved of 'a or_error * 'b or_error
right: 'b single_res;
}
let get_exn (self : _ t A.t) = let get_exn_ (self : _ t A.t) =
match A.get self with match A.get self with
| { left = St_fail (e, bt); _ } | { right = St_fail (e, bt); _ } -> | Both_solved (Ok a, Ok b) -> a, b
Printexc.raise_with_backtrace e bt | Both_solved (Error (exn, bt), _) | Both_solved (_, Error (exn, bt)) ->
| { left = St_some x; right = St_some y; _ } -> x, y Printexc.raise_with_backtrace exn bt
| _ -> assert false | _ -> assert false
let check_if_state_complete_ (self : _ t) : unit = let rec set_left_ (self : _ t A.t) (left : _ or_error) =
match self.left, self.right, self.suspension with let old_st = A.get self in
| St_some _, St_some _, Some f -> f (Ok ()) match old_st with
| St_fail (e, bt), _, Some f | _, St_fail (e, bt), Some f -> | Init ->
f (Error (e, bt)) let new_st = Left_solved left in
| _ -> () if not (A.compare_and_set self old_st new_st) then (
Domain_.relax ();
let set_left_ (self : _ t A.t) (x : _ single_res) = set_left_ self left
while )
let old_st = A.get self in | Right_solved (right, cont) ->
let new_st = { old_st with left = x } in let new_st = Both_solved (left, right) in
if A.compare_and_set self old_st new_st then ( if not (A.compare_and_set self old_st new_st) then (
check_if_state_complete_ new_st; Domain_.relax ();
false set_left_ self left
) else ) else
true cont (Ok ())
do | Left_solved _ | Both_solved _ -> assert false
Domain_.relax ()
done
let set_right_ (self : _ t A.t) (y : _ single_res) = let rec set_right_ (self : _ t A.t) (right : _ or_error) : unit =
while let old_st = A.get self in
let old_st = A.get self in match old_st with
let new_st = { old_st with right = y } in | Left_solved left ->
if A.compare_and_set self old_st new_st then ( let new_st = Both_solved (left, right) in
check_if_state_complete_ new_st; if not (A.compare_and_set self old_st new_st) then set_right_ self right
false | Init ->
) else (* we are first arrived, we suspend until the left computation is done *)
true Suspend_.suspend
do {
Domain_.relax () Suspend_.handle =
done (fun ~run:_ suspension ->
while
let old_st = A.get self in
match old_st with
| Init ->
not
(A.compare_and_set self old_st
(Right_solved (right, suspension)))
| Left_solved left ->
(* other thread is done, no risk of race condition *)
A.set self (Both_solved (left, right));
suspension (Ok ());
false
| Right_solved _ | Both_solved _ -> assert false
do
()
done);
}
| Right_solved _ | Both_solved _ -> assert false
end end
let both f g : _ * _ = let both f g : _ * _ =
let open State_ in let module ST = State_ in
let st = A.make { suspension = None; left = St_none; right = St_none } in let st = A.make ST.Init in
let start_tasks ~run () : unit = let runner =
run (fun () -> match Runner.get_current_runner () with
try | None -> invalid_arg "Fork_join.both must be run from within a runner"
let res = f () in | Some r -> r
set_left_ st (St_some res)
with e ->
let bt = Printexc.get_raw_backtrace () in
set_left_ st (St_fail (e, bt)));
run (fun () ->
try
let res = g () in
set_right_ st (St_some res)
with e ->
let bt = Printexc.get_raw_backtrace () in
set_right_ st (St_fail (e, bt)))
in in
Suspend_.suspend (* start computing [f] in the background *)
{ Runner.run_async runner (fun () ->
Suspend_.handle = try
(fun ~run suspension -> let res = f () in
(* nothing else is started, no race condition possible *) ST.set_left_ st (Ok res)
(A.get st).suspension <- Some suspension; with exn ->
start_tasks ~run ()); let bt = Printexc.get_raw_backtrace () in
}; ST.set_left_ st (Error (exn, bt)));
get_exn st
let res_right =
try Ok (g ())
with exn ->
let bt = Printexc.get_raw_backtrace () in
Error (exn, bt)
in
ST.set_right_ st res_right;
ST.get_exn_ st
let both_ignore f g = ignore (both f g : _ * _) let both_ignore f g = ignore (both f g : _ * _)