1 | use std::{ |
2 | fmt, |
3 | io::{self, BufRead, Write}, |
4 | }; |
5 | |
6 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; |
7 | |
8 | use crate::error::ExtractError; |
9 | |
10 | #[derive (Serialize, Deserialize, Debug, Clone)] |
11 | #[serde(untagged)] |
12 | pub enum Message { |
13 | Request(Request), |
14 | Response(Response), |
15 | Notification(Notification), |
16 | } |
17 | |
18 | impl From<Request> for Message { |
19 | fn from(request: Request) -> Message { |
20 | Message::Request(request) |
21 | } |
22 | } |
23 | |
24 | impl From<Response> for Message { |
25 | fn from(response: Response) -> Message { |
26 | Message::Response(response) |
27 | } |
28 | } |
29 | |
30 | impl From<Notification> for Message { |
31 | fn from(notification: Notification) -> Message { |
32 | Message::Notification(notification) |
33 | } |
34 | } |
35 | |
36 | #[derive (Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] |
37 | #[serde(transparent)] |
38 | pub struct RequestId(IdRepr); |
39 | |
40 | #[derive (Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] |
41 | #[serde(untagged)] |
42 | enum IdRepr { |
43 | I32(i32), |
44 | String(String), |
45 | } |
46 | |
47 | impl From<i32> for RequestId { |
48 | fn from(id: i32) -> RequestId { |
49 | RequestId(IdRepr::I32(id)) |
50 | } |
51 | } |
52 | |
53 | impl From<String> for RequestId { |
54 | fn from(id: String) -> RequestId { |
55 | RequestId(IdRepr::String(id)) |
56 | } |
57 | } |
58 | |
59 | impl fmt::Display for RequestId { |
60 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
61 | match &self.0 { |
62 | IdRepr::I32(it: &i32) => fmt::Display::fmt(self:it, f), |
63 | // Use debug here, to make it clear that `92` and `"92"` are |
64 | // different, and to reduce WTF factor if the sever uses `" "` as an |
65 | // ID. |
66 | IdRepr::String(it: &String) => fmt::Debug::fmt(self:it, f), |
67 | } |
68 | } |
69 | } |
70 | |
71 | #[derive (Debug, Serialize, Deserialize, Clone)] |
72 | pub struct Request { |
73 | pub id: RequestId, |
74 | pub method: String, |
75 | #[serde(default = "serde_json::Value::default" )] |
76 | #[serde(skip_serializing_if = "serde_json::Value::is_null" )] |
77 | pub params: serde_json::Value, |
78 | } |
79 | |
80 | #[derive (Debug, Serialize, Deserialize, Clone)] |
81 | pub struct Response { |
82 | // JSON RPC allows this to be null if it was impossible |
83 | // to decode the request's id. Ignore this special case |
84 | // and just die horribly. |
85 | pub id: RequestId, |
86 | #[serde(skip_serializing_if = "Option::is_none" )] |
87 | pub result: Option<serde_json::Value>, |
88 | #[serde(skip_serializing_if = "Option::is_none" )] |
89 | pub error: Option<ResponseError>, |
90 | } |
91 | |
92 | #[derive (Debug, Serialize, Deserialize, Clone)] |
93 | pub struct ResponseError { |
94 | pub code: i32, |
95 | pub message: String, |
96 | #[serde(skip_serializing_if = "Option::is_none" )] |
97 | pub data: Option<serde_json::Value>, |
98 | } |
99 | |
100 | #[derive (Clone, Copy, Debug)] |
101 | #[non_exhaustive ] |
102 | pub enum ErrorCode { |
103 | // Defined by JSON RPC: |
104 | ParseError = -32700, |
105 | InvalidRequest = -32600, |
106 | MethodNotFound = -32601, |
107 | InvalidParams = -32602, |
108 | InternalError = -32603, |
109 | ServerErrorStart = -32099, |
110 | ServerErrorEnd = -32000, |
111 | |
112 | /// Error code indicating that a server received a notification or |
113 | /// request before the server has received the `initialize` request. |
114 | ServerNotInitialized = -32002, |
115 | UnknownErrorCode = -32001, |
116 | |
117 | // Defined by the protocol: |
118 | /// The client has canceled a request and a server has detected |
119 | /// the cancel. |
120 | RequestCanceled = -32800, |
121 | |
122 | /// The server detected that the content of a document got |
123 | /// modified outside normal conditions. A server should |
124 | /// NOT send this error code if it detects a content change |
125 | /// in it unprocessed messages. The result even computed |
126 | /// on an older state might still be useful for the client. |
127 | /// |
128 | /// If a client decides that a result is not of any use anymore |
129 | /// the client should cancel the request. |
130 | ContentModified = -32801, |
131 | |
132 | /// The server cancelled the request. This error code should |
133 | /// only be used for requests that explicitly support being |
134 | /// server cancellable. |
135 | /// |
136 | /// @since 3.17.0 |
137 | ServerCancelled = -32802, |
138 | |
139 | /// A request failed but it was syntactically correct, e.g the |
140 | /// method name was known and the parameters were valid. The error |
141 | /// message should contain human readable information about why |
142 | /// the request failed. |
143 | /// |
144 | /// @since 3.17.0 |
145 | RequestFailed = -32803, |
146 | } |
147 | |
148 | #[derive (Debug, Serialize, Deserialize, Clone)] |
149 | pub struct Notification { |
150 | pub method: String, |
151 | #[serde(default = "serde_json::Value::default" )] |
152 | #[serde(skip_serializing_if = "serde_json::Value::is_null" )] |
153 | pub params: serde_json::Value, |
154 | } |
155 | |
156 | impl Message { |
157 | pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> { |
158 | Message::_read(r) |
159 | } |
160 | fn _read(r: &mut dyn BufRead) -> io::Result<Option<Message>> { |
161 | let text = match read_msg_text(r)? { |
162 | None => return Ok(None), |
163 | Some(text) => text, |
164 | }; |
165 | let msg = serde_json::from_str(&text)?; |
166 | Ok(Some(msg)) |
167 | } |
168 | pub fn write(self, w: &mut impl Write) -> io::Result<()> { |
169 | self._write(w) |
170 | } |
171 | fn _write(self, w: &mut dyn Write) -> io::Result<()> { |
172 | #[derive (Serialize)] |
173 | struct JsonRpc { |
174 | jsonrpc: &'static str, |
175 | #[serde(flatten)] |
176 | msg: Message, |
177 | } |
178 | let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0" , msg: self })?; |
179 | write_msg_text(w, &text) |
180 | } |
181 | } |
182 | |
183 | impl Response { |
184 | pub fn new_ok<R: Serialize>(id: RequestId, result: R) -> Response { |
185 | Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None } |
186 | } |
187 | pub fn new_err(id: RequestId, code: i32, message: String) -> Response { |
188 | let error: ResponseError = ResponseError { code, message, data: None }; |
189 | Response { id, result: None, error: Some(error) } |
190 | } |
191 | } |
192 | |
193 | impl Request { |
194 | pub fn new<P: Serialize>(id: RequestId, method: String, params: P) -> Request { |
195 | Request { id, method, params: serde_json::to_value(params).unwrap() } |
196 | } |
197 | pub fn extract<P: DeserializeOwned>( |
198 | self, |
199 | method: &str, |
200 | ) -> Result<(RequestId, P), ExtractError<Request>> { |
201 | if self.method != method { |
202 | return Err(ExtractError::MethodMismatch(self)); |
203 | } |
204 | match serde_json::from_value(self.params) { |
205 | Ok(params: P) => Ok((self.id, params)), |
206 | Err(error: Error) => Err(ExtractError::JsonError { method: self.method, error }), |
207 | } |
208 | } |
209 | |
210 | pub(crate) fn is_shutdown(&self) -> bool { |
211 | self.method == "shutdown" |
212 | } |
213 | pub(crate) fn is_initialize(&self) -> bool { |
214 | self.method == "initialize" |
215 | } |
216 | } |
217 | |
218 | impl Notification { |
219 | pub fn new(method: String, params: impl Serialize) -> Notification { |
220 | Notification { method, params: serde_json::to_value(params).unwrap() } |
221 | } |
222 | pub fn extract<P: DeserializeOwned>( |
223 | self, |
224 | method: &str, |
225 | ) -> Result<P, ExtractError<Notification>> { |
226 | if self.method != method { |
227 | return Err(ExtractError::MethodMismatch(self)); |
228 | } |
229 | match serde_json::from_value(self.params) { |
230 | Ok(params: P) => Ok(params), |
231 | Err(error: Error) => Err(ExtractError::JsonError { method: self.method, error }), |
232 | } |
233 | } |
234 | pub(crate) fn is_exit(&self) -> bool { |
235 | self.method == "exit" |
236 | } |
237 | pub(crate) fn is_initialized(&self) -> bool { |
238 | self.method == "initialized" |
239 | } |
240 | } |
241 | |
242 | fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> { |
243 | fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error { |
244 | io::Error::new(io::ErrorKind::InvalidData, error) |
245 | } |
246 | macro_rules! invalid_data { |
247 | ($($tt:tt)*) => (invalid_data(format!($($tt)*))) |
248 | } |
249 | |
250 | let mut size = None; |
251 | let mut buf = String::new(); |
252 | loop { |
253 | buf.clear(); |
254 | if inp.read_line(&mut buf)? == 0 { |
255 | return Ok(None); |
256 | } |
257 | if !buf.ends_with(" \r\n" ) { |
258 | return Err(invalid_data!("malformed header: {:?}" , buf)); |
259 | } |
260 | let buf = &buf[..buf.len() - 2]; |
261 | if buf.is_empty() { |
262 | break; |
263 | } |
264 | let mut parts = buf.splitn(2, ": " ); |
265 | let header_name = parts.next().unwrap(); |
266 | let header_value = |
267 | parts.next().ok_or_else(|| invalid_data(format!("malformed header: {:?}" , buf)))?; |
268 | if header_name.eq_ignore_ascii_case("Content-Length" ) { |
269 | size = Some(header_value.parse::<usize>().map_err(invalid_data)?); |
270 | } |
271 | } |
272 | let size: usize = size.ok_or_else(|| invalid_data("no Content-Length" .to_string()))?; |
273 | let mut buf = buf.into_bytes(); |
274 | buf.resize(size, 0); |
275 | inp.read_exact(&mut buf)?; |
276 | let buf = String::from_utf8(buf).map_err(invalid_data)?; |
277 | log::debug!("< {}" , buf); |
278 | Ok(Some(buf)) |
279 | } |
280 | |
281 | fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> { |
282 | log::debug!("> {}" , msg); |
283 | write!(out, "Content-Length: {}\r\n\r\n" , msg.len())?; |
284 | out.write_all(buf:msg.as_bytes())?; |
285 | out.flush()?; |
286 | Ok(()) |
287 | } |
288 | |
289 | #[cfg (test)] |
290 | mod tests { |
291 | use super::{Message, Notification, Request, RequestId}; |
292 | |
293 | #[test ] |
294 | fn shutdown_with_explicit_null() { |
295 | let text = "{ \"jsonrpc \": \"2.0 \", \"id \": 3, \"method \": \"shutdown \", \"params \": null }" ; |
296 | let msg: Message = serde_json::from_str(text).unwrap(); |
297 | |
298 | assert!( |
299 | matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown" ) |
300 | ); |
301 | } |
302 | |
303 | #[test ] |
304 | fn shutdown_with_no_params() { |
305 | let text = "{ \"jsonrpc \": \"2.0 \", \"id \": 3, \"method \": \"shutdown \"}" ; |
306 | let msg: Message = serde_json::from_str(text).unwrap(); |
307 | |
308 | assert!( |
309 | matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown" ) |
310 | ); |
311 | } |
312 | |
313 | #[test ] |
314 | fn notification_with_explicit_null() { |
315 | let text = "{ \"jsonrpc \": \"2.0 \", \"method \": \"exit \", \"params \": null }" ; |
316 | let msg: Message = serde_json::from_str(text).unwrap(); |
317 | |
318 | assert!(matches!(msg, Message::Notification(not) if not.method == "exit" )); |
319 | } |
320 | |
321 | #[test ] |
322 | fn notification_with_no_params() { |
323 | let text = "{ \"jsonrpc \": \"2.0 \", \"method \": \"exit \"}" ; |
324 | let msg: Message = serde_json::from_str(text).unwrap(); |
325 | |
326 | assert!(matches!(msg, Message::Notification(not) if not.method == "exit" )); |
327 | } |
328 | |
329 | #[test ] |
330 | fn serialize_request_with_null_params() { |
331 | let msg = Message::Request(Request { |
332 | id: RequestId::from(3), |
333 | method: "shutdown" .into(), |
334 | params: serde_json::Value::Null, |
335 | }); |
336 | let serialized = serde_json::to_string(&msg).unwrap(); |
337 | |
338 | assert_eq!("{ \"id \":3, \"method \": \"shutdown \"}" , serialized); |
339 | } |
340 | |
341 | #[test ] |
342 | fn serialize_notification_with_null_params() { |
343 | let msg = Message::Notification(Notification { |
344 | method: "exit" .into(), |
345 | params: serde_json::Value::Null, |
346 | }); |
347 | let serialized = serde_json::to_string(&msg).unwrap(); |
348 | |
349 | assert_eq!("{ \"method \": \"exit \"}" , serialized); |
350 | } |
351 | } |
352 | |