1//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Tools/lsp-server-support/Transport.h"
10#include "mlir/Support/ToolUtilities.h"
11#include "mlir/Tools/lsp-server-support/Logging.h"
12#include "mlir/Tools/lsp-server-support/Protocol.h"
13#include "llvm/ADT/SmallString.h"
14#include "llvm/Support/Errno.h"
15#include "llvm/Support/Error.h"
16#include <optional>
17#include <system_error>
18#include <utility>
19
20using namespace mlir;
21using namespace mlir::lsp;
22
23//===----------------------------------------------------------------------===//
24// Reply
25//===----------------------------------------------------------------------===//
26
27namespace {
28/// Function object to reply to an LSP call.
29/// Each instance must be called exactly once, otherwise:
30/// - if there was no reply, an error reply is sent
31/// - if there were multiple replies, only the first is sent
32class Reply {
33public:
34 Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport,
35 std::mutex &transportOutputMutex);
36 Reply(Reply &&other);
37 Reply &operator=(Reply &&) = delete;
38 Reply(const Reply &) = delete;
39 Reply &operator=(const Reply &) = delete;
40
41 void operator()(llvm::Expected<llvm::json::Value> reply);
42
43private:
44 StringRef method;
45 std::atomic<bool> replied = {false};
46 llvm::json::Value id;
47 JSONTransport *transport;
48 std::mutex &transportOutputMutex;
49};
50} // namespace
51
52Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
53 JSONTransport &transport, std::mutex &transportOutputMutex)
54 : id(id), transport(&transport),
55 transportOutputMutex(transportOutputMutex) {}
56
57Reply::Reply(Reply &&other)
58 : replied(other.replied.load()), id(std::move(other.id)),
59 transport(other.transport),
60 transportOutputMutex(other.transportOutputMutex) {
61 other.transport = nullptr;
62}
63
64void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
65 if (replied.exchange(i: true)) {
66 Logger::error(fmt: "Replied twice to message {0}({1})", vals&: method, vals&: id);
67 assert(false && "must reply to each call only once!");
68 return;
69 }
70 assert(transport && "expected valid transport to reply to");
71
72 std::lock_guard<std::mutex> transportLock(transportOutputMutex);
73 if (reply) {
74 Logger::info(fmt: "--> reply:{0}({1})", vals&: method, vals&: id);
75 transport->reply(id: std::move(id), result: std::move(reply));
76 } else {
77 llvm::Error error = reply.takeError();
78 Logger::info(fmt: "--> reply:{0}({1})", vals&: method, vals&: id, vals&: error);
79 transport->reply(id: std::move(id), result: std::move(error));
80 }
81}
82
83//===----------------------------------------------------------------------===//
84// MessageHandler
85//===----------------------------------------------------------------------===//
86
87bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
88 Logger::info(fmt: "--> {0}", vals&: method);
89
90 if (method == "exit")
91 return false;
92 if (method == "$cancel") {
93 // TODO: Add support for cancelling requests.
94 } else {
95 auto it = notificationHandlers.find(Key: method);
96 if (it != notificationHandlers.end())
97 it->second(std::move(value));
98 }
99 return true;
100}
101
102bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
103 llvm::json::Value id) {
104 Logger::info(fmt: "--> {0}({1})", vals&: method, vals&: id);
105
106 Reply reply(id, method, transport, transportOutputMutex);
107
108 auto it = methodHandlers.find(Key: method);
109 if (it != methodHandlers.end()) {
110 it->second(std::move(params), std::move(reply));
111 } else {
112 reply(llvm::make_error<LSPError>(Args: "method not found: " + method.str(),
113 Args: ErrorCode::MethodNotFound));
114 }
115 return true;
116}
117
118bool MessageHandler::onReply(llvm::json::Value id,
119 llvm::Expected<llvm::json::Value> result) {
120 // TODO: Add support for reply callbacks when support for outgoing messages is
121 // added. For now, we just log an error on any replies received.
122 Callback<llvm::json::Value> replyHandler =
123 [&id](llvm::Expected<llvm::json::Value> result) {
124 Logger::error(
125 fmt: "received a reply with ID {0}, but there was no such call", vals&: id);
126 if (!result)
127 llvm::consumeError(Err: result.takeError());
128 };
129
130 // Log and run the reply handler.
131 if (result)
132 replyHandler(std::move(result));
133 else
134 replyHandler(result.takeError());
135 return true;
136}
137
138//===----------------------------------------------------------------------===//
139// JSONTransport
140//===----------------------------------------------------------------------===//
141
142/// Encode the given error as a JSON object.
143static llvm::json::Object encodeError(llvm::Error error) {
144 std::string message;
145 ErrorCode code = ErrorCode::UnknownErrorCode;
146 auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
147 message = lspError.message;
148 code = lspError.code;
149 return llvm::Error::success();
150 };
151 if (llvm::Error unhandled = llvm::handleErrors(E: std::move(error), Hs&: handlerFn))
152 message = llvm::toString(E: std::move(unhandled));
153
154 return llvm::json::Object{
155 {.K: "message", .V: std::move(message)},
156 {.K: "code", .V: int64_t(code)},
157 };
158}
159
160/// Decode the given JSON object into an error.
161llvm::Error decodeError(const llvm::json::Object &o) {
162 StringRef msg = o.getString(K: "message").value_or(u: "Unspecified error");
163 if (std::optional<int64_t> code = o.getInteger(K: "code"))
164 return llvm::make_error<LSPError>(Args: msg.str(), Args: ErrorCode(*code));
165 return llvm::make_error<llvm::StringError>(Args: llvm::inconvertibleErrorCode(),
166 Args: msg.str());
167}
168
169void JSONTransport::notify(StringRef method, llvm::json::Value params) {
170 sendMessage(msg: llvm::json::Object{
171 {.K: "jsonrpc", .V: "2.0"},
172 {.K: "method", .V: method},
173 {.K: "params", .V: std::move(params)},
174 });
175}
176void JSONTransport::call(StringRef method, llvm::json::Value params,
177 llvm::json::Value id) {
178 sendMessage(msg: llvm::json::Object{
179 {.K: "jsonrpc", .V: "2.0"},
180 {.K: "id", .V: std::move(id)},
181 {.K: "method", .V: method},
182 {.K: "params", .V: std::move(params)},
183 });
184}
185void JSONTransport::reply(llvm::json::Value id,
186 llvm::Expected<llvm::json::Value> result) {
187 if (result) {
188 return sendMessage(msg: llvm::json::Object{
189 {.K: "jsonrpc", .V: "2.0"},
190 {.K: "id", .V: std::move(id)},
191 {.K: "result", .V: std::move(*result)},
192 });
193 }
194
195 sendMessage(msg: llvm::json::Object{
196 {.K: "jsonrpc", .V: "2.0"},
197 {.K: "id", .V: std::move(id)},
198 {.K: "error", .V: encodeError(error: result.takeError())},
199 });
200}
201
202llvm::Error JSONTransport::run(MessageHandler &handler) {
203 std::string json;
204 while (!feof(stream: in)) {
205 if (ferror(stream: in)) {
206 return llvm::errorCodeToError(
207 EC: std::error_code(errno, std::system_category()));
208 }
209
210 if (succeeded(result: readMessage(json))) {
211 if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(JSON: json)) {
212 if (!handleMessage(msg: std::move(*doc), handler))
213 return llvm::Error::success();
214 } else {
215 Logger::error(fmt: "JSON parse error: {0}", vals: llvm::toString(E: doc.takeError()));
216 }
217 }
218 }
219 return llvm::errorCodeToError(EC: std::make_error_code(e: std::errc::io_error));
220}
221
222void JSONTransport::sendMessage(llvm::json::Value msg) {
223 outputBuffer.clear();
224 llvm::raw_svector_ostream os(outputBuffer);
225 os << llvm::formatv(Fmt: prettyOutput ? "{0:2}\n" : "{0}", Vals&: msg);
226 out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
227 << outputBuffer;
228 out.flush();
229 Logger::debug(fmt: ">>> {0}\n", vals&: outputBuffer);
230}
231
232bool JSONTransport::handleMessage(llvm::json::Value msg,
233 MessageHandler &handler) {
234 // Message must be an object with "jsonrpc":"2.0".
235 llvm::json::Object *object = msg.getAsObject();
236 if (!object ||
237 object->getString(K: "jsonrpc") != std::optional<StringRef>("2.0"))
238 return false;
239
240 // `id` may be any JSON value. If absent, this is a notification.
241 std::optional<llvm::json::Value> id;
242 if (llvm::json::Value *i = object->get(K: "id"))
243 id = std::move(*i);
244 std::optional<StringRef> method = object->getString(K: "method");
245
246 // This is a response.
247 if (!method) {
248 if (!id)
249 return false;
250 if (auto *err = object->getObject(K: "error"))
251 return handler.onReply(id: std::move(*id), result: decodeError(o: *err));
252 // result should be given, use null if not.
253 llvm::json::Value result = nullptr;
254 if (llvm::json::Value *r = object->get(K: "result"))
255 result = std::move(*r);
256 return handler.onReply(id: std::move(*id), result: std::move(result));
257 }
258
259 // Params should be given, use null if not.
260 llvm::json::Value params = nullptr;
261 if (llvm::json::Value *p = object->get(K: "params"))
262 params = std::move(*p);
263
264 if (id)
265 return handler.onCall(method: *method, params: std::move(params), id: std::move(*id));
266 return handler.onNotify(method: *method, value: std::move(params));
267}
268
269/// Tries to read a line up to and including \n.
270/// If failing, feof(), ferror(), or shutdownRequested() will be set.
271LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) {
272 // Big enough to hold any reasonable header line. May not fit content lines
273 // in delimited mode, but performance doesn't matter for that mode.
274 static constexpr int bufSize = 128;
275 size_t size = 0;
276 out.clear();
277 for (;;) {
278 out.resize_for_overwrite(N: size + bufSize);
279 if (!std::fgets(s: &out[size], n: bufSize, stream: in))
280 return failure();
281
282 clearerr(stream: in);
283
284 // If the line contained null bytes, anything after it (including \n) will
285 // be ignored. Fortunately this is not a legal header or JSON.
286 size_t read = std::strlen(s: &out[size]);
287 if (read > 0 && out[size + read - 1] == '\n') {
288 out.resize(N: size + read);
289 return success();
290 }
291 size += read;
292 }
293}
294
295// Returns std::nullopt when:
296// - ferror(), feof(), or shutdownRequested() are set.
297// - Content-Length is missing or empty (protocol error)
298LogicalResult JSONTransport::readStandardMessage(std::string &json) {
299 // A Language Server Protocol message starts with a set of HTTP headers,
300 // delimited by \r\n, and terminated by an empty line (\r\n).
301 unsigned long long contentLength = 0;
302 llvm::SmallString<128> line;
303 while (true) {
304 if (feof(stream: in) || ferror(stream: in) || failed(result: readLine(in, out&: line)))
305 return failure();
306
307 // Content-Length is a mandatory header, and the only one we handle.
308 StringRef lineRef = line;
309 if (lineRef.consume_front(Prefix: "Content-Length: ")) {
310 llvm::getAsUnsignedInteger(Str: lineRef.trim(), Radix: 0, Result&: contentLength);
311 } else if (!lineRef.trim().empty()) {
312 // It's another header, ignore it.
313 continue;
314 } else {
315 // An empty line indicates the end of headers. Go ahead and read the JSON.
316 break;
317 }
318 }
319
320 // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
321 if (contentLength == 0 || contentLength > 1 << 30)
322 return failure();
323
324 json.resize(n: contentLength);
325 for (size_t pos = 0, read; pos < contentLength; pos += read) {
326 read = std::fread(ptr: &json[pos], size: 1, n: contentLength - pos, stream: in);
327 if (read == 0)
328 return failure();
329
330 // If we're done, the error was transient. If we're not done, either it was
331 // transient or we'll see it again on retry.
332 clearerr(stream: in);
333 pos += read;
334 }
335 return success();
336}
337
338/// For lit tests we support a simplified syntax:
339/// - messages are delimited by '// -----' on a line by itself
340/// - lines starting with // are ignored.
341/// This is a testing path, so favor simplicity over performance here.
342/// When returning failure: feof(), ferror(), or shutdownRequested() will be
343/// set.
344LogicalResult JSONTransport::readDelimitedMessage(std::string &json) {
345 json.clear();
346 llvm::SmallString<128> line;
347 while (succeeded(result: readLine(in, out&: line))) {
348 StringRef lineRef = line.str().trim();
349 if (lineRef.starts_with(Prefix: "//")) {
350 // Found a delimiter for the message.
351 if (lineRef == kDefaultSplitMarker)
352 break;
353 continue;
354 }
355
356 json += line;
357 }
358
359 return failure(isFailure: ferror(stream: in));
360}
361

source code of mlir/lib/Tools/lsp-server-support/Transport.cpp