1 | use std::{ |
2 | fmt, |
3 | io::{self, BufRead, Write}, |
4 | }; |
5 | |
6 | use serde::de::DeserializeOwned; |
7 | use serde_derive::{Deserialize, Serialize}; |
8 | |
9 | use crate::error::ExtractError; |
10 | |
11 | #[derive (Serialize, Deserialize, Debug, Clone)] |
12 | #[serde(untagged)] |
13 | pub enum Message { |
14 | Request(Request), |
15 | Response(Response), |
16 | Notification(Notification), |
17 | } |
18 | |
19 | impl From<Request> for Message { |
20 | fn from(request: Request) -> Message { |
21 | Message::Request(request) |
22 | } |
23 | } |
24 | |
25 | impl From<Response> for Message { |
26 | fn from(response: Response) -> Message { |
27 | Message::Response(response) |
28 | } |
29 | } |
30 | |
31 | impl From<Notification> for Message { |
32 | fn from(notification: Notification) -> Message { |
33 | Message::Notification(notification) |
34 | } |
35 | } |
36 | |
37 | #[derive (Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] |
38 | #[serde(transparent)] |
39 | pub struct RequestId(IdRepr); |
40 | |
41 | #[derive (Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] |
42 | #[serde(untagged)] |
43 | enum IdRepr { |
44 | I32(i32), |
45 | String(String), |
46 | } |
47 | |
48 | impl From<i32> for RequestId { |
49 | fn from(id: i32) -> RequestId { |
50 | RequestId(IdRepr::I32(id)) |
51 | } |
52 | } |
53 | |
54 | impl From<String> for RequestId { |
55 | fn from(id: String) -> RequestId { |
56 | RequestId(IdRepr::String(id)) |
57 | } |
58 | } |
59 | |
60 | impl fmt::Display for RequestId { |
61 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
62 | match &self.0 { |
63 | IdRepr::I32(it: &i32) => fmt::Display::fmt(self:it, f), |
64 | // Use debug here, to make it clear that `92` and `"92"` are |
65 | // different, and to reduce WTF factor if the sever uses `" "` as an |
66 | // ID. |
67 | IdRepr::String(it: &String) => fmt::Debug::fmt(self:it, f), |
68 | } |
69 | } |
70 | } |
71 | |
72 | #[derive (Debug, Serialize, Deserialize, Clone)] |
73 | pub struct Request { |
74 | pub id: RequestId, |
75 | pub method: String, |
76 | #[serde(default = "serde_json::Value::default" )] |
77 | #[serde(skip_serializing_if = "serde_json::Value::is_null" )] |
78 | pub params: serde_json::Value, |
79 | } |
80 | |
81 | #[derive (Debug, Serialize, Deserialize, Clone)] |
82 | pub struct Response { |
83 | // JSON RPC allows this to be null if it was impossible |
84 | // to decode the request's id. Ignore this special case |
85 | // and just die horribly. |
86 | pub id: RequestId, |
87 | #[serde(skip_serializing_if = "Option::is_none" )] |
88 | pub result: Option<serde_json::Value>, |
89 | #[serde(skip_serializing_if = "Option::is_none" )] |
90 | pub error: Option<ResponseError>, |
91 | } |
92 | |
93 | #[derive (Debug, Serialize, Deserialize, Clone)] |
94 | pub struct ResponseError { |
95 | pub code: i32, |
96 | pub message: String, |
97 | #[serde(skip_serializing_if = "Option::is_none" )] |
98 | pub data: Option<serde_json::Value>, |
99 | } |
100 | |
101 | #[derive (Clone, Copy, Debug)] |
102 | #[non_exhaustive ] |
103 | pub enum ErrorCode { |
104 | // Defined by JSON RPC: |
105 | ParseError = -32700, |
106 | InvalidRequest = -32600, |
107 | MethodNotFound = -32601, |
108 | InvalidParams = -32602, |
109 | InternalError = -32603, |
110 | ServerErrorStart = -32099, |
111 | ServerErrorEnd = -32000, |
112 | |
113 | /// Error code indicating that a server received a notification or |
114 | /// request before the server has received the `initialize` request. |
115 | ServerNotInitialized = -32002, |
116 | UnknownErrorCode = -32001, |
117 | |
118 | // Defined by the protocol: |
119 | /// The client has canceled a request and a server has detected |
120 | /// the cancel. |
121 | RequestCanceled = -32800, |
122 | |
123 | /// The server detected that the content of a document got |
124 | /// modified outside normal conditions. A server should |
125 | /// NOT send this error code if it detects a content change |
126 | /// in it unprocessed messages. The result even computed |
127 | /// on an older state might still be useful for the client. |
128 | /// |
129 | /// If a client decides that a result is not of any use anymore |
130 | /// the client should cancel the request. |
131 | ContentModified = -32801, |
132 | |
133 | /// The server cancelled the request. This error code should |
134 | /// only be used for requests that explicitly support being |
135 | /// server cancellable. |
136 | /// |
137 | /// @since 3.17.0 |
138 | ServerCancelled = -32802, |
139 | |
140 | /// A request failed but it was syntactically correct, e.g the |
141 | /// method name was known and the parameters were valid. The error |
142 | /// message should contain human readable information about why |
143 | /// the request failed. |
144 | /// |
145 | /// @since 3.17.0 |
146 | RequestFailed = -32803, |
147 | } |
148 | |
149 | #[derive (Debug, Serialize, Deserialize, Clone)] |
150 | pub struct Notification { |
151 | pub method: String, |
152 | #[serde(default = "serde_json::Value::default" )] |
153 | #[serde(skip_serializing_if = "serde_json::Value::is_null" )] |
154 | pub params: serde_json::Value, |
155 | } |
156 | |
157 | fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error { |
158 | io::Error::new(kind:io::ErrorKind::InvalidData, error) |
159 | } |
160 | |
161 | macro_rules! invalid_data { |
162 | ($($tt:tt)*) => (invalid_data(format!($($tt)*))) |
163 | } |
164 | |
165 | impl Message { |
166 | pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> { |
167 | Message::_read(r) |
168 | } |
169 | fn _read(r: &mut dyn BufRead) -> io::Result<Option<Message>> { |
170 | let text = match read_msg_text(r)? { |
171 | None => return Ok(None), |
172 | Some(text) => text, |
173 | }; |
174 | |
175 | let msg = match serde_json::from_str(&text) { |
176 | Ok(msg) => msg, |
177 | Err(e) => { |
178 | return Err(invalid_data!("malformed LSP payload: {:?}" , e)); |
179 | } |
180 | }; |
181 | |
182 | Ok(Some(msg)) |
183 | } |
184 | pub fn write(self, w: &mut impl Write) -> io::Result<()> { |
185 | self._write(w) |
186 | } |
187 | fn _write(self, w: &mut dyn Write) -> io::Result<()> { |
188 | #[derive (Serialize)] |
189 | struct JsonRpc { |
190 | jsonrpc: &'static str, |
191 | #[serde(flatten)] |
192 | msg: Message, |
193 | } |
194 | let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0" , msg: self })?; |
195 | write_msg_text(w, &text) |
196 | } |
197 | } |
198 | |
199 | impl Response { |
200 | pub fn new_ok<R: serde::Serialize>(id: RequestId, result: R) -> Response { |
201 | Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None } |
202 | } |
203 | pub fn new_err(id: RequestId, code: i32, message: String) -> Response { |
204 | let error: ResponseError = ResponseError { code, message, data: None }; |
205 | Response { id, result: None, error: Some(error) } |
206 | } |
207 | } |
208 | |
209 | impl Request { |
210 | pub fn new<P: serde::Serialize>(id: RequestId, method: String, params: P) -> Request { |
211 | Request { id, method, params: serde_json::to_value(params).unwrap() } |
212 | } |
213 | pub fn extract<P: DeserializeOwned>( |
214 | self, |
215 | method: &str, |
216 | ) -> Result<(RequestId, P), ExtractError<Request>> { |
217 | if self.method != method { |
218 | return Err(ExtractError::MethodMismatch(self)); |
219 | } |
220 | match serde_json::from_value(self.params) { |
221 | Ok(params: P) => Ok((self.id, params)), |
222 | Err(error: Error) => Err(ExtractError::JsonError { method: self.method, error }), |
223 | } |
224 | } |
225 | |
226 | pub(crate) fn is_shutdown(&self) -> bool { |
227 | self.method == "shutdown" |
228 | } |
229 | pub(crate) fn is_initialize(&self) -> bool { |
230 | self.method == "initialize" |
231 | } |
232 | } |
233 | |
234 | impl Notification { |
235 | pub fn new(method: String, params: impl serde::Serialize) -> Notification { |
236 | Notification { method, params: serde_json::to_value(params).unwrap() } |
237 | } |
238 | pub fn extract<P: DeserializeOwned>( |
239 | self, |
240 | method: &str, |
241 | ) -> Result<P, ExtractError<Notification>> { |
242 | if self.method != method { |
243 | return Err(ExtractError::MethodMismatch(self)); |
244 | } |
245 | match serde_json::from_value(self.params) { |
246 | Ok(params: P) => Ok(params), |
247 | Err(error: Error) => Err(ExtractError::JsonError { method: self.method, error }), |
248 | } |
249 | } |
250 | pub(crate) fn is_exit(&self) -> bool { |
251 | self.method == "exit" |
252 | } |
253 | pub(crate) fn is_initialized(&self) -> bool { |
254 | self.method == "initialized" |
255 | } |
256 | } |
257 | |
258 | fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> { |
259 | let mut size = None; |
260 | let mut buf = String::new(); |
261 | loop { |
262 | buf.clear(); |
263 | if inp.read_line(&mut buf)? == 0 { |
264 | return Ok(None); |
265 | } |
266 | if !buf.ends_with(" \r\n" ) { |
267 | return Err(invalid_data!("malformed header: {:?}" , buf)); |
268 | } |
269 | let buf = &buf[..buf.len() - 2]; |
270 | if buf.is_empty() { |
271 | break; |
272 | } |
273 | let mut parts = buf.splitn(2, ": " ); |
274 | let header_name = parts.next().unwrap(); |
275 | let header_value = |
276 | parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}" , buf))?; |
277 | if header_name.eq_ignore_ascii_case("Content-Length" ) { |
278 | size = Some(header_value.parse::<usize>().map_err(invalid_data)?); |
279 | } |
280 | } |
281 | let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length" ))?; |
282 | let mut buf = buf.into_bytes(); |
283 | buf.resize(size, 0); |
284 | inp.read_exact(&mut buf)?; |
285 | let buf = String::from_utf8(buf).map_err(invalid_data)?; |
286 | log::debug!("< {}" , buf); |
287 | Ok(Some(buf)) |
288 | } |
289 | |
290 | fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> { |
291 | log::debug!("> {}" , msg); |
292 | write!(out, "Content-Length: {}\r\n\r\n" , msg.len())?; |
293 | out.write_all(buf:msg.as_bytes())?; |
294 | out.flush()?; |
295 | Ok(()) |
296 | } |
297 | |
298 | #[cfg (test)] |
299 | mod tests { |
300 | use super::{Message, Notification, Request, RequestId}; |
301 | |
302 | #[test ] |
303 | fn shutdown_with_explicit_null() { |
304 | let text = "{ \"jsonrpc \": \"2.0 \", \"id \": 3, \"method \": \"shutdown \", \"params \": null }" ; |
305 | let msg: Message = serde_json::from_str(text).unwrap(); |
306 | |
307 | assert!( |
308 | matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown" ) |
309 | ); |
310 | } |
311 | |
312 | #[test ] |
313 | fn shutdown_with_no_params() { |
314 | let text = "{ \"jsonrpc \": \"2.0 \", \"id \": 3, \"method \": \"shutdown \"}" ; |
315 | let msg: Message = serde_json::from_str(text).unwrap(); |
316 | |
317 | assert!( |
318 | matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown" ) |
319 | ); |
320 | } |
321 | |
322 | #[test ] |
323 | fn notification_with_explicit_null() { |
324 | let text = "{ \"jsonrpc \": \"2.0 \", \"method \": \"exit \", \"params \": null }" ; |
325 | let msg: Message = serde_json::from_str(text).unwrap(); |
326 | |
327 | assert!(matches!(msg, Message::Notification(not) if not.method == "exit" )); |
328 | } |
329 | |
330 | #[test ] |
331 | fn notification_with_no_params() { |
332 | let text = "{ \"jsonrpc \": \"2.0 \", \"method \": \"exit \"}" ; |
333 | let msg: Message = serde_json::from_str(text).unwrap(); |
334 | |
335 | assert!(matches!(msg, Message::Notification(not) if not.method == "exit" )); |
336 | } |
337 | |
338 | #[test ] |
339 | fn serialize_request_with_null_params() { |
340 | let msg = Message::Request(Request { |
341 | id: RequestId::from(3), |
342 | method: "shutdown" .into(), |
343 | params: serde_json::Value::Null, |
344 | }); |
345 | let serialized = serde_json::to_string(&msg).unwrap(); |
346 | |
347 | assert_eq!("{ \"id \":3, \"method \": \"shutdown \"}" , serialized); |
348 | } |
349 | |
350 | #[test ] |
351 | fn serialize_notification_with_null_params() { |
352 | let msg = Message::Notification(Notification { |
353 | method: "exit" .into(), |
354 | params: serde_json::Value::Null, |
355 | }); |
356 | let serialized = serde_json::to_string(&msg).unwrap(); |
357 | |
358 | assert_eq!("{ \"method \": \"exit \"}" , serialized); |
359 | } |
360 | } |
361 | |