1 | //===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Tools/lsp-server-support/Transport.h" |
10 | #include "mlir/Support/ToolUtilities.h" |
11 | #include "mlir/Tools/lsp-server-support/Logging.h" |
12 | #include "mlir/Tools/lsp-server-support/Protocol.h" |
13 | #include "llvm/ADT/SmallString.h" |
14 | #include "llvm/Support/Errno.h" |
15 | #include "llvm/Support/Error.h" |
16 | #include <optional> |
17 | #include <system_error> |
18 | #include <utility> |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::lsp; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // Reply |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | namespace { |
28 | /// Function object to reply to an LSP call. |
29 | /// Each instance must be called exactly once, otherwise: |
30 | /// - if there was no reply, an error reply is sent |
31 | /// - if there were multiple replies, only the first is sent |
32 | class Reply { |
33 | public: |
34 | Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport, |
35 | std::mutex &transportOutputMutex); |
36 | Reply(Reply &&other); |
37 | Reply &operator=(Reply &&) = delete; |
38 | Reply(const Reply &) = delete; |
39 | Reply &operator=(const Reply &) = delete; |
40 | |
41 | void operator()(llvm::Expected<llvm::json::Value> reply); |
42 | |
43 | private: |
44 | std::string method; |
45 | std::atomic<bool> replied = {false}; |
46 | llvm::json::Value id; |
47 | JSONTransport *transport; |
48 | std::mutex &transportOutputMutex; |
49 | }; |
50 | } // namespace |
51 | |
52 | Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, |
53 | JSONTransport &transport, std::mutex &transportOutputMutex) |
54 | : method(method), id(id), transport(&transport), |
55 | transportOutputMutex(transportOutputMutex) {} |
56 | |
57 | Reply::Reply(Reply &&other) |
58 | : method(other.method), replied(other.replied.load()), |
59 | id(std::move(other.id)), transport(other.transport), |
60 | transportOutputMutex(other.transportOutputMutex) { |
61 | other.transport = nullptr; |
62 | } |
63 | |
64 | void Reply::operator()(llvm::Expected<llvm::json::Value> reply) { |
65 | if (replied.exchange(i: true)) { |
66 | Logger::error(fmt: "Replied twice to message {0}({1})" , vals&: method, vals&: id); |
67 | assert(false && "must reply to each call only once!" ); |
68 | return; |
69 | } |
70 | assert(transport && "expected valid transport to reply to" ); |
71 | |
72 | std::lock_guard<std::mutex> transportLock(transportOutputMutex); |
73 | if (reply) { |
74 | Logger::info(fmt: "--> reply:{0}({1})" , vals&: method, vals&: id); |
75 | transport->reply(id: std::move(id), result: std::move(reply)); |
76 | } else { |
77 | llvm::Error error = reply.takeError(); |
78 | Logger::info(fmt: "--> reply:{0}({1}): {2}" , vals&: method, vals&: id, vals&: error); |
79 | transport->reply(id: std::move(id), result: std::move(error)); |
80 | } |
81 | } |
82 | |
83 | //===----------------------------------------------------------------------===// |
84 | // MessageHandler |
85 | //===----------------------------------------------------------------------===// |
86 | |
87 | bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) { |
88 | Logger::info(fmt: "--> {0}" , vals&: method); |
89 | |
90 | if (method == "exit" ) |
91 | return false; |
92 | if (method == "$cancel" ) { |
93 | // TODO: Add support for cancelling requests. |
94 | } else { |
95 | auto it = notificationHandlers.find(Key: method); |
96 | if (it != notificationHandlers.end()) |
97 | it->second(std::move(value)); |
98 | } |
99 | return true; |
100 | } |
101 | |
102 | bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, |
103 | llvm::json::Value id) { |
104 | Logger::info(fmt: "--> {0}({1})" , vals&: method, vals&: id); |
105 | |
106 | Reply reply(id, method, transport, transportOutputMutex); |
107 | |
108 | auto it = methodHandlers.find(Key: method); |
109 | if (it != methodHandlers.end()) { |
110 | it->second(std::move(params), std::move(reply)); |
111 | } else { |
112 | reply(llvm::make_error<LSPError>(Args: "method not found: " + method.str(), |
113 | Args: ErrorCode::MethodNotFound)); |
114 | } |
115 | return true; |
116 | } |
117 | |
118 | bool MessageHandler::onReply(llvm::json::Value id, |
119 | llvm::Expected<llvm::json::Value> result) { |
120 | // Find the response handler in the mapping. If it exists, move it out of the |
121 | // mapping and erase it. |
122 | ResponseHandlerTy responseHandler; |
123 | { |
124 | std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex); |
125 | auto it = responseHandlers.find(Key: debugString(op&: id)); |
126 | if (it != responseHandlers.end()) { |
127 | responseHandler = std::move(it->second); |
128 | responseHandlers.erase(I: it); |
129 | } |
130 | } |
131 | |
132 | // If we found a response handler, invoke it. Otherwise, log an error. |
133 | if (responseHandler.second) { |
134 | Logger::info(fmt: "--> reply:{0}({1})" , vals&: responseHandler.first, vals&: id); |
135 | responseHandler.second(std::move(id), std::move(result)); |
136 | } else { |
137 | Logger::error( |
138 | fmt: "received a reply with ID {0}, but there was no such outgoing request" , |
139 | vals&: id); |
140 | if (!result) |
141 | llvm::consumeError(Err: result.takeError()); |
142 | } |
143 | return true; |
144 | } |
145 | |
146 | //===----------------------------------------------------------------------===// |
147 | // JSONTransport |
148 | //===----------------------------------------------------------------------===// |
149 | |
150 | /// Encode the given error as a JSON object. |
151 | static llvm::json::Object encodeError(llvm::Error error) { |
152 | std::string message; |
153 | ErrorCode code = ErrorCode::UnknownErrorCode; |
154 | auto handlerFn = [&](const LSPError &lspError) -> llvm::Error { |
155 | message = lspError.message; |
156 | code = lspError.code; |
157 | return llvm::Error::success(); |
158 | }; |
159 | if (llvm::Error unhandled = llvm::handleErrors(E: std::move(error), Hs&: handlerFn)) |
160 | message = llvm::toString(E: std::move(unhandled)); |
161 | |
162 | return llvm::json::Object{ |
163 | {.K: "message" , .V: std::move(message)}, |
164 | {.K: "code" , .V: int64_t(code)}, |
165 | }; |
166 | } |
167 | |
168 | /// Decode the given JSON object into an error. |
169 | llvm::Error decodeError(const llvm::json::Object &o) { |
170 | StringRef msg = o.getString(K: "message" ).value_or(u: "Unspecified error" ); |
171 | if (std::optional<int64_t> code = o.getInteger(K: "code" )) |
172 | return llvm::make_error<LSPError>(Args: msg.str(), Args: ErrorCode(*code)); |
173 | return llvm::make_error<llvm::StringError>(Args: llvm::inconvertibleErrorCode(), |
174 | Args: msg.str()); |
175 | } |
176 | |
177 | void JSONTransport::notify(StringRef method, llvm::json::Value params) { |
178 | sendMessage(msg: llvm::json::Object{ |
179 | {.K: "jsonrpc" , .V: "2.0" }, |
180 | {.K: "method" , .V: method}, |
181 | {.K: "params" , .V: std::move(params)}, |
182 | }); |
183 | } |
184 | void JSONTransport::call(StringRef method, llvm::json::Value params, |
185 | llvm::json::Value id) { |
186 | sendMessage(msg: llvm::json::Object{ |
187 | {.K: "jsonrpc" , .V: "2.0" }, |
188 | {.K: "id" , .V: std::move(id)}, |
189 | {.K: "method" , .V: method}, |
190 | {.K: "params" , .V: std::move(params)}, |
191 | }); |
192 | } |
193 | void JSONTransport::reply(llvm::json::Value id, |
194 | llvm::Expected<llvm::json::Value> result) { |
195 | if (result) { |
196 | return sendMessage(msg: llvm::json::Object{ |
197 | {.K: "jsonrpc" , .V: "2.0" }, |
198 | {.K: "id" , .V: std::move(id)}, |
199 | {.K: "result" , .V: std::move(*result)}, |
200 | }); |
201 | } |
202 | |
203 | sendMessage(msg: llvm::json::Object{ |
204 | {.K: "jsonrpc" , .V: "2.0" }, |
205 | {.K: "id" , .V: std::move(id)}, |
206 | {.K: "error" , .V: encodeError(error: result.takeError())}, |
207 | }); |
208 | } |
209 | |
210 | llvm::Error JSONTransport::run(MessageHandler &handler) { |
211 | std::string json; |
212 | while (!in->isEndOfInput()) { |
213 | if (in->hasError()) { |
214 | return llvm::errorCodeToError( |
215 | EC: std::error_code(errno, std::system_category())); |
216 | } |
217 | |
218 | if (succeeded(Result: in->readMessage(json))) { |
219 | if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(JSON: json)) { |
220 | if (!handleMessage(msg: std::move(*doc), handler)) |
221 | return llvm::Error::success(); |
222 | } else { |
223 | Logger::error(fmt: "JSON parse error: {0}" , vals: llvm::toString(E: doc.takeError())); |
224 | } |
225 | } |
226 | } |
227 | return llvm::errorCodeToError(EC: std::make_error_code(e: std::errc::io_error)); |
228 | } |
229 | |
230 | void JSONTransport::sendMessage(llvm::json::Value msg) { |
231 | outputBuffer.clear(); |
232 | llvm::raw_svector_ostream os(outputBuffer); |
233 | os << llvm::formatv(Fmt: prettyOutput ? "{0:2}\n" : "{0}" , Vals&: msg); |
234 | out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n" |
235 | << outputBuffer; |
236 | out.flush(); |
237 | Logger::debug(fmt: ">>> {0}\n" , vals&: outputBuffer); |
238 | } |
239 | |
240 | bool JSONTransport::handleMessage(llvm::json::Value msg, |
241 | MessageHandler &handler) { |
242 | // Message must be an object with "jsonrpc":"2.0". |
243 | llvm::json::Object *object = msg.getAsObject(); |
244 | if (!object || |
245 | object->getString(K: "jsonrpc" ) != std::optional<StringRef>("2.0" )) |
246 | return false; |
247 | |
248 | // `id` may be any JSON value. If absent, this is a notification. |
249 | std::optional<llvm::json::Value> id; |
250 | if (llvm::json::Value *i = object->get(K: "id" )) |
251 | id = std::move(*i); |
252 | std::optional<StringRef> method = object->getString(K: "method" ); |
253 | |
254 | // This is a response. |
255 | if (!method) { |
256 | if (!id) |
257 | return false; |
258 | if (auto *err = object->getObject(K: "error" )) |
259 | return handler.onReply(id: std::move(*id), result: decodeError(o: *err)); |
260 | // result should be given, use null if not. |
261 | llvm::json::Value result = nullptr; |
262 | if (llvm::json::Value *r = object->get(K: "result" )) |
263 | result = std::move(*r); |
264 | return handler.onReply(id: std::move(*id), result: std::move(result)); |
265 | } |
266 | |
267 | // Params should be given, use null if not. |
268 | llvm::json::Value params = nullptr; |
269 | if (llvm::json::Value *p = object->get(K: "params" )) |
270 | params = std::move(*p); |
271 | |
272 | if (id) |
273 | return handler.onCall(method: *method, params: std::move(params), id: std::move(*id)); |
274 | return handler.onNotify(method: *method, value: std::move(params)); |
275 | } |
276 | |
277 | /// Tries to read a line up to and including \n. |
278 | /// If failing, feof(), ferror(), or shutdownRequested() will be set. |
279 | LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) { |
280 | // Big enough to hold any reasonable header line. May not fit content lines |
281 | // in delimited mode, but performance doesn't matter for that mode. |
282 | static constexpr int bufSize = 128; |
283 | size_t size = 0; |
284 | out.clear(); |
285 | for (;;) { |
286 | out.resize_for_overwrite(N: size + bufSize); |
287 | if (!std::fgets(s: &out[size], n: bufSize, stream: in)) |
288 | return failure(); |
289 | |
290 | clearerr(stream: in); |
291 | |
292 | // If the line contained null bytes, anything after it (including \n) will |
293 | // be ignored. Fortunately this is not a legal header or JSON. |
294 | size_t read = std::strlen(s: &out[size]); |
295 | if (read > 0 && out[size + read - 1] == '\n') { |
296 | out.resize(N: size + read); |
297 | return success(); |
298 | } |
299 | size += read; |
300 | } |
301 | } |
302 | |
303 | // Returns std::nullopt when: |
304 | // - ferror(), feof(), or shutdownRequested() are set. |
305 | // - Content-Length is missing or empty (protocol error) |
306 | LogicalResult |
307 | JSONTransportInputOverFile::readStandardMessage(std::string &json) { |
308 | // A Language Server Protocol message starts with a set of HTTP headers, |
309 | // delimited by \r\n, and terminated by an empty line (\r\n). |
310 | unsigned long long contentLength = 0; |
311 | llvm::SmallString<128> line; |
312 | while (true) { |
313 | if (feof(stream: in) || hasError() || failed(Result: readLine(in, out&: line))) |
314 | return failure(); |
315 | |
316 | // Content-Length is a mandatory header, and the only one we handle. |
317 | StringRef lineRef = line; |
318 | if (lineRef.consume_front(Prefix: "Content-Length: " )) { |
319 | llvm::getAsUnsignedInteger(Str: lineRef.trim(), Radix: 0, Result&: contentLength); |
320 | } else if (!lineRef.trim().empty()) { |
321 | // It's another header, ignore it. |
322 | continue; |
323 | } else { |
324 | // An empty line indicates the end of headers. Go ahead and read the JSON. |
325 | break; |
326 | } |
327 | } |
328 | |
329 | // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" |
330 | if (contentLength == 0 || contentLength > 1 << 30) |
331 | return failure(); |
332 | |
333 | json.resize(n: contentLength); |
334 | for (size_t pos = 0, read; pos < contentLength; pos += read) { |
335 | read = std::fread(ptr: &json[pos], size: 1, n: contentLength - pos, stream: in); |
336 | if (read == 0) |
337 | return failure(); |
338 | |
339 | // If we're done, the error was transient. If we're not done, either it was |
340 | // transient or we'll see it again on retry. |
341 | clearerr(stream: in); |
342 | pos += read; |
343 | } |
344 | return success(); |
345 | } |
346 | |
347 | /// For lit tests we support a simplified syntax: |
348 | /// - messages are delimited by '// -----' on a line by itself |
349 | /// - lines starting with // are ignored. |
350 | /// This is a testing path, so favor simplicity over performance here. |
351 | /// When returning failure: feof(), ferror(), or shutdownRequested() will be |
352 | /// set. |
353 | LogicalResult |
354 | JSONTransportInputOverFile::readDelimitedMessage(std::string &json) { |
355 | json.clear(); |
356 | llvm::SmallString<128> line; |
357 | while (succeeded(Result: readLine(in, out&: line))) { |
358 | StringRef lineRef = line.str().trim(); |
359 | if (lineRef.starts_with(Prefix: "//" )) { |
360 | // Found a delimiter for the message. |
361 | if (lineRef == kDefaultSplitMarker) |
362 | break; |
363 | continue; |
364 | } |
365 | |
366 | json += line; |
367 | } |
368 | |
369 | return failure(IsFailure: ferror(stream: in)); |
370 | } |
371 | |