1//===- LSPServer.cpp - MLIR Language Server -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "LSPServer.h"
10#include "MLIRServer.h"
11#include "Protocol.h"
12#include "mlir/Tools/lsp-server-support/Logging.h"
13#include "mlir/Tools/lsp-server-support/Transport.h"
14#include "llvm/ADT/FunctionExtras.h"
15#include "llvm/ADT/StringMap.h"
16#include <optional>
17
18#define DEBUG_TYPE "mlir-lsp-server"
19
20using namespace mlir;
21using namespace mlir::lsp;
22
23//===----------------------------------------------------------------------===//
24// LSPServer
25//===----------------------------------------------------------------------===//
26
27namespace {
28struct LSPServer {
29 LSPServer(MLIRServer &server) : server(server) {}
30
31 //===--------------------------------------------------------------------===//
32 // Initialization
33
34 void onInitialize(const InitializeParams &params,
35 Callback<llvm::json::Value> reply);
36 void onInitialized(const InitializedParams &params);
37 void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);
38
39 //===--------------------------------------------------------------------===//
40 // Document Change
41
42 void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
43 void onDocumentDidClose(const DidCloseTextDocumentParams &params);
44 void onDocumentDidChange(const DidChangeTextDocumentParams &params);
45
46 //===--------------------------------------------------------------------===//
47 // Definitions and References
48
49 void onGoToDefinition(const TextDocumentPositionParams &params,
50 Callback<std::vector<Location>> reply);
51 void onReference(const ReferenceParams &params,
52 Callback<std::vector<Location>> reply);
53
54 //===--------------------------------------------------------------------===//
55 // Hover
56
57 void onHover(const TextDocumentPositionParams &params,
58 Callback<std::optional<Hover>> reply);
59
60 //===--------------------------------------------------------------------===//
61 // Document Symbols
62
63 void onDocumentSymbol(const DocumentSymbolParams &params,
64 Callback<std::vector<DocumentSymbol>> reply);
65
66 //===--------------------------------------------------------------------===//
67 // Code Completion
68
69 void onCompletion(const CompletionParams &params,
70 Callback<CompletionList> reply);
71
72 //===--------------------------------------------------------------------===//
73 // Code Action
74
75 void onCodeAction(const CodeActionParams &params,
76 Callback<llvm::json::Value> reply);
77
78 //===--------------------------------------------------------------------===//
79 // Bytecode
80
81 void onConvertFromBytecode(const MLIRConvertBytecodeParams &params,
82 Callback<MLIRConvertBytecodeResult> reply);
83 void onConvertToBytecode(const MLIRConvertBytecodeParams &params,
84 Callback<MLIRConvertBytecodeResult> reply);
85
86 //===--------------------------------------------------------------------===//
87 // Fields
88 //===--------------------------------------------------------------------===//
89
90 MLIRServer &server;
91
92 /// An outgoing notification used to send diagnostics to the client when they
93 /// are ready to be processed.
94 OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;
95
96 /// Used to indicate that the 'shutdown' request was received from the
97 /// Language Server client.
98 bool shutdownRequestReceived = false;
99};
100} // namespace
101
102//===----------------------------------------------------------------------===//
103// Initialization
104
105void LSPServer::onInitialize(const InitializeParams &params,
106 Callback<llvm::json::Value> reply) {
107 // Send a response with the capabilities of this server.
108 llvm::json::Object serverCaps{
109 {.K: "textDocumentSync",
110 .V: llvm::json::Object{
111 {.K: "openClose", .V: true},
112 {.K: "change", .V: (int)TextDocumentSyncKind::Full},
113 {.K: "save", .V: true},
114 }},
115 {.K: "completionProvider",
116 .V: llvm::json::Object{
117 {.K: "allCommitCharacters",
118 .V: {
119 "\t",
120 ";",
121 ",",
122 ".",
123 "=",
124 }},
125 {.K: "resolveProvider", .V: false},
126 {.K: "triggerCharacters",
127 .V: {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}},
128 }},
129 {.K: "definitionProvider", .V: true},
130 {.K: "referencesProvider", .V: true},
131 {.K: "hoverProvider", .V: true},
132
133 // For now we only support documenting symbols when the client supports
134 // hierarchical symbols.
135 {.K: "documentSymbolProvider",
136 .V: params.capabilities.hierarchicalDocumentSymbol},
137 };
138
139 // Per LSP, codeActionProvider can be either boolean or CodeActionOptions.
140 // CodeActionOptions is only valid if the client supports action literal
141 // via textDocument.codeAction.codeActionLiteralSupport.
142 serverCaps["codeActionProvider"] =
143 params.capabilities.codeActionStructure
144 ? llvm::json::Object{{.K: "codeActionKinds",
145 .V: {CodeAction::kQuickFix, CodeAction::kRefactor,
146 CodeAction::kInfo}}}
147 : llvm::json::Value(true);
148
149 llvm::json::Object result{
150 {{.K: "serverInfo",
151 .V: llvm::json::Object{{.K: "name", .V: "mlir-lsp-server"}, {.K: "version", .V: "0.0.0"}}},
152 {.K: "capabilities", .V: std::move(serverCaps)}}};
153 reply(std::move(result));
154}
155void LSPServer::onInitialized(const InitializedParams &) {}
156void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) {
157 shutdownRequestReceived = true;
158 reply(nullptr);
159}
160
161//===----------------------------------------------------------------------===//
162// Document Change
163
164void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams &params) {
165 PublishDiagnosticsParams diagParams(params.textDocument.uri,
166 params.textDocument.version);
167 server.addOrUpdateDocument(uri: params.textDocument.uri, contents: params.textDocument.text,
168 version: params.textDocument.version,
169 diagnostics&: diagParams.diagnostics);
170
171 // Publish any recorded diagnostics.
172 publishDiagnostics(diagParams);
173}
174void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams &params) {
175 std::optional<int64_t> version =
176 server.removeDocument(uri: params.textDocument.uri);
177 if (!version)
178 return;
179
180 // Empty out the diagnostics shown for this document. This will clear out
181 // anything currently displayed by the client for this document (e.g. in the
182 // "Problems" pane of VSCode).
183 publishDiagnostics(
184 PublishDiagnosticsParams(params.textDocument.uri, *version));
185}
186void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams &params) {
187 // TODO: We currently only support full document updates, we should refactor
188 // to avoid this.
189 if (params.contentChanges.size() != 1)
190 return;
191 PublishDiagnosticsParams diagParams(params.textDocument.uri,
192 params.textDocument.version);
193 server.addOrUpdateDocument(
194 uri: params.textDocument.uri, contents: params.contentChanges.front().text,
195 version: params.textDocument.version, diagnostics&: diagParams.diagnostics);
196
197 // Publish any recorded diagnostics.
198 publishDiagnostics(diagParams);
199}
200
201//===----------------------------------------------------------------------===//
202// Definitions and References
203
204void LSPServer::onGoToDefinition(const TextDocumentPositionParams &params,
205 Callback<std::vector<Location>> reply) {
206 std::vector<Location> locations;
207 server.getLocationsOf(uri: params.textDocument.uri, defPos: params.position, locations);
208 reply(std::move(locations));
209}
210
211void LSPServer::onReference(const ReferenceParams &params,
212 Callback<std::vector<Location>> reply) {
213 std::vector<Location> locations;
214 server.findReferencesOf(uri: params.textDocument.uri, pos: params.position, references&: locations);
215 reply(std::move(locations));
216}
217
218//===----------------------------------------------------------------------===//
219// Hover
220
221void LSPServer::onHover(const TextDocumentPositionParams &params,
222 Callback<std::optional<Hover>> reply) {
223 reply(server.findHover(uri: params.textDocument.uri, hoverPos: params.position));
224}
225
226//===----------------------------------------------------------------------===//
227// Document Symbols
228
229void LSPServer::onDocumentSymbol(const DocumentSymbolParams &params,
230 Callback<std::vector<DocumentSymbol>> reply) {
231 std::vector<DocumentSymbol> symbols;
232 server.findDocumentSymbols(uri: params.textDocument.uri, symbols);
233 reply(std::move(symbols));
234}
235
236//===----------------------------------------------------------------------===//
237// Code Completion
238
239void LSPServer::onCompletion(const CompletionParams &params,
240 Callback<CompletionList> reply) {
241 reply(server.getCodeCompletion(uri: params.textDocument.uri, completePos: params.position));
242}
243
244//===----------------------------------------------------------------------===//
245// Code Action
246
247void LSPServer::onCodeAction(const CodeActionParams &params,
248 Callback<llvm::json::Value> reply) {
249 URIForFile uri = params.textDocument.uri;
250
251 // Check whether a particular CodeActionKind is included in the response.
252 auto isKindAllowed = [only(params.context.only)](StringRef kind) {
253 if (only.empty())
254 return true;
255 return llvm::any_of(Range: only, P: [&](StringRef base) {
256 return kind.consume_front(Prefix: base) &&
257 (kind.empty() || kind.starts_with(Prefix: "."));
258 });
259 };
260
261 // We provide a code action for fixes on the specified diagnostics.
262 std::vector<CodeAction> actions;
263 if (isKindAllowed(CodeAction::kQuickFix))
264 server.getCodeActions(uri, pos: params.range.start, context: params.context, actions);
265 reply(std::move(actions));
266}
267
268//===----------------------------------------------------------------------===//
269// Bytecode
270
271void LSPServer::onConvertFromBytecode(
272 const MLIRConvertBytecodeParams &params,
273 Callback<MLIRConvertBytecodeResult> reply) {
274 reply(server.convertFromBytecode(uri: params.uri));
275}
276
277void LSPServer::onConvertToBytecode(const MLIRConvertBytecodeParams &params,
278 Callback<MLIRConvertBytecodeResult> reply) {
279 reply(server.convertToBytecode(uri: params.uri));
280}
281
282//===----------------------------------------------------------------------===//
283// Entry point
284//===----------------------------------------------------------------------===//
285
286LogicalResult lsp::runMlirLSPServer(MLIRServer &server,
287 JSONTransport &transport) {
288 LSPServer lspServer(server);
289 MessageHandler messageHandler(transport);
290
291 // Initialization
292 messageHandler.method(method: "initialize", thisPtr: &lspServer, handler: &LSPServer::onInitialize);
293 messageHandler.notification(method: "initialized", thisPtr: &lspServer,
294 handler: &LSPServer::onInitialized);
295 messageHandler.method(method: "shutdown", thisPtr: &lspServer, handler: &LSPServer::onShutdown);
296
297 // Document Changes
298 messageHandler.notification(method: "textDocument/didOpen", thisPtr: &lspServer,
299 handler: &LSPServer::onDocumentDidOpen);
300 messageHandler.notification(method: "textDocument/didClose", thisPtr: &lspServer,
301 handler: &LSPServer::onDocumentDidClose);
302 messageHandler.notification(method: "textDocument/didChange", thisPtr: &lspServer,
303 handler: &LSPServer::onDocumentDidChange);
304
305 // Definitions and References
306 messageHandler.method(method: "textDocument/definition", thisPtr: &lspServer,
307 handler: &LSPServer::onGoToDefinition);
308 messageHandler.method(method: "textDocument/references", thisPtr: &lspServer,
309 handler: &LSPServer::onReference);
310
311 // Hover
312 messageHandler.method(method: "textDocument/hover", thisPtr: &lspServer, handler: &LSPServer::onHover);
313
314 // Document Symbols
315 messageHandler.method(method: "textDocument/documentSymbol", thisPtr: &lspServer,
316 handler: &LSPServer::onDocumentSymbol);
317
318 // Code Completion
319 messageHandler.method(method: "textDocument/completion", thisPtr: &lspServer,
320 handler: &LSPServer::onCompletion);
321
322 // Code Action
323 messageHandler.method(method: "textDocument/codeAction", thisPtr: &lspServer,
324 handler: &LSPServer::onCodeAction);
325
326 // Bytecode
327 messageHandler.method(method: "mlir/convertFromBytecode", thisPtr: &lspServer,
328 handler: &LSPServer::onConvertFromBytecode);
329 messageHandler.method(method: "mlir/convertToBytecode", thisPtr: &lspServer,
330 handler: &LSPServer::onConvertToBytecode);
331
332 // Diagnostics
333 lspServer.publishDiagnostics =
334 messageHandler.outgoingNotification<PublishDiagnosticsParams>(
335 method: "textDocument/publishDiagnostics");
336
337 // Run the main loop of the transport.
338 LogicalResult result = success();
339 if (llvm::Error error = transport.run(handler&: messageHandler)) {
340 Logger::error(fmt: "Transport error: {0}", vals&: error);
341 llvm::consumeError(Err: std::move(error));
342 result = failure();
343 } else {
344 result = success(isSuccess: lspServer.shutdownRequestReceived);
345 }
346 return result;
347}
348

source code of mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp