1 | //===- LSPServer.cpp - MLIR Language Server -------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "LSPServer.h" |
10 | #include "MLIRServer.h" |
11 | #include "Protocol.h" |
12 | #include "mlir/Tools/lsp-server-support/Logging.h" |
13 | #include "mlir/Tools/lsp-server-support/Transport.h" |
14 | #include "llvm/ADT/FunctionExtras.h" |
15 | #include "llvm/ADT/StringMap.h" |
16 | #include <optional> |
17 | |
18 | #define DEBUG_TYPE "mlir-lsp-server" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::lsp; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // LSPServer |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | namespace { |
28 | struct LSPServer { |
29 | LSPServer(MLIRServer &server) : server(server) {} |
30 | |
31 | //===--------------------------------------------------------------------===// |
32 | // Initialization |
33 | |
34 | void onInitialize(const InitializeParams ¶ms, |
35 | Callback<llvm::json::Value> reply); |
36 | void onInitialized(const InitializedParams ¶ms); |
37 | void onShutdown(const NoParams ¶ms, Callback<std::nullptr_t> reply); |
38 | |
39 | //===--------------------------------------------------------------------===// |
40 | // Document Change |
41 | |
42 | void onDocumentDidOpen(const DidOpenTextDocumentParams ¶ms); |
43 | void onDocumentDidClose(const DidCloseTextDocumentParams ¶ms); |
44 | void onDocumentDidChange(const DidChangeTextDocumentParams ¶ms); |
45 | |
46 | //===--------------------------------------------------------------------===// |
47 | // Definitions and References |
48 | |
49 | void onGoToDefinition(const TextDocumentPositionParams ¶ms, |
50 | Callback<std::vector<Location>> reply); |
51 | void onReference(const ReferenceParams ¶ms, |
52 | Callback<std::vector<Location>> reply); |
53 | |
54 | //===--------------------------------------------------------------------===// |
55 | // Hover |
56 | |
57 | void onHover(const TextDocumentPositionParams ¶ms, |
58 | Callback<std::optional<Hover>> reply); |
59 | |
60 | //===--------------------------------------------------------------------===// |
61 | // Document Symbols |
62 | |
63 | void onDocumentSymbol(const DocumentSymbolParams ¶ms, |
64 | Callback<std::vector<DocumentSymbol>> reply); |
65 | |
66 | //===--------------------------------------------------------------------===// |
67 | // Code Completion |
68 | |
69 | void onCompletion(const CompletionParams ¶ms, |
70 | Callback<CompletionList> reply); |
71 | |
72 | //===--------------------------------------------------------------------===// |
73 | // Code Action |
74 | |
75 | void onCodeAction(const CodeActionParams ¶ms, |
76 | Callback<llvm::json::Value> reply); |
77 | |
78 | //===--------------------------------------------------------------------===// |
79 | // Bytecode |
80 | |
81 | void onConvertFromBytecode(const MLIRConvertBytecodeParams ¶ms, |
82 | Callback<MLIRConvertBytecodeResult> reply); |
83 | void onConvertToBytecode(const MLIRConvertBytecodeParams ¶ms, |
84 | Callback<MLIRConvertBytecodeResult> reply); |
85 | |
86 | //===--------------------------------------------------------------------===// |
87 | // Fields |
88 | //===--------------------------------------------------------------------===// |
89 | |
90 | MLIRServer &server; |
91 | |
92 | /// An outgoing notification used to send diagnostics to the client when they |
93 | /// are ready to be processed. |
94 | OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics; |
95 | |
96 | /// Used to indicate that the 'shutdown' request was received from the |
97 | /// Language Server client. |
98 | bool shutdownRequestReceived = false; |
99 | }; |
100 | } // namespace |
101 | |
102 | //===----------------------------------------------------------------------===// |
103 | // Initialization |
104 | //===----------------------------------------------------------------------===// |
105 | |
106 | void LSPServer::onInitialize(const InitializeParams ¶ms, |
107 | Callback<llvm::json::Value> reply) { |
108 | // Send a response with the capabilities of this server. |
109 | llvm::json::Object serverCaps{ |
110 | {.K: "textDocumentSync" , |
111 | .V: llvm::json::Object{ |
112 | {.K: "openClose" , .V: true}, |
113 | {.K: "change" , .V: (int)TextDocumentSyncKind::Full}, |
114 | {.K: "save" , .V: true}, |
115 | }}, |
116 | {.K: "completionProvider" , |
117 | .V: llvm::json::Object{ |
118 | {.K: "allCommitCharacters" , |
119 | .V: { |
120 | "\t" , |
121 | ";" , |
122 | "," , |
123 | "." , |
124 | "=" , |
125 | }}, |
126 | {.K: "resolveProvider" , .V: false}, |
127 | {.K: "triggerCharacters" , |
128 | .V: {"." , "%" , "^" , "!" , "#" , "(" , "," , "<" , ":" , "[" , " " , "\"" , "/" }}, |
129 | }}, |
130 | {.K: "definitionProvider" , .V: true}, |
131 | {.K: "referencesProvider" , .V: true}, |
132 | {.K: "hoverProvider" , .V: true}, |
133 | |
134 | // For now we only support documenting symbols when the client supports |
135 | // hierarchical symbols. |
136 | {.K: "documentSymbolProvider" , |
137 | .V: params.capabilities.hierarchicalDocumentSymbol}, |
138 | }; |
139 | |
140 | // Per LSP, codeActionProvider can be either boolean or CodeActionOptions. |
141 | // CodeActionOptions is only valid if the client supports action literal |
142 | // via textDocument.codeAction.codeActionLiteralSupport. |
143 | serverCaps["codeActionProvider" ] = |
144 | params.capabilities.codeActionStructure |
145 | ? llvm::json::Object{{.K: "codeActionKinds" , |
146 | .V: {CodeAction::kQuickFix, CodeAction::kRefactor, |
147 | CodeAction::kInfo}}} |
148 | : llvm::json::Value(true); |
149 | |
150 | llvm::json::Object result{ |
151 | {{.K: "serverInfo" , |
152 | .V: llvm::json::Object{{.K: "name" , .V: "mlir-lsp-server" }, {.K: "version" , .V: "0.0.0" }}}, |
153 | {.K: "capabilities" , .V: std::move(serverCaps)}}}; |
154 | reply(std::move(result)); |
155 | } |
156 | void LSPServer::onInitialized(const InitializedParams &) {} |
157 | void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) { |
158 | shutdownRequestReceived = true; |
159 | reply(nullptr); |
160 | } |
161 | |
162 | //===----------------------------------------------------------------------===// |
163 | // Document Change |
164 | //===----------------------------------------------------------------------===// |
165 | |
166 | void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams ¶ms) { |
167 | PublishDiagnosticsParams diagParams(params.textDocument.uri, |
168 | params.textDocument.version); |
169 | server.addOrUpdateDocument(uri: params.textDocument.uri, contents: params.textDocument.text, |
170 | version: params.textDocument.version, |
171 | diagnostics&: diagParams.diagnostics); |
172 | |
173 | // Publish any recorded diagnostics. |
174 | publishDiagnostics(diagParams); |
175 | } |
176 | void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams ¶ms) { |
177 | std::optional<int64_t> version = |
178 | server.removeDocument(uri: params.textDocument.uri); |
179 | if (!version) |
180 | return; |
181 | |
182 | // Empty out the diagnostics shown for this document. This will clear out |
183 | // anything currently displayed by the client for this document (e.g. in the |
184 | // "Problems" pane of VSCode). |
185 | publishDiagnostics( |
186 | PublishDiagnosticsParams(params.textDocument.uri, *version)); |
187 | } |
188 | void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams ¶ms) { |
189 | // TODO: We currently only support full document updates, we should refactor |
190 | // to avoid this. |
191 | if (params.contentChanges.size() != 1) |
192 | return; |
193 | PublishDiagnosticsParams diagParams(params.textDocument.uri, |
194 | params.textDocument.version); |
195 | server.addOrUpdateDocument( |
196 | uri: params.textDocument.uri, contents: params.contentChanges.front().text, |
197 | version: params.textDocument.version, diagnostics&: diagParams.diagnostics); |
198 | |
199 | // Publish any recorded diagnostics. |
200 | publishDiagnostics(diagParams); |
201 | } |
202 | |
203 | //===----------------------------------------------------------------------===// |
204 | // Definitions and References |
205 | //===----------------------------------------------------------------------===// |
206 | |
207 | void LSPServer::onGoToDefinition(const TextDocumentPositionParams ¶ms, |
208 | Callback<std::vector<Location>> reply) { |
209 | std::vector<Location> locations; |
210 | server.getLocationsOf(uri: params.textDocument.uri, defPos: params.position, locations); |
211 | reply(std::move(locations)); |
212 | } |
213 | |
214 | void LSPServer::onReference(const ReferenceParams ¶ms, |
215 | Callback<std::vector<Location>> reply) { |
216 | std::vector<Location> locations; |
217 | server.findReferencesOf(uri: params.textDocument.uri, pos: params.position, references&: locations); |
218 | reply(std::move(locations)); |
219 | } |
220 | |
221 | //===----------------------------------------------------------------------===// |
222 | // Hover |
223 | //===----------------------------------------------------------------------===// |
224 | |
225 | void LSPServer::onHover(const TextDocumentPositionParams ¶ms, |
226 | Callback<std::optional<Hover>> reply) { |
227 | reply(server.findHover(uri: params.textDocument.uri, hoverPos: params.position)); |
228 | } |
229 | |
230 | //===----------------------------------------------------------------------===// |
231 | // Document Symbols |
232 | //===----------------------------------------------------------------------===// |
233 | |
234 | void LSPServer::onDocumentSymbol(const DocumentSymbolParams ¶ms, |
235 | Callback<std::vector<DocumentSymbol>> reply) { |
236 | std::vector<DocumentSymbol> symbols; |
237 | server.findDocumentSymbols(uri: params.textDocument.uri, symbols); |
238 | reply(std::move(symbols)); |
239 | } |
240 | |
241 | //===----------------------------------------------------------------------===// |
242 | // Code Completion |
243 | //===----------------------------------------------------------------------===// |
244 | |
245 | void LSPServer::onCompletion(const CompletionParams ¶ms, |
246 | Callback<CompletionList> reply) { |
247 | reply(server.getCodeCompletion(uri: params.textDocument.uri, completePos: params.position)); |
248 | } |
249 | |
250 | //===----------------------------------------------------------------------===// |
251 | // Code Action |
252 | //===----------------------------------------------------------------------===// |
253 | |
254 | void LSPServer::onCodeAction(const CodeActionParams ¶ms, |
255 | Callback<llvm::json::Value> reply) { |
256 | URIForFile uri = params.textDocument.uri; |
257 | |
258 | // Check whether a particular CodeActionKind is included in the response. |
259 | auto isKindAllowed = [only(params.context.only)](StringRef kind) { |
260 | if (only.empty()) |
261 | return true; |
262 | return llvm::any_of(Range: only, P: [&](StringRef base) { |
263 | return kind.consume_front(Prefix: base) && |
264 | (kind.empty() || kind.starts_with(Prefix: "." )); |
265 | }); |
266 | }; |
267 | |
268 | // We provide a code action for fixes on the specified diagnostics. |
269 | std::vector<CodeAction> actions; |
270 | if (isKindAllowed(CodeAction::kQuickFix)) |
271 | server.getCodeActions(uri, pos: params.range.start, context: params.context, actions); |
272 | reply(std::move(actions)); |
273 | } |
274 | |
275 | //===----------------------------------------------------------------------===// |
276 | // Bytecode |
277 | //===----------------------------------------------------------------------===// |
278 | |
279 | void LSPServer::onConvertFromBytecode( |
280 | const MLIRConvertBytecodeParams ¶ms, |
281 | Callback<MLIRConvertBytecodeResult> reply) { |
282 | reply(server.convertFromBytecode(uri: params.uri)); |
283 | } |
284 | |
285 | void LSPServer::onConvertToBytecode(const MLIRConvertBytecodeParams ¶ms, |
286 | Callback<MLIRConvertBytecodeResult> reply) { |
287 | reply(server.convertToBytecode(uri: params.uri)); |
288 | } |
289 | |
290 | //===----------------------------------------------------------------------===// |
291 | // Entry point |
292 | //===----------------------------------------------------------------------===// |
293 | |
294 | LogicalResult lsp::runMlirLSPServer(MLIRServer &server, |
295 | JSONTransport &transport) { |
296 | LSPServer lspServer(server); |
297 | MessageHandler messageHandler(transport); |
298 | |
299 | // Initialization |
300 | messageHandler.method(method: "initialize" , thisPtr: &lspServer, handler: &LSPServer::onInitialize); |
301 | messageHandler.notification(method: "initialized" , thisPtr: &lspServer, |
302 | handler: &LSPServer::onInitialized); |
303 | messageHandler.method(method: "shutdown" , thisPtr: &lspServer, handler: &LSPServer::onShutdown); |
304 | |
305 | // Document Changes |
306 | messageHandler.notification(method: "textDocument/didOpen" , thisPtr: &lspServer, |
307 | handler: &LSPServer::onDocumentDidOpen); |
308 | messageHandler.notification(method: "textDocument/didClose" , thisPtr: &lspServer, |
309 | handler: &LSPServer::onDocumentDidClose); |
310 | messageHandler.notification(method: "textDocument/didChange" , thisPtr: &lspServer, |
311 | handler: &LSPServer::onDocumentDidChange); |
312 | |
313 | // Definitions and References |
314 | messageHandler.method(method: "textDocument/definition" , thisPtr: &lspServer, |
315 | handler: &LSPServer::onGoToDefinition); |
316 | messageHandler.method(method: "textDocument/references" , thisPtr: &lspServer, |
317 | handler: &LSPServer::onReference); |
318 | |
319 | // Hover |
320 | messageHandler.method(method: "textDocument/hover" , thisPtr: &lspServer, handler: &LSPServer::onHover); |
321 | |
322 | // Document Symbols |
323 | messageHandler.method(method: "textDocument/documentSymbol" , thisPtr: &lspServer, |
324 | handler: &LSPServer::onDocumentSymbol); |
325 | |
326 | // Code Completion |
327 | messageHandler.method(method: "textDocument/completion" , thisPtr: &lspServer, |
328 | handler: &LSPServer::onCompletion); |
329 | |
330 | // Code Action |
331 | messageHandler.method(method: "textDocument/codeAction" , thisPtr: &lspServer, |
332 | handler: &LSPServer::onCodeAction); |
333 | |
334 | // Bytecode |
335 | messageHandler.method(method: "mlir/convertFromBytecode" , thisPtr: &lspServer, |
336 | handler: &LSPServer::onConvertFromBytecode); |
337 | messageHandler.method(method: "mlir/convertToBytecode" , thisPtr: &lspServer, |
338 | handler: &LSPServer::onConvertToBytecode); |
339 | |
340 | // Diagnostics |
341 | lspServer.publishDiagnostics = |
342 | messageHandler.outgoingNotification<PublishDiagnosticsParams>( |
343 | method: "textDocument/publishDiagnostics" ); |
344 | |
345 | // Run the main loop of the transport. |
346 | LogicalResult result = success(); |
347 | if (llvm::Error error = transport.run(handler&: messageHandler)) { |
348 | Logger::error(fmt: "Transport error: {0}" , vals&: error); |
349 | llvm::consumeError(Err: std::move(error)); |
350 | result = failure(); |
351 | } else { |
352 | result = success(IsSuccess: lspServer.shutdownRequestReceived); |
353 | } |
354 | return result; |
355 | } |
356 | |