1 | //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/lsp-server-support/Transport.h" |
10 | #include "mlir/Support/ToolUtilities.h" |
11 | #include "mlir/Tools/lsp-server-support/Logging.h" |
12 | #include "mlir/Tools/lsp-server-support/Protocol.h" |
13 | #include "llvm/ADT/SmallString.h" |
14 | #include "llvm/Support/Errno.h" |
15 | #include "llvm/Support/Error.h" |
16 | #include <optional> |
17 | #include <system_error> |
18 | #include <utility> |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::lsp; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // Reply |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | namespace { |
28 | /// Function object to reply to an LSP call. |
29 | /// Each instance must be called exactly once, otherwise: |
30 | /// - if there was no reply, an error reply is sent |
31 | /// - if there were multiple replies, only the first is sent |
32 | class Reply { |
33 | public: |
34 | Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport, |
35 | std::mutex &transportOutputMutex); |
36 | Reply(Reply &&other); |
37 | Reply &operator=(Reply &&) = delete; |
38 | Reply(const Reply &) = delete; |
39 | Reply &operator=(const Reply &) = delete; |
40 | |
41 | void operator()(llvm::Expected<llvm::json::Value> reply); |
42 | |
43 | private: |
44 | StringRef method; |
45 | std::atomic<bool> replied = {false}; |
46 | llvm::json::Value id; |
47 | JSONTransport *transport; |
48 | std::mutex &transportOutputMutex; |
49 | }; |
50 | } // namespace |
51 | |
52 | Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, |
53 | JSONTransport &transport, std::mutex &transportOutputMutex) |
54 | : id(id), transport(&transport), |
55 | transportOutputMutex(transportOutputMutex) {} |
56 | |
57 | Reply::Reply(Reply &&other) |
58 | : replied(other.replied.load()), id(std::move(other.id)), |
59 | transport(other.transport), |
60 | transportOutputMutex(other.transportOutputMutex) { |
61 | other.transport = nullptr; |
62 | } |
63 | |
64 | void Reply::operator()(llvm::Expected<llvm::json::Value> reply) { |
65 | if (replied.exchange(i: true)) { |
66 | Logger::error(fmt: "Replied twice to message {0}({1})" , vals&: method, vals&: id); |
67 | assert(false && "must reply to each call only once!" ); |
68 | return; |
69 | } |
70 | assert(transport && "expected valid transport to reply to" ); |
71 | |
72 | std::lock_guard<std::mutex> transportLock(transportOutputMutex); |
73 | if (reply) { |
74 | Logger::info(fmt: "--> reply:{0}({1})" , vals&: method, vals&: id); |
75 | transport->reply(id: std::move(id), result: std::move(reply)); |
76 | } else { |
77 | llvm::Error error = reply.takeError(); |
78 | Logger::info(fmt: "--> reply:{0}({1})" , vals&: method, vals&: id, vals&: error); |
79 | transport->reply(id: std::move(id), result: std::move(error)); |
80 | } |
81 | } |
82 | |
83 | //===----------------------------------------------------------------------===// |
84 | // MessageHandler |
85 | //===----------------------------------------------------------------------===// |
86 | |
87 | bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) { |
88 | Logger::info(fmt: "--> {0}" , vals&: method); |
89 | |
90 | if (method == "exit" ) |
91 | return false; |
92 | if (method == "$cancel" ) { |
93 | // TODO: Add support for cancelling requests. |
94 | } else { |
95 | auto it = notificationHandlers.find(Key: method); |
96 | if (it != notificationHandlers.end()) |
97 | it->second(std::move(value)); |
98 | } |
99 | return true; |
100 | } |
101 | |
102 | bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, |
103 | llvm::json::Value id) { |
104 | Logger::info(fmt: "--> {0}({1})" , vals&: method, vals&: id); |
105 | |
106 | Reply reply(id, method, transport, transportOutputMutex); |
107 | |
108 | auto it = methodHandlers.find(Key: method); |
109 | if (it != methodHandlers.end()) { |
110 | it->second(std::move(params), std::move(reply)); |
111 | } else { |
112 | reply(llvm::make_error<LSPError>(Args: "method not found: " + method.str(), |
113 | Args: ErrorCode::MethodNotFound)); |
114 | } |
115 | return true; |
116 | } |
117 | |
118 | bool MessageHandler::onReply(llvm::json::Value id, |
119 | llvm::Expected<llvm::json::Value> result) { |
120 | // TODO: Add support for reply callbacks when support for outgoing messages is |
121 | // added. For now, we just log an error on any replies received. |
122 | Callback<llvm::json::Value> replyHandler = |
123 | [&id](llvm::Expected<llvm::json::Value> result) { |
124 | Logger::error( |
125 | fmt: "received a reply with ID {0}, but there was no such call" , vals&: id); |
126 | if (!result) |
127 | llvm::consumeError(Err: result.takeError()); |
128 | }; |
129 | |
130 | // Log and run the reply handler. |
131 | if (result) |
132 | replyHandler(std::move(result)); |
133 | else |
134 | replyHandler(result.takeError()); |
135 | return true; |
136 | } |
137 | |
138 | //===----------------------------------------------------------------------===// |
139 | // JSONTransport |
140 | //===----------------------------------------------------------------------===// |
141 | |
142 | /// Encode the given error as a JSON object. |
143 | static llvm::json::Object encodeError(llvm::Error error) { |
144 | std::string message; |
145 | ErrorCode code = ErrorCode::UnknownErrorCode; |
146 | auto handlerFn = [&](const LSPError &lspError) -> llvm::Error { |
147 | message = lspError.message; |
148 | code = lspError.code; |
149 | return llvm::Error::success(); |
150 | }; |
151 | if (llvm::Error unhandled = llvm::handleErrors(E: std::move(error), Hs&: handlerFn)) |
152 | message = llvm::toString(E: std::move(unhandled)); |
153 | |
154 | return llvm::json::Object{ |
155 | {.K: "message" , .V: std::move(message)}, |
156 | {.K: "code" , .V: int64_t(code)}, |
157 | }; |
158 | } |
159 | |
160 | /// Decode the given JSON object into an error. |
161 | llvm::Error decodeError(const llvm::json::Object &o) { |
162 | StringRef msg = o.getString(K: "message" ).value_or(u: "Unspecified error" ); |
163 | if (std::optional<int64_t> code = o.getInteger(K: "code" )) |
164 | return llvm::make_error<LSPError>(Args: msg.str(), Args: ErrorCode(*code)); |
165 | return llvm::make_error<llvm::StringError>(Args: llvm::inconvertibleErrorCode(), |
166 | Args: msg.str()); |
167 | } |
168 | |
169 | void JSONTransport::notify(StringRef method, llvm::json::Value params) { |
170 | sendMessage(msg: llvm::json::Object{ |
171 | {.K: "jsonrpc" , .V: "2.0" }, |
172 | {.K: "method" , .V: method}, |
173 | {.K: "params" , .V: std::move(params)}, |
174 | }); |
175 | } |
176 | void JSONTransport::call(StringRef method, llvm::json::Value params, |
177 | llvm::json::Value id) { |
178 | sendMessage(msg: llvm::json::Object{ |
179 | {.K: "jsonrpc" , .V: "2.0" }, |
180 | {.K: "id" , .V: std::move(id)}, |
181 | {.K: "method" , .V: method}, |
182 | {.K: "params" , .V: std::move(params)}, |
183 | }); |
184 | } |
185 | void JSONTransport::reply(llvm::json::Value id, |
186 | llvm::Expected<llvm::json::Value> result) { |
187 | if (result) { |
188 | return sendMessage(msg: llvm::json::Object{ |
189 | {.K: "jsonrpc" , .V: "2.0" }, |
190 | {.K: "id" , .V: std::move(id)}, |
191 | {.K: "result" , .V: std::move(*result)}, |
192 | }); |
193 | } |
194 | |
195 | sendMessage(msg: llvm::json::Object{ |
196 | {.K: "jsonrpc" , .V: "2.0" }, |
197 | {.K: "id" , .V: std::move(id)}, |
198 | {.K: "error" , .V: encodeError(error: result.takeError())}, |
199 | }); |
200 | } |
201 | |
202 | llvm::Error JSONTransport::run(MessageHandler &handler) { |
203 | std::string json; |
204 | while (!feof(stream: in)) { |
205 | if (ferror(stream: in)) { |
206 | return llvm::errorCodeToError( |
207 | EC: std::error_code(errno, std::system_category())); |
208 | } |
209 | |
210 | if (succeeded(result: readMessage(json))) { |
211 | if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(JSON: json)) { |
212 | if (!handleMessage(msg: std::move(*doc), handler)) |
213 | return llvm::Error::success(); |
214 | } else { |
215 | Logger::error(fmt: "JSON parse error: {0}" , vals: llvm::toString(E: doc.takeError())); |
216 | } |
217 | } |
218 | } |
219 | return llvm::errorCodeToError(EC: std::make_error_code(e: std::errc::io_error)); |
220 | } |
221 | |
222 | void JSONTransport::sendMessage(llvm::json::Value msg) { |
223 | outputBuffer.clear(); |
224 | llvm::raw_svector_ostream os(outputBuffer); |
225 | os << llvm::formatv(Fmt: prettyOutput ? "{0:2}\n" : "{0}" , Vals&: msg); |
226 | out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n" |
227 | << outputBuffer; |
228 | out.flush(); |
229 | Logger::debug(fmt: ">>> {0}\n" , vals&: outputBuffer); |
230 | } |
231 | |
232 | bool JSONTransport::handleMessage(llvm::json::Value msg, |
233 | MessageHandler &handler) { |
234 | // Message must be an object with "jsonrpc":"2.0". |
235 | llvm::json::Object *object = msg.getAsObject(); |
236 | if (!object || |
237 | object->getString(K: "jsonrpc" ) != std::optional<StringRef>("2.0" )) |
238 | return false; |
239 | |
240 | // `id` may be any JSON value. If absent, this is a notification. |
241 | std::optional<llvm::json::Value> id; |
242 | if (llvm::json::Value *i = object->get(K: "id" )) |
243 | id = std::move(*i); |
244 | std::optional<StringRef> method = object->getString(K: "method" ); |
245 | |
246 | // This is a response. |
247 | if (!method) { |
248 | if (!id) |
249 | return false; |
250 | if (auto *err = object->getObject(K: "error" )) |
251 | return handler.onReply(id: std::move(*id), result: decodeError(o: *err)); |
252 | // result should be given, use null if not. |
253 | llvm::json::Value result = nullptr; |
254 | if (llvm::json::Value *r = object->get(K: "result" )) |
255 | result = std::move(*r); |
256 | return handler.onReply(id: std::move(*id), result: std::move(result)); |
257 | } |
258 | |
259 | // Params should be given, use null if not. |
260 | llvm::json::Value params = nullptr; |
261 | if (llvm::json::Value *p = object->get(K: "params" )) |
262 | params = std::move(*p); |
263 | |
264 | if (id) |
265 | return handler.onCall(method: *method, params: std::move(params), id: std::move(*id)); |
266 | return handler.onNotify(method: *method, value: std::move(params)); |
267 | } |
268 | |
269 | /// Tries to read a line up to and including \n. |
270 | /// If failing, feof(), ferror(), or shutdownRequested() will be set. |
271 | LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) { |
272 | // Big enough to hold any reasonable header line. May not fit content lines |
273 | // in delimited mode, but performance doesn't matter for that mode. |
274 | static constexpr int bufSize = 128; |
275 | size_t size = 0; |
276 | out.clear(); |
277 | for (;;) { |
278 | out.resize_for_overwrite(N: size + bufSize); |
279 | if (!std::fgets(s: &out[size], n: bufSize, stream: in)) |
280 | return failure(); |
281 | |
282 | clearerr(stream: in); |
283 | |
284 | // If the line contained null bytes, anything after it (including \n) will |
285 | // be ignored. Fortunately this is not a legal header or JSON. |
286 | size_t read = std::strlen(s: &out[size]); |
287 | if (read > 0 && out[size + read - 1] == '\n') { |
288 | out.resize(N: size + read); |
289 | return success(); |
290 | } |
291 | size += read; |
292 | } |
293 | } |
294 | |
295 | // Returns std::nullopt when: |
296 | // - ferror(), feof(), or shutdownRequested() are set. |
297 | // - Content-Length is missing or empty (protocol error) |
298 | LogicalResult JSONTransport::readStandardMessage(std::string &json) { |
299 | // A Language Server Protocol message starts with a set of HTTP headers, |
300 | // delimited by \r\n, and terminated by an empty line (\r\n). |
301 | unsigned long long contentLength = 0; |
302 | llvm::SmallString<128> line; |
303 | while (true) { |
304 | if (feof(stream: in) || ferror(stream: in) || failed(result: readLine(in, out&: line))) |
305 | return failure(); |
306 | |
307 | // Content-Length is a mandatory header, and the only one we handle. |
308 | StringRef lineRef = line; |
309 | if (lineRef.consume_front(Prefix: "Content-Length: " )) { |
310 | llvm::getAsUnsignedInteger(Str: lineRef.trim(), Radix: 0, Result&: contentLength); |
311 | } else if (!lineRef.trim().empty()) { |
312 | // It's another header, ignore it. |
313 | continue; |
314 | } else { |
315 | // An empty line indicates the end of headers. Go ahead and read the JSON. |
316 | break; |
317 | } |
318 | } |
319 | |
320 | // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" |
321 | if (contentLength == 0 || contentLength > 1 << 30) |
322 | return failure(); |
323 | |
324 | json.resize(n: contentLength); |
325 | for (size_t pos = 0, read; pos < contentLength; pos += read) { |
326 | read = std::fread(ptr: &json[pos], size: 1, n: contentLength - pos, stream: in); |
327 | if (read == 0) |
328 | return failure(); |
329 | |
330 | // If we're done, the error was transient. If we're not done, either it was |
331 | // transient or we'll see it again on retry. |
332 | clearerr(stream: in); |
333 | pos += read; |
334 | } |
335 | return success(); |
336 | } |
337 | |
338 | /// For lit tests we support a simplified syntax: |
339 | /// - messages are delimited by '// -----' on a line by itself |
340 | /// - lines starting with // are ignored. |
341 | /// This is a testing path, so favor simplicity over performance here. |
342 | /// When returning failure: feof(), ferror(), or shutdownRequested() will be |
343 | /// set. |
344 | LogicalResult JSONTransport::readDelimitedMessage(std::string &json) { |
345 | json.clear(); |
346 | llvm::SmallString<128> line; |
347 | while (succeeded(result: readLine(in, out&: line))) { |
348 | StringRef lineRef = line.str().trim(); |
349 | if (lineRef.starts_with(Prefix: "//" )) { |
350 | // Found a delimiter for the message. |
351 | if (lineRef == kDefaultSplitMarker) |
352 | break; |
353 | continue; |
354 | } |
355 | |
356 | json += line; |
357 | } |
358 | |
359 | return failure(isFailure: ferror(stream: in)); |
360 | } |
361 | |