diff --git a/src/jsonrpc2.ml b/src/jsonrpc2.ml index b276a198..7098c8e2 100644 --- a/src/jsonrpc2.ml +++ b/src/jsonrpc2.ml @@ -60,9 +60,13 @@ module Make (IO : IO) : S with module IO = IO = struct ic: IO.in_channel; oc: IO.out_channel; s: server; + mutable id_counter: int; + pending_responses: (Req_id.t, server_request_handler_pair) Hashtbl.t; } - let create ~ic ~oc server : t = { ic; oc; s = server } + let create ~ic ~oc server : t = + { ic; oc; s = server; id_counter = 0; pending_responses = Hashtbl.create 8 } + let create_stdio server : t = create ~ic:IO.stdin ~oc:IO.stdout server (* send a single message *) @@ -82,6 +86,27 @@ module Make (IO : IO) : S with module IO = IO = struct let json = Jsonrpc.Notification.yojson_of_t m in send_json_ self json + let send_server_req (self : t) (m : Jsonrpc.Request.t) : unit IO.t = + let json = Jsonrpc.Request.yojson_of_t m in + send_json_ self json + + (** Returns a new, unused [Req_id.t] to send a server request. *) + let fresh_lsp_id (self : t) : Req_id.t = + let id = self.id_counter in + self.id_counter <- id + 1; + `Int id + + (** Registers a new handler for a request response. The return indicates + whether a value was inserted or not (in which case it's already present). *) + let register_server_request_response_handler (self : t) (id : Req_id.t) + (handler : server_request_handler_pair) : bool = + if Hashtbl.mem self.pending_responses id then + false + else ( + let () = Hashtbl.add self.pending_responses id handler in + true + ) + let try_ f = IO.catch (fun () -> @@ -89,6 +114,134 @@ module Make (IO : IO) : S with module IO = IO = struct Ok x) (fun e -> IO.return (Error e)) + (** Sends a server notification to the LSP client. *) + let server_notification (self : t) (n : Lsp.Server_notification.t) : unit IO.t + = + let msg = Lsp.Server_notification.to_jsonrpc n in + send_server_notif self msg + + (** Given a [server_request_handler_pair] consisting of some server request + and its handler, sends this request to the LSP client and adds the handler + to a table of pending responses. The request will later be handled by + [handle_response], which will call the provided handler and delete it from + the table of pending responses. *) + let server_request (self : t) (req : server_request_handler_pair) : + Req_id.t IO.t = + let (Request_and_handler (r, _)) = req in + let id = fresh_lsp_id self in + let msg = Lsp.Server_request.to_jsonrpc_request r ~id in + let has_inserted = register_server_request_response_handler self id req in + if has_inserted then + let* () = send_server_req self msg in + return id + else + IO.failwith "failed to register server request: id was already used" + + (** Wraps some action and, in case the [IO.t] request has failed, logs the + failure to the LSP client. *) + let with_error_handler (self : t) (action : unit -> unit IO.t) : unit IO.t = + IO.catch action (fun e -> + let msg = + Lsp.Types.LogMessageParams.create ~type_:Lsp.Types.MessageType.Error + ~message:(Printexc.to_string e) + in + let msg = + Lsp.Server_notification.LogMessage msg + |> Lsp.Server_notification.to_jsonrpc + in + send_server_notif self msg) + + let handle_notification (self : t) (n : Jsonrpc.Notification.t) : unit IO.t = + match Lsp.Client_notification.of_jsonrpc n with + | Ok n -> + with_error_handler self (fun () -> + self.s#on_notification n ~notify_back:(server_notification self) + ~server_request:(server_request self)) + | Error e -> IO.failwith (spf "cannot decode notification: %s" e) + + let handle_request (self : t) (r : Jsonrpc.Request.t) : unit IO.t = + let protect ~id f = + IO.catch f (fun e -> + let message = + spf "%s\n%s" (Printexc.to_string e) (Printexc.get_backtrace ()) + in + Log.err (fun k -> k "error: %s" message); + let r = + Jsonrpc.Response.error id + (Jsonrpc.Response.Error.make + ~code:Jsonrpc.Response.Error.Code.InternalError ~message ()) + in + send_response self r) + in + (* request, so we need to reply *) + let id = r.id in + IO.catch + (fun () -> + match Lsp.Client_request.of_jsonrpc r with + | Ok (Lsp.Client_request.E r) -> + protect ~id (fun () -> + let* reply = + self.s#on_request r ~id ~notify_back:(server_notification self) + ~server_request:(server_request self) + in + let reply_json = Lsp.Client_request.yojson_of_result r reply in + let response = Jsonrpc.Response.ok id reply_json in + send_response self response) + | Error e -> IO.failwith (spf "cannot decode request: %s" e)) + (fun e -> + let message = + spf "%s\n%s" (Printexc.to_string e) (Printexc.get_backtrace ()) + in + Log.err (fun k -> k "error: %s" message); + let r = + Jsonrpc.Response.error id + (Jsonrpc.Response.Error.make + ~code:Jsonrpc.Response.Error.Code.InternalError ~message ()) + in + send_response self r) + + let handle_response (self : t) (r : Jsonrpc.Response.t) : unit IO.t = + match Hashtbl.find_opt self.pending_responses r.id with + | None -> + IO.failwith + @@ Printf.sprintf "server request not found for response of id %s" + @@ Req_id.to_string r.id + | Some (Request_and_handler (req, handler)) -> + let () = Hashtbl.remove self.pending_responses r.id in + (match r.result with + | Error err -> with_error_handler self (fun () -> handler @@ Error err) + | Ok json -> + let r = Lsp.Server_request.response_of_json req json in + with_error_handler self (fun () -> handler @@ Ok r)) + + let handle_batch_response (self : t) (rs : Jsonrpc.Response.t list) : + unit IO.t = + let rec go = function + | [] -> IO.return () + | r :: rs -> + let* () = handle_response self r in + go rs + in + go rs + + let handle_batch_call (self : t) + (cs : + [ `Notification of Jsonrpc.Notification.t + | `Request of Jsonrpc.Request.t + ] + list) : unit IO.t = + let rec go = function + | [] -> IO.return () + | c :: cs -> + let* () = + match c with + | `Notification n -> handle_notification self n + | `Request r -> handle_request self r + in + go cs + in + go cs + (* read a full message *) let read_msg (self : t) : (Jsonrpc.Packet.t, exn) result IO.t = let rec read_headers acc = @@ -137,8 +290,9 @@ module Make (IO : IO) : S with module IO = IO = struct Log.debug (fun k -> k "got json %s" (J.to_string j)); (match Jsonrpc.Packet.t_of_yojson j with | m -> IO.return @@ Ok m - | exception _ -> - Log.err (fun k -> k "cannot decode json message"); + | exception exn -> + Log.err (fun k -> + k "cannot decode json message: %s" (Printexc.to_string exn)); IO.return (Error (E (ErrorCode.ParseError, "cannot decode json")))) | exception _ -> IO.return @@ -150,72 +304,12 @@ module Make (IO : IO) : S with module IO = IO = struct let run ?(shutdown = fun _ -> false) (self : t) : unit IO.t = let process_msg r = let module M = Jsonrpc.Packet in - let protect ~id f = - IO.catch f (fun e -> - let message = - spf "%s\n%s" (Printexc.to_string e) (Printexc.get_backtrace ()) - in - Log.err (fun k -> k "error: %s" message); - let r = - Jsonrpc.Response.error id - (Jsonrpc.Response.Error.make - ~code:Jsonrpc.Response.Error.Code.InternalError ~message ()) - in - send_response self r) - in match r with - | M.Notification n -> - (* notification *) - (match Lsp.Client_notification.of_jsonrpc n with - | Ok n -> - IO.catch - (fun () -> - self.s#on_notification n ~notify_back:(fun n -> - let msg = Lsp.Server_notification.to_jsonrpc n in - send_server_notif self msg)) - (fun e -> - let msg = - Lsp.Types.LogMessageParams.create - ~type_:Lsp.Types.MessageType.Error - ~message:(Printexc.to_string e) - in - let msg = - Lsp.Server_notification.LogMessage msg - |> Lsp.Server_notification.to_jsonrpc - in - send_server_notif self msg) - | Error e -> IO.failwith (spf "cannot decode notification: %s" e)) - | M.Request r -> - (* request, so we need to reply *) - let id = r.id in - IO.catch - (fun () -> - match Lsp.Client_request.of_jsonrpc r with - | Ok (Lsp.Client_request.E r) -> - protect ~id (fun () -> - let* reply = - self.s#on_request r ~id ~notify_back:(fun n -> - let msg = Lsp.Server_notification.to_jsonrpc n in - send_server_notif self msg) - in - let reply_json = - Lsp.Client_request.yojson_of_result r reply - in - let response = Jsonrpc.Response.ok id reply_json in - send_response self response) - | Error e -> IO.failwith (spf "cannot decode request: %s" e)) - (fun e -> - let message = - spf "%s\n%s" (Printexc.to_string e) (Printexc.get_backtrace ()) - in - Log.err (fun k -> k "error: %s" message); - let r = - Jsonrpc.Response.error id - (Jsonrpc.Response.Error.make - ~code:Jsonrpc.Response.Error.Code.InternalError ~message ()) - in - send_response self r) - | _p -> IO.failwith "neither notification nor request" + | M.Notification n -> handle_notification self n + | M.Request r -> handle_request self r + | M.Response r -> handle_response self r + | M.Batch_response rs -> handle_batch_response self rs + | M.Batch_call cs -> handle_batch_call self cs in let rec loop () = if shutdown () then diff --git a/src/server.ml b/src/server.ml index 320b8a6b..a25d397a 100644 --- a/src/server.ml +++ b/src/server.ml @@ -29,17 +29,32 @@ module Make (IO : IO) = struct module DiagnosticSeverity = DiagnosticSeverity module Req_id = Req_id + (** A variant carrying a [Lsp.Server_request.t] and a handler for its return + value. The request is stored in order to allow us to discriminate its + existential variable. *) + type server_request_handler_pair = + | Request_and_handler : + 'from_server Lsp.Server_request.t + * (('from_server, Jsonrpc.Response.Error.t) result -> unit IO.t) + -> server_request_handler_pair + + type send_request = server_request_handler_pair -> Req_id.t IO.t + (** The type of the action that sends a request from the server to the client + and handles its response. *) + (** The server baseclass *) class virtual base_server = object method virtual on_notification : notify_back:(Lsp.Server_notification.t -> unit IO.t) -> + server_request:send_request -> Lsp.Client_notification.t -> unit IO.t method virtual on_request : 'a. notify_back:(Lsp.Server_notification.t -> unit IO.t) -> + server_request:send_request -> id:Req_id.t -> 'a Lsp.Client_request.t -> 'a IO.t @@ -53,8 +68,8 @@ module Make (IO : IO) = struct end (** A wrapper to more easily reply to notifications *) - class notify_back ~notify_back ~workDoneToken ~partialResultToken:_ ?version - ?(uri : DocumentUri.t option) () = + class notify_back ~notify_back ~server_request ~workDoneToken + ~partialResultToken:_ ?version ?(uri : DocumentUri.t option) () = object val mutable uri = uri method set_uri u = uri <- Some u @@ -109,7 +124,15 @@ module Make (IO : IO) = struct | None -> IO.return () method send_notification (n : Lsp.Server_notification.t) = notify_back n - (** Send a notification (general purpose method) *) + (** Send a notification from the server to the client (general purpose method) *) + + method send_request + : 'from_server. + 'from_server Lsp.Server_request.t -> + (('from_server, Jsonrpc.Response.Error.t) result -> unit IO.t) -> + Req_id.t IO.t = + fun r h -> server_request @@ Request_and_handler (r, h) + (** Send a request from the server to the client (general purpose method) *) end type nonrec doc_state = doc_state = { @@ -271,8 +294,12 @@ module Make (IO : IO) = struct @since 0.3 *) method on_request : type r. - notify_back:_ -> id:Req_id.t -> r Lsp.Client_request.t -> r IO.t = - fun ~notify_back ~id (r : _ Lsp.Client_request.t) -> + notify_back:_ -> + server_request:_ -> + id:Req_id.t -> + r Lsp.Client_request.t -> + r IO.t = + fun ~notify_back ~server_request ~id (r : _ Lsp.Client_request.t) -> Log.debug (fun k -> k "handle request[id=%s] " (Req_id.to_string id)); @@ -286,7 +313,7 @@ module Make (IO : IO) = struct let notify_back = new notify_back ~partialResultToken:None ~workDoneToken:i.workDoneToken - ~notify_back () + ~notify_back ~server_request () in self#on_req_initialize ~notify_back i | Lsp.Client_request.TextDocumentHover @@ -299,7 +326,8 @@ module Make (IO : IO) = struct | Some doc_st -> let notify_back = new notify_back - ~workDoneToken ~partialResultToken:None ~uri ~notify_back () + ~workDoneToken ~partialResultToken:None ~uri ~notify_back + ~server_request () in self#on_req_hover ~notify_back ~id ~uri ~pos:position ~workDoneToken doc_st) @@ -319,7 +347,8 @@ module Make (IO : IO) = struct | Some doc_st -> let notify_back = new notify_back - ~partialResultToken ~workDoneToken ~uri ~notify_back () + ~partialResultToken ~workDoneToken ~uri ~notify_back + ~server_request () in self#on_req_completion ~notify_back ~id ~uri ~workDoneToken ~partialResultToken ~pos:position ~ctx:context doc_st) @@ -330,7 +359,8 @@ module Make (IO : IO) = struct k "req: definition '%s'" (DocumentUri.to_path uri)); let notify_back = new notify_back - ~workDoneToken ~partialResultToken ~uri ~notify_back () + ~workDoneToken ~partialResultToken ~uri ~notify_back + ~server_request () in (match Hashtbl.find_opt docs uri with @@ -345,7 +375,8 @@ module Make (IO : IO) = struct k "req: codelens '%s'" (DocumentUri.to_path uri)); let notify_back = new notify_back - ~workDoneToken ~partialResultToken ~uri ~notify_back () + ~workDoneToken ~partialResultToken ~uri ~notify_back + ~server_request () in (match Hashtbl.find_opt docs uri with @@ -357,7 +388,8 @@ module Make (IO : IO) = struct Log.debug (fun k -> k "req: codelens resolve"); let notify_back = new notify_back - ~workDoneToken:None ~partialResultToken:None ~notify_back () + ~workDoneToken:None ~partialResultToken:None ~notify_back + ~server_request () in self#on_req_code_lens_resolve ~notify_back ~id cl | Lsp.Client_request.ExecuteCommand @@ -365,14 +397,17 @@ module Make (IO : IO) = struct Log.debug (fun k -> k "req: execute command '%s'" command); let notify_back = new notify_back - ~workDoneToken ~partialResultToken:None ~notify_back () + ~workDoneToken ~partialResultToken:None ~notify_back + ~server_request () in self#on_req_execute_command ~notify_back ~id ~workDoneToken command arguments | Lsp.Client_request.DocumentSymbol { textDocument = d; workDoneToken; partialResultToken } -> let notify_back = - new notify_back ~workDoneToken ~partialResultToken ~notify_back () + new notify_back + ~workDoneToken ~partialResultToken ~notify_back ~server_request + () in self#on_req_symbol ~notify_back ~id ~uri:d.uri ~workDoneToken ~partialResultToken () @@ -380,7 +415,8 @@ module Make (IO : IO) = struct let notify_back = new notify_back ~workDoneToken:a.workDoneToken - ~partialResultToken:a.partialResultToken ~notify_back () + ~partialResultToken:a.partialResultToken ~notify_back + ~server_request () in self#on_req_code_action ~notify_back ~id a | Lsp.Client_request.CodeActionResolve _ @@ -420,7 +456,8 @@ module Make (IO : IO) = struct | Lsp.Client_request.UnknownRequest _ -> let notify_back = new notify_back - ~workDoneToken:None ~partialResultToken:None ~notify_back () + ~workDoneToken:None ~partialResultToken:None ~notify_back + ~server_request () in self#on_request_unhandled ~notify_back ~id r @@ -448,8 +485,8 @@ module Make (IO : IO) = struct IO.return () (** Override to handle unprocessed notifications *) - method on_notification ~notify_back (n : Lsp.Client_notification.t) - : unit IO.t = + method on_notification ~notify_back ~server_request + (n : Lsp.Client_notification.t) : unit IO.t = let open Lsp.Types in match n with | Lsp.Client_notification.TextDocumentDidOpen @@ -459,7 +496,7 @@ module Make (IO : IO) = struct let notify_back = new notify_back ~uri:doc.uri ~workDoneToken:None ~partialResultToken:None - ~version:doc.version ~notify_back () + ~version:doc.version ~notify_back ~server_request () in let st = { @@ -479,7 +516,7 @@ module Make (IO : IO) = struct let notify_back = new notify_back ~workDoneToken:None ~partialResultToken:None ~uri:doc.uri - ~notify_back () + ~notify_back ~server_request () in self#on_notif_doc_did_close ~notify_back:(notify_back : notify_back) @@ -491,7 +528,7 @@ module Make (IO : IO) = struct let notify_back = new notify_back ~workDoneToken:None ~partialResultToken:None ~uri:doc.uri - ~notify_back () + ~notify_back ~server_request () in let old_doc = @@ -557,7 +594,8 @@ module Make (IO : IO) = struct | Lsp.Client_notification.LogTrace _ -> let notify_back = new notify_back - ~workDoneToken:None ~partialResultToken:None ~notify_back () + ~workDoneToken:None ~partialResultToken:None ~notify_back + ~server_request () in self#on_notification_unhandled ~notify_back:(notify_back : notify_back)