From 0fecde07fc6ed4b0d65b1a9abbe4de567c123907 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 9 Jul 2025 22:06:34 -0400 Subject: [PATCH] test: update Lwt tests to use the new Moonpool_lwt --- test/lwt/echo_client.ml | 122 ++++++++++++++++++++++------------------ test/lwt/echo_server.ml | 14 +++-- test/lwt/hash_client.ml | 34 +++++------ test/lwt/hash_server.ml | 57 +++++++++++-------- 4 files changed, 128 insertions(+), 99 deletions(-) diff --git a/test/lwt/echo_client.ml b/test/lwt/echo_client.ml index 7143d8be..ba6cfea9 100644 --- a/test/lwt/echo_client.ml +++ b/test/lwt/echo_client.ml @@ -1,93 +1,105 @@ -module M = Moonpool module M_lwt = Moonpool_lwt module Trace = Trace_core let spf = Printf.sprintf +let await_lwt = Moonpool_lwt.await_lwt let ( let@ ) = ( @@ ) -let lock_stdout = M.Lock.create () -let main ~port ~runner ~n ~n_conn () : unit Lwt.t = +let main ~port ~n ~n_conn ~verbose ~msg_per_conn () : unit = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main" in - let remaining = Atomic.make n in - let all_done = Atomic.make 0 in - - let fut_exit, prom_exit = M.Fut.make () in - - Printf.printf "connecting to port %d\n%!" port; + let t0 = Unix.gettimeofday () in + Printf.printf + "connecting to port %d (%d msg per conn, %d conns total, %d max at a time)\n\ + %!" + port msg_per_conn n n_conn; let addr = Unix.ADDR_INET (Unix.inet_addr_loopback, port) in - let rec run_task () = + let token_pool = Lwt_pool.create n_conn (fun () -> Lwt.return_unit) in + let n_msg_total = ref 0 in + + let run_task () = (* Printf.printf "running task\n%!"; *) - let n = Atomic.fetch_and_add remaining (-1) in - if n > 0 then ( - (let _sp = - Trace.enter_manual_toplevel_span ~__FILE__ ~__LINE__ "connect.client" - in - Trace.message "connecting new client…"; - M_lwt.TCP_client.with_connect addr @@ fun ic oc -> - let buf = Bytes.create 32 in + let@ () = Lwt_pool.use token_pool in - for _j = 1 to 10 do - let _sp = - Trace.enter_manual_sub_span ~parent:_sp ~__FILE__ ~__LINE__ - "write.loop" - in + let@ () = M_lwt.spawn_lwt in + let _sp = + Trace.enter_manual_span ~parent:None ~__FILE__ ~__LINE__ "connect.client" + in + Trace.message "connecting new client…"; - let s = spf "hello %d" _j in - M_lwt.IO_out.output_string oc s; - M_lwt.IO_out.flush oc; + let ic, oc = Lwt_io.open_connection addr |> await_lwt in - (* read back something *) - M_lwt.IO_in.really_input ic buf 0 (String.length s); - (let@ () = M.Lock.with_ lock_stdout in - Printf.printf "read: %s\n%!" - (Bytes.sub_string buf 0 (String.length s))); - Trace.exit_manual_span _sp; - () - done; - Trace.exit_manual_span _sp); + let cleanup () = + Trace.message "closing connection"; + Lwt_io.close ic |> await_lwt; + Lwt_io.close oc |> await_lwt + in - (* run another task *) M.Runner.run_async runner run_task - ) else ( - (* if we're the last to exit, resolve the promise *) - let n_already_done = Atomic.fetch_and_add all_done 1 in - if n_already_done = n_conn - 1 then ( - (let@ () = M.Lock.with_ lock_stdout in - Printf.printf "all done\n%!"); - M.Fut.fulfill prom_exit @@ Ok () - ) - ) + let@ () = Fun.protect ~finally:cleanup in + + let buf = Bytes.create 32 in + + for _j = 1 to msg_per_conn do + let _sp = + Trace.enter_manual_span + ~parent:(Some (Trace.ctx_of_span _sp)) + ~__FILE__ ~__LINE__ "write.loop" + in + + let s = spf "hello %d" _j in + Lwt_io.write oc s |> await_lwt; + Lwt_io.flush oc |> await_lwt; + incr n_msg_total; + + (* read back something *) + Lwt_io.read_into_exactly ic buf 0 (String.length s) |> await_lwt; + if verbose then + Printf.printf "read: %s\n%!" (Bytes.sub_string buf 0 (String.length s)); + Trace.exit_manual_span _sp; + () + done; + Trace.exit_manual_span _sp in (* start the first [n_conn] tasks *) - for _i = 1 to n_conn do - M.Runner.run_async runner run_task - done; + let futs = List.init n (fun _ -> run_task ()) in + Lwt.join futs |> await_lwt; - (* exit when [fut_exit] is resolved *) - M_lwt.lwt_of_fut fut_exit + Printf.printf "all done\n%!"; + let elapsed = Unix.gettimeofday () -. t0 in + Printf.printf " sent %d messages in %.4fs (%.2f msg/s)\n%!" !n_msg_total + elapsed + (float !n_msg_total /. elapsed); + () let () = let@ () = Trace_tef.with_setup () in Trace.set_thread_name "main"; let port = ref 0 in - let j = ref 4 in let n_conn = ref 100 in let n = ref 50_000 in + let msg_per_conn = ref 10 in + let verbose = ref false in let opts = [ "-p", Arg.Set_int port, " port"; - "-j", Arg.Set_int j, " number of threads"; "-n", Arg.Set_int n, " total number of connections"; - "--n-conn", Arg.Set_int n_conn, " number of parallel connections"; + ( "--msg-per-conn", + Arg.Set_int msg_per_conn, + " messages sent per connection" ); + "-v", Arg.Set verbose, " verbose"; + ( "--n-conn", + Arg.Set_int n_conn, + " maximum number of connections opened simultaneously" ); ] |> Arg.align in Arg.parse opts ignore "echo client"; - let@ runner = M.Ws_pool.with_ ~name:"tpool" ~num_threads:!j () in (* Lwt_engine.set @@ new Lwt_engine.libev (); *) - Lwt_main.run @@ main ~runner ~port:!port ~n:!n ~n_conn:!n_conn () + M_lwt.lwt_main @@ fun _runner -> + main ~port:!port ~n:!n ~n_conn:!n_conn ~verbose:!verbose + ~msg_per_conn:!msg_per_conn () diff --git a/test/lwt/echo_server.ml b/test/lwt/echo_server.ml index 9047e5ff..3a6d6fde 100644 --- a/test/lwt/echo_server.ml +++ b/test/lwt/echo_server.ml @@ -11,7 +11,7 @@ let str_of_sockaddr = function | Unix.ADDR_INET (addr, port) -> spf "%s:%d" (Unix.string_of_inet_addr addr) port -let main ~port ~runner:_ () : unit Lwt.t = +let main ~port ~verbose ~runner:_ () : unit Lwt.t = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main" in let lwt_fut, _lwt_prom = Lwt.wait () in @@ -26,7 +26,8 @@ let main ~port ~runner:_ () : unit Lwt.t = ~data:(fun () -> [ "addr", `String (str_of_sockaddr client_addr) ]) in - Printf.printf "got new client on %s\n%!" (str_of_sockaddr client_addr); + if verbose then + Printf.printf "got new client on %s\n%!" (str_of_sockaddr client_addr); let buf = Bytes.create 32 in let continue = ref true in @@ -42,6 +43,8 @@ let main ~port ~runner:_ () : unit Lwt.t = Trace.message "write" ) done; + if verbose then + Printf.printf "done with client on %s\n%!" (str_of_sockaddr client_addr); Trace.exit_manual_span _sp; Trace.message "exit handle client" in @@ -58,10 +61,13 @@ let () = Trace.set_thread_name "main"; let port = ref 0 in let j = ref 4 in + let verbose = ref false in let opts = [ - "-p", Arg.Set_int port, " port"; "-j", Arg.Set_int j, " number of threads"; + "-v", Arg.Set verbose, " verbose"; + "-p", Arg.Set_int port, " port"; + "-j", Arg.Set_int j, " number of threads"; ] |> Arg.align in @@ -69,4 +75,4 @@ let () = let@ runner = M.Ws_pool.with_ ~name:"tpool" ~num_threads:!j () in (* Lwt_engine.set @@ new Lwt_engine.libev (); *) - Lwt_main.run @@ main ~runner ~port:!port () + Lwt_main.run @@ main ~runner ~port:!port ~verbose:!verbose () diff --git a/test/lwt/hash_client.ml b/test/lwt/hash_client.ml index 085666fb..1ea3fcea 100644 --- a/test/lwt/hash_client.ml +++ b/test/lwt/hash_client.ml @@ -8,10 +8,10 @@ module Str_tbl = Hashtbl.Make (struct let hash = Hashtbl.hash end) +let await_lwt = Moonpool_lwt.await_lwt let ( let@ ) = ( @@ ) -let lock_stdout = M.Lock.create () -let main ~port ~runner ~ext ~dir ~n_conn () : unit Lwt.t = +let main ~port ~ext ~dir ~n_conn () : unit = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main" in Printf.printf "hash dir=%S\n%!" dir; @@ -20,12 +20,15 @@ let main ~port ~runner ~ext ~dir ~n_conn () : unit Lwt.t = let addr = Unix.ADDR_INET (Unix.inet_addr_loopback, port) in (* TODO: *) - let run_task () : unit = - let _sp = Trace.enter_manual_toplevel_span ~__FILE__ ~__LINE__ "run-task" in + let run_task () : unit Lwt.t = + let@ () = M_lwt.spawn_lwt in + let _sp = + Trace.enter_manual_span ~parent:None ~__FILE__ ~__LINE__ "run-task" + in let seen = Str_tbl.create 16 in - M_lwt.TCP_client.with_connect_lwt addr @@ fun ic oc -> + let ic, oc = Lwt_io.open_connection addr |> await_lwt in let rec walk file : unit = if not (Sys.file_exists file) then () @@ -33,7 +36,9 @@ let main ~port ~runner ~ext ~dir ~n_conn () : unit Lwt.t = () else if Sys.is_directory file then ( let _sp = - Trace.enter_manual_sub_span ~parent:_sp ~__FILE__ ~__LINE__ "walk-dir" + Trace.enter_manual_span + ~parent:(Some (Trace.ctx_of_span _sp)) + ~__FILE__ ~__LINE__ "walk-dir" ~data:(fun () -> [ "d", `String file ]) in @@ -45,9 +50,8 @@ let main ~port ~runner ~ext ~dir ~n_conn () : unit Lwt.t = () else ( Str_tbl.add seen file (); - M_lwt.run_in_lwt_and_await (fun () -> Lwt_io.write_line oc file); - let res = M_lwt.run_in_lwt_and_await (fun () -> Lwt_io.read_line ic) in - let@ () = M.Lock.with_ lock_stdout in + Lwt_io.write_line oc file |> await_lwt; + let res = Lwt_io.read_line ic |> await_lwt in Printf.printf "%s\n%!" res ) in @@ -56,16 +60,14 @@ let main ~port ~runner ~ext ~dir ~n_conn () : unit Lwt.t = in (* start the first [n_conn] tasks *) - let futs = List.init n_conn (fun _ -> M.Fut.spawn ~on:runner run_task) in - - Lwt.join (List.map M_lwt.lwt_of_fut futs) + let futs = List.init n_conn (fun _ -> run_task ()) in + Lwt.join futs |> await_lwt let () = let@ () = Trace_tef.with_setup () in Trace.set_thread_name "main"; let port = ref 1234 in - let j = ref 4 in let n_conn = ref 100 in let ext = ref "" in let dir = ref "." in @@ -73,7 +75,6 @@ let () = let opts = [ "-p", Arg.Set_int port, " port"; - "-j", Arg.Set_int j, " number of threads"; "-d", Arg.Set_string dir, " directory to hash"; "--n-conn", Arg.Set_int n_conn, " number of parallel connections"; "--ext", Arg.Set_string ext, " extension to filter files"; @@ -82,7 +83,6 @@ let () = in Arg.parse opts ignore "echo client"; - let@ runner = M.Ws_pool.with_ ~name:"tpool" ~num_threads:!j () in (* Lwt_engine.set @@ new Lwt_engine.libev (); *) - Lwt_main.run - @@ main ~runner ~port:!port ~ext:!ext ~dir:!dir ~n_conn:!n_conn () + M_lwt.lwt_main @@ fun _runner -> + main ~port:!port ~ext:!ext ~dir:!dir ~n_conn:!n_conn () diff --git a/test/lwt/hash_server.ml b/test/lwt/hash_server.ml index a84f6ccb..3038430f 100644 --- a/test/lwt/hash_server.ml +++ b/test/lwt/hash_server.ml @@ -134,10 +134,11 @@ let sha_1 s = (* server that reads from sockets lists of files, and returns hashes of these files *) -module M = Moonpool module M_lwt = Moonpool_lwt module Trace = Trace_core +module Fut = Moonpool.Fut +let await_lwt = Moonpool_lwt.await_lwt let ( let@ ) = ( @@ ) let spf = Printf.sprintf @@ -165,7 +166,7 @@ let read_file filename : string = in In_channel.with_open_bin filename In_channel.input_all -let main ~port ~runner () : unit Lwt.t = +let main ~port ~runner () : unit = let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "main" in let lwt_fut, _lwt_prom = Lwt.wait () in @@ -173,38 +174,39 @@ let main ~port ~runner () : unit Lwt.t = (* TODO: handle exit?? *) Printf.printf "listening on port %d\n%!" port; - let handle_client client_addr ic oc = + let handle_client client_addr (ic, oc) = + let@ () = Moonpool_lwt.spawn_lwt in let _sp = - Trace.enter_manual_toplevel_span ~__FILE__ ~__LINE__ "handle.client" + Trace.enter_manual_span ~parent:None ~__FILE__ ~__LINE__ "handle.client" ~data:(fun () -> [ "addr", `String (str_of_sockaddr client_addr) ]) in try while true do Trace.message "read"; - let filename = - M_lwt.run_in_lwt_and_await (fun () -> Lwt_io.read_line ic) - |> String.trim - in + let filename = Lwt_io.read_line ic |> await_lwt |> String.trim in Trace.messagef (fun k -> k "hash %S" filename); match read_file filename with | exception e -> Printf.eprintf "error while reading %S:\n%s\n" filename (Printexc.to_string e); - M_lwt.run_in_lwt_and_await (fun () -> - Lwt_io.write_line oc (spf "%s: error" filename)); - M_lwt.run_in_lwt_and_await (fun () -> Lwt_io.flush oc) + Lwt_io.write_line oc (spf "%s: error" filename) |> await_lwt; + Lwt_io.flush oc |> await_lwt | content -> - (* got the content, now hash it *) - let hash = - let@ _sp = Trace.with_span ~__FILE__ ~__LINE__ "hash" in + (* got the content, now hash it in a background task *) + let hash : _ Fut.t = + let@ () = Moonpool.spawn ~on:runner in + let@ _sp = + Trace.with_span ~__FILE__ ~__LINE__ "hash" ~data:(fun () -> + [ "file", `String filename ]) + in sha_1 content |> to_hex in - M_lwt.run_in_lwt_and_await (fun () -> - Lwt_io.write_line oc (spf "%s: %s" filename hash)); - M_lwt.run_in_lwt_and_await (fun () -> Lwt_io.flush oc) + let hash = Fut.await hash in + Lwt_io.write_line oc (spf "%s: %s" filename hash) |> await_lwt; + Lwt_io.flush oc |> await_lwt done with End_of_file | Unix.Unix_error (Unix.ECONNRESET, _, _) -> Trace.exit_manual_span _sp; @@ -212,16 +214,17 @@ let main ~port ~runner () : unit Lwt.t = in let addr = Unix.ADDR_INET (Unix.inet_addr_loopback, port) in - let _server = M_lwt.TCP_server.establish_lwt ~runner addr handle_client in - Printf.printf "listening on port=%d\n%!" port; + let _server = + Lwt_io.establish_server_with_client_address addr handle_client |> await_lwt + in - lwt_fut + lwt_fut |> await_lwt let () = let@ () = Trace_tef.with_setup () in Trace.set_thread_name "main"; let port = ref 1234 in - let j = ref 4 in + let j = ref 0 in let opts = [ @@ -231,6 +234,14 @@ let () = in Arg.parse opts ignore "echo server"; - let@ runner = M.Ws_pool.with_ ~name:"tpool" ~num_threads:!j () in (* Lwt_engine.set @@ new Lwt_engine.libev (); *) - Lwt_main.run @@ main ~runner ~port:!port () + let@ runner = + let num_threads = + if !j = 0 then + None + else + Some !j + in + Moonpool.Ws_pool.with_ ?num_threads () + in + M_lwt.lwt_main @@ fun _main_runner -> main ~runner ~port:!port ()