1 | //===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "MLIRServer.h" |
10 | #include "Protocol.h" |
11 | #include "mlir/AsmParser/AsmParser.h" |
12 | #include "mlir/AsmParser/AsmParserState.h" |
13 | #include "mlir/AsmParser/CodeComplete.h" |
14 | #include "mlir/Bytecode/BytecodeWriter.h" |
15 | #include "mlir/IR/Operation.h" |
16 | #include "mlir/Interfaces/FunctionInterfaces.h" |
17 | #include "mlir/Parser/Parser.h" |
18 | #include "mlir/Support/ToolUtilities.h" |
19 | #include "mlir/Tools/lsp-server-support/Logging.h" |
20 | #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" |
21 | #include "llvm/ADT/StringExtras.h" |
22 | #include "llvm/Support/Base64.h" |
23 | #include "llvm/Support/SourceMgr.h" |
24 | #include <optional> |
25 | |
26 | using namespace mlir; |
27 | |
28 | /// Returns the range of a lexical token given a SMLoc corresponding to the |
29 | /// start of an token location. The range is computed heuristically, and |
30 | /// supports identifier-like tokens, strings, etc. |
31 | static SMRange convertTokenLocToRange(SMLoc loc) { |
32 | return lsp::convertTokenLocToRange(loc, identifierChars: "$-." ); |
33 | } |
34 | |
35 | /// Returns a language server location from the given MLIR file location. |
36 | /// `uriScheme` is the scheme to use when building new uris. |
37 | static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme, |
38 | FileLineColLoc loc) { |
39 | llvm::Expected<lsp::URIForFile> sourceURI = |
40 | lsp::URIForFile::fromFile(absoluteFilepath: loc.getFilename(), scheme: uriScheme); |
41 | if (!sourceURI) { |
42 | lsp::Logger::error("Failed to create URI for file `{0}`: {1}" , |
43 | loc.getFilename(), |
44 | llvm::toString(E: sourceURI.takeError())); |
45 | return std::nullopt; |
46 | } |
47 | |
48 | lsp::Position position; |
49 | position.line = loc.getLine() - 1; |
50 | position.character = loc.getColumn() ? loc.getColumn() - 1 : 0; |
51 | return lsp::Location{*sourceURI, lsp::Range(position)}; |
52 | } |
53 | |
54 | /// Returns a language server location from the given MLIR location, or |
55 | /// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use |
56 | /// when building new uris. `uri` is an optional additional filter that, when |
57 | /// present, is used to filter sub locations that do not share the same uri. |
58 | static std::optional<lsp::Location> |
59 | getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, |
60 | StringRef uriScheme, const lsp::URIForFile *uri = nullptr) { |
61 | std::optional<lsp::Location> location; |
62 | loc->walk(walkFn: [&](Location nestedLoc) { |
63 | FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc); |
64 | if (!fileLoc) |
65 | return WalkResult::advance(); |
66 | |
67 | std::optional<lsp::Location> sourceLoc = |
68 | getLocationFromLoc(uriScheme, fileLoc); |
69 | if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { |
70 | location = *sourceLoc; |
71 | SMLoc loc = sourceMgr.FindLocForLineAndColumn( |
72 | BufferID: sourceMgr.getMainFileID(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn()); |
73 | |
74 | // Use range of potential identifier starting at location, else length 1 |
75 | // range. |
76 | location->range.end.character += 1; |
77 | if (std::optional<SMRange> range = convertTokenLocToRange(loc)) { |
78 | auto lineCol = sourceMgr.getLineAndColumn(Loc: range->End); |
79 | location->range.end.character = |
80 | std::max(fileLoc.getColumn() + 1, lineCol.second - 1); |
81 | } |
82 | return WalkResult::interrupt(); |
83 | } |
84 | return WalkResult::advance(); |
85 | }); |
86 | return location; |
87 | } |
88 | |
89 | /// Collect all of the locations from the given MLIR location that are not |
90 | /// contained within the given URI. |
91 | static void collectLocationsFromLoc(Location loc, |
92 | std::vector<lsp::Location> &locations, |
93 | const lsp::URIForFile &uri) { |
94 | SetVector<Location> visitedLocs; |
95 | loc->walk(walkFn: [&](Location nestedLoc) { |
96 | FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc); |
97 | if (!fileLoc || !visitedLocs.insert(X: nestedLoc)) |
98 | return WalkResult::advance(); |
99 | |
100 | std::optional<lsp::Location> sourceLoc = |
101 | getLocationFromLoc(uri.scheme(), fileLoc); |
102 | if (sourceLoc && sourceLoc->uri != uri) |
103 | locations.push_back(x: *sourceLoc); |
104 | return WalkResult::advance(); |
105 | }); |
106 | } |
107 | |
108 | /// Returns true if the given range contains the given source location. Note |
109 | /// that this has slightly different behavior than SMRange because it is |
110 | /// inclusive of the end location. |
111 | static bool contains(SMRange range, SMLoc loc) { |
112 | return range.Start.getPointer() <= loc.getPointer() && |
113 | loc.getPointer() <= range.End.getPointer(); |
114 | } |
115 | |
116 | /// Returns true if the given location is contained by the definition or one of |
117 | /// the uses of the given SMDefinition. If provided, `overlappedRange` is set to |
118 | /// the range within `def` that the provided `loc` overlapped with. |
119 | static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc, |
120 | SMRange *overlappedRange = nullptr) { |
121 | // Check the main definition. |
122 | if (contains(range: def.loc, loc)) { |
123 | if (overlappedRange) |
124 | *overlappedRange = def.loc; |
125 | return true; |
126 | } |
127 | |
128 | // Check the uses. |
129 | const auto *useIt = llvm::find_if( |
130 | Range: def.uses, P: [&](const SMRange &range) { return contains(range, loc); }); |
131 | if (useIt != def.uses.end()) { |
132 | if (overlappedRange) |
133 | *overlappedRange = *useIt; |
134 | return true; |
135 | } |
136 | return false; |
137 | } |
138 | |
139 | /// Given a location pointing to a result, return the result number it refers |
140 | /// to or std::nullopt if it refers to all of the results. |
141 | static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) { |
142 | // Skip all of the identifier characters. |
143 | auto isIdentifierChar = [](char c) { |
144 | return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' || |
145 | c == '-'; |
146 | }; |
147 | const char *curPtr = loc.getPointer(); |
148 | while (isIdentifierChar(*curPtr)) |
149 | ++curPtr; |
150 | |
151 | // Check to see if this location indexes into the result group, via `#`. If it |
152 | // doesn't, we can't extract a sub result number. |
153 | if (*curPtr != '#') |
154 | return std::nullopt; |
155 | |
156 | // Compute the sub result number from the remaining portion of the string. |
157 | const char *numberStart = ++curPtr; |
158 | while (llvm::isDigit(C: *curPtr)) |
159 | ++curPtr; |
160 | StringRef numberStr(numberStart, curPtr - numberStart); |
161 | unsigned resultNumber = 0; |
162 | return numberStr.consumeInteger(Radix: 10, Result&: resultNumber) ? std::optional<unsigned>() |
163 | : resultNumber; |
164 | } |
165 | |
166 | /// Given a source location range, return the text covered by the given range. |
167 | /// If the range is invalid, returns std::nullopt. |
168 | static std::optional<StringRef> getTextFromRange(SMRange range) { |
169 | if (!range.isValid()) |
170 | return std::nullopt; |
171 | const char *startPtr = range.Start.getPointer(); |
172 | return StringRef(startPtr, range.End.getPointer() - startPtr); |
173 | } |
174 | |
175 | /// Given a block, return its position in its parent region. |
176 | static unsigned getBlockNumber(Block *block) { |
177 | return std::distance(first: block->getParent()->begin(), last: block->getIterator()); |
178 | } |
179 | |
180 | /// Given a block and source location, print the source name of the block to the |
181 | /// given output stream. |
182 | static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) { |
183 | // Try to extract a name from the source location. |
184 | std::optional<StringRef> text = getTextFromRange(range: loc); |
185 | if (text && text->starts_with(Prefix: "^" )) { |
186 | os << *text; |
187 | return; |
188 | } |
189 | |
190 | // Otherwise, we don't have a name so print the block number. |
191 | os << "<Block #" << getBlockNumber(block) << ">" ; |
192 | } |
193 | static void printDefBlockName(raw_ostream &os, |
194 | const AsmParserState::BlockDefinition &def) { |
195 | printDefBlockName(os, block: def.block, loc: def.definition.loc); |
196 | } |
197 | |
198 | /// Convert the given MLIR diagnostic to the LSP form. |
199 | static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, |
200 | Diagnostic &diag, |
201 | const lsp::URIForFile &uri) { |
202 | lsp::Diagnostic lspDiag; |
203 | lspDiag.source = "mlir" ; |
204 | |
205 | // Note: Right now all of the diagnostics are treated as parser issues, but |
206 | // some are parser and some are verifier. |
207 | lspDiag.category = "Parse Error" ; |
208 | |
209 | // Try to grab a file location for this diagnostic. |
210 | // TODO: For simplicity, we just grab the first one. It may be likely that we |
211 | // will need a more interesting heuristic here.' |
212 | StringRef uriScheme = uri.scheme(); |
213 | std::optional<lsp::Location> lspLocation = |
214 | getLocationFromLoc(sourceMgr, loc: diag.getLocation(), uriScheme, uri: &uri); |
215 | if (lspLocation) |
216 | lspDiag.range = lspLocation->range; |
217 | |
218 | // Convert the severity for the diagnostic. |
219 | switch (diag.getSeverity()) { |
220 | case DiagnosticSeverity::Note: |
221 | llvm_unreachable("expected notes to be handled separately" ); |
222 | case DiagnosticSeverity::Warning: |
223 | lspDiag.severity = lsp::DiagnosticSeverity::Warning; |
224 | break; |
225 | case DiagnosticSeverity::Error: |
226 | lspDiag.severity = lsp::DiagnosticSeverity::Error; |
227 | break; |
228 | case DiagnosticSeverity::Remark: |
229 | lspDiag.severity = lsp::DiagnosticSeverity::Information; |
230 | break; |
231 | } |
232 | lspDiag.message = diag.str(); |
233 | |
234 | // Attach any notes to the main diagnostic as related information. |
235 | std::vector<lsp::DiagnosticRelatedInformation> relatedDiags; |
236 | for (Diagnostic ¬e : diag.getNotes()) { |
237 | lsp::Location noteLoc; |
238 | if (std::optional<lsp::Location> loc = |
239 | getLocationFromLoc(sourceMgr, loc: note.getLocation(), uriScheme)) |
240 | noteLoc = *loc; |
241 | else |
242 | noteLoc.uri = uri; |
243 | relatedDiags.emplace_back(args&: noteLoc, args: note.str()); |
244 | } |
245 | if (!relatedDiags.empty()) |
246 | lspDiag.relatedInformation = std::move(relatedDiags); |
247 | |
248 | return lspDiag; |
249 | } |
250 | |
251 | //===----------------------------------------------------------------------===// |
252 | // MLIRDocument |
253 | //===----------------------------------------------------------------------===// |
254 | |
255 | namespace { |
256 | /// This class represents all of the information pertaining to a specific MLIR |
257 | /// document. |
258 | struct MLIRDocument { |
259 | MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, |
260 | StringRef contents, std::vector<lsp::Diagnostic> &diagnostics); |
261 | MLIRDocument(const MLIRDocument &) = delete; |
262 | MLIRDocument &operator=(const MLIRDocument &) = delete; |
263 | |
264 | //===--------------------------------------------------------------------===// |
265 | // Definitions and References |
266 | //===--------------------------------------------------------------------===// |
267 | |
268 | void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, |
269 | std::vector<lsp::Location> &locations); |
270 | void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, |
271 | std::vector<lsp::Location> &references); |
272 | |
273 | //===--------------------------------------------------------------------===// |
274 | // Hover |
275 | //===--------------------------------------------------------------------===// |
276 | |
277 | std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri, |
278 | const lsp::Position &hoverPos); |
279 | std::optional<lsp::Hover> |
280 | buildHoverForOperation(SMRange hoverRange, |
281 | const AsmParserState::OperationDefinition &op); |
282 | lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op, |
283 | unsigned resultStart, |
284 | unsigned resultEnd, SMLoc posLoc); |
285 | lsp::Hover buildHoverForBlock(SMRange hoverRange, |
286 | const AsmParserState::BlockDefinition &block); |
287 | lsp::Hover |
288 | buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg, |
289 | const AsmParserState::BlockDefinition &block); |
290 | |
291 | lsp::Hover buildHoverForAttributeAlias( |
292 | SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr); |
293 | lsp::Hover |
294 | buildHoverForTypeAlias(SMRange hoverRange, |
295 | const AsmParserState::TypeAliasDefinition &type); |
296 | |
297 | //===--------------------------------------------------------------------===// |
298 | // Document Symbols |
299 | //===--------------------------------------------------------------------===// |
300 | |
301 | void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); |
302 | void findDocumentSymbols(Operation *op, |
303 | std::vector<lsp::DocumentSymbol> &symbols); |
304 | |
305 | //===--------------------------------------------------------------------===// |
306 | // Code Completion |
307 | //===--------------------------------------------------------------------===// |
308 | |
309 | lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, |
310 | const lsp::Position &completePos, |
311 | const DialectRegistry ®istry); |
312 | |
313 | //===--------------------------------------------------------------------===// |
314 | // Code Action |
315 | //===--------------------------------------------------------------------===// |
316 | |
317 | void getCodeActionForDiagnostic(const lsp::URIForFile &uri, |
318 | lsp::Position &pos, StringRef severity, |
319 | StringRef message, |
320 | std::vector<lsp::TextEdit> &edits); |
321 | |
322 | //===--------------------------------------------------------------------===// |
323 | // Bytecode |
324 | //===--------------------------------------------------------------------===// |
325 | |
326 | llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode(); |
327 | |
328 | //===--------------------------------------------------------------------===// |
329 | // Fields |
330 | //===--------------------------------------------------------------------===// |
331 | |
332 | /// The high level parser state used to find definitions and references within |
333 | /// the source file. |
334 | AsmParserState asmState; |
335 | |
336 | /// The container for the IR parsed from the input file. |
337 | Block parsedIR; |
338 | |
339 | /// A collection of external resources, which we want to propagate up to the |
340 | /// user. |
341 | FallbackAsmResourceMap fallbackResourceMap; |
342 | |
343 | /// The source manager containing the contents of the input file. |
344 | llvm::SourceMgr sourceMgr; |
345 | }; |
346 | } // namespace |
347 | |
348 | MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, |
349 | StringRef contents, |
350 | std::vector<lsp::Diagnostic> &diagnostics) { |
351 | ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { |
352 | diagnostics.push_back(x: getLspDiagnoticFromDiag(sourceMgr, diag, uri)); |
353 | }); |
354 | |
355 | // Try to parsed the given IR string. |
356 | auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(InputData: contents, BufferName: uri.file()); |
357 | if (!memBuffer) { |
358 | lsp::Logger::error(fmt: "Failed to create memory buffer for file" , vals: uri.file()); |
359 | return; |
360 | } |
361 | |
362 | ParserConfig config(&context, /*verifyAfterParse=*/true, |
363 | &fallbackResourceMap); |
364 | sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc()); |
365 | if (failed(result: parseAsmSourceFile(sourceMgr, block: &parsedIR, config, asmState: &asmState))) { |
366 | // If parsing failed, clear out any of the current state. |
367 | parsedIR.clear(); |
368 | asmState = AsmParserState(); |
369 | fallbackResourceMap = FallbackAsmResourceMap(); |
370 | return; |
371 | } |
372 | } |
373 | |
374 | //===----------------------------------------------------------------------===// |
375 | // MLIRDocument: Definitions and References |
376 | //===----------------------------------------------------------------------===// |
377 | |
378 | void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri, |
379 | const lsp::Position &defPos, |
380 | std::vector<lsp::Location> &locations) { |
381 | SMLoc posLoc = defPos.getAsSMLoc(mgr&: sourceMgr); |
382 | |
383 | // Functor used to check if an SM definition contains the position. |
384 | auto containsPosition = [&](const AsmParserState::SMDefinition &def) { |
385 | if (!isDefOrUse(def, loc: posLoc)) |
386 | return false; |
387 | locations.emplace_back(args: uri, args&: sourceMgr, args: def.loc); |
388 | return true; |
389 | }; |
390 | |
391 | // Check all definitions related to operations. |
392 | for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { |
393 | if (contains(range: op.loc, loc: posLoc)) |
394 | return collectLocationsFromLoc(loc: op.op->getLoc(), locations, uri); |
395 | for (const auto &result : op.resultGroups) |
396 | if (containsPosition(result.definition)) |
397 | return collectLocationsFromLoc(loc: op.op->getLoc(), locations, uri); |
398 | for (const auto &symUse : op.symbolUses) { |
399 | if (contains(range: symUse, loc: posLoc)) { |
400 | locations.emplace_back(args: uri, args&: sourceMgr, args: op.loc); |
401 | return collectLocationsFromLoc(loc: op.op->getLoc(), locations, uri); |
402 | } |
403 | } |
404 | } |
405 | |
406 | // Check all definitions related to blocks. |
407 | for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { |
408 | if (containsPosition(block.definition)) |
409 | return; |
410 | for (const AsmParserState::SMDefinition &arg : block.arguments) |
411 | if (containsPosition(arg)) |
412 | return; |
413 | } |
414 | |
415 | // Check all alias definitions. |
416 | for (const AsmParserState::AttributeAliasDefinition &attr : |
417 | asmState.getAttributeAliasDefs()) { |
418 | if (containsPosition(attr.definition)) |
419 | return; |
420 | } |
421 | for (const AsmParserState::TypeAliasDefinition &type : |
422 | asmState.getTypeAliasDefs()) { |
423 | if (containsPosition(type.definition)) |
424 | return; |
425 | } |
426 | } |
427 | |
428 | void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri, |
429 | const lsp::Position &pos, |
430 | std::vector<lsp::Location> &references) { |
431 | // Functor used to append all of the definitions/uses of the given SM |
432 | // definition to the reference list. |
433 | auto appendSMDef = [&](const AsmParserState::SMDefinition &def) { |
434 | references.emplace_back(args: uri, args&: sourceMgr, args: def.loc); |
435 | for (const SMRange &use : def.uses) |
436 | references.emplace_back(args: uri, args&: sourceMgr, args: use); |
437 | }; |
438 | |
439 | SMLoc posLoc = pos.getAsSMLoc(mgr&: sourceMgr); |
440 | |
441 | // Check all definitions related to operations. |
442 | for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { |
443 | if (contains(range: op.loc, loc: posLoc)) { |
444 | for (const auto &result : op.resultGroups) |
445 | appendSMDef(result.definition); |
446 | for (const auto &symUse : op.symbolUses) |
447 | if (contains(range: symUse, loc: posLoc)) |
448 | references.emplace_back(args: uri, args&: sourceMgr, args: symUse); |
449 | return; |
450 | } |
451 | for (const auto &result : op.resultGroups) |
452 | if (isDefOrUse(def: result.definition, loc: posLoc)) |
453 | return appendSMDef(result.definition); |
454 | for (const auto &symUse : op.symbolUses) { |
455 | if (!contains(range: symUse, loc: posLoc)) |
456 | continue; |
457 | for (const auto &symUse : op.symbolUses) |
458 | references.emplace_back(args: uri, args&: sourceMgr, args: symUse); |
459 | return; |
460 | } |
461 | } |
462 | |
463 | // Check all definitions related to blocks. |
464 | for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { |
465 | if (isDefOrUse(def: block.definition, loc: posLoc)) |
466 | return appendSMDef(block.definition); |
467 | |
468 | for (const AsmParserState::SMDefinition &arg : block.arguments) |
469 | if (isDefOrUse(def: arg, loc: posLoc)) |
470 | return appendSMDef(arg); |
471 | } |
472 | |
473 | // Check all alias definitions. |
474 | for (const AsmParserState::AttributeAliasDefinition &attr : |
475 | asmState.getAttributeAliasDefs()) { |
476 | if (isDefOrUse(def: attr.definition, loc: posLoc)) |
477 | return appendSMDef(attr.definition); |
478 | } |
479 | for (const AsmParserState::TypeAliasDefinition &type : |
480 | asmState.getTypeAliasDefs()) { |
481 | if (isDefOrUse(def: type.definition, loc: posLoc)) |
482 | return appendSMDef(type.definition); |
483 | } |
484 | } |
485 | |
486 | //===----------------------------------------------------------------------===// |
487 | // MLIRDocument: Hover |
488 | //===----------------------------------------------------------------------===// |
489 | |
490 | std::optional<lsp::Hover> |
491 | MLIRDocument::findHover(const lsp::URIForFile &uri, |
492 | const lsp::Position &hoverPos) { |
493 | SMLoc posLoc = hoverPos.getAsSMLoc(mgr&: sourceMgr); |
494 | SMRange hoverRange; |
495 | |
496 | // Check for Hovers on operations and results. |
497 | for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) { |
498 | // Check if the position points at this operation. |
499 | if (contains(range: op.loc, loc: posLoc)) |
500 | return buildHoverForOperation(hoverRange: op.loc, op); |
501 | |
502 | // Check if the position points at the symbol name. |
503 | for (auto &use : op.symbolUses) |
504 | if (contains(range: use, loc: posLoc)) |
505 | return buildHoverForOperation(hoverRange: use, op); |
506 | |
507 | // Check if the position points at a result group. |
508 | for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) { |
509 | const auto &result = op.resultGroups[i]; |
510 | if (!isDefOrUse(def: result.definition, loc: posLoc, overlappedRange: &hoverRange)) |
511 | continue; |
512 | |
513 | // Get the range of results covered by the over position. |
514 | unsigned resultStart = result.startIndex; |
515 | unsigned resultEnd = (i == e - 1) ? op.op->getNumResults() |
516 | : op.resultGroups[i + 1].startIndex; |
517 | return buildHoverForOperationResult(hoverRange, op: op.op, resultStart, |
518 | resultEnd, posLoc); |
519 | } |
520 | } |
521 | |
522 | // Check to see if the hover is over a block argument. |
523 | for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) { |
524 | if (isDefOrUse(def: block.definition, loc: posLoc, overlappedRange: &hoverRange)) |
525 | return buildHoverForBlock(hoverRange, block); |
526 | |
527 | for (const auto &arg : llvm::enumerate(First: block.arguments)) { |
528 | if (!isDefOrUse(def: arg.value(), loc: posLoc, overlappedRange: &hoverRange)) |
529 | continue; |
530 | |
531 | return buildHoverForBlockArgument( |
532 | hoverRange, arg: block.block->getArgument(i: arg.index()), block); |
533 | } |
534 | } |
535 | |
536 | // Check to see if the hover is over an alias. |
537 | for (const AsmParserState::AttributeAliasDefinition &attr : |
538 | asmState.getAttributeAliasDefs()) { |
539 | if (isDefOrUse(def: attr.definition, loc: posLoc, overlappedRange: &hoverRange)) |
540 | return buildHoverForAttributeAlias(hoverRange, attr); |
541 | } |
542 | for (const AsmParserState::TypeAliasDefinition &type : |
543 | asmState.getTypeAliasDefs()) { |
544 | if (isDefOrUse(def: type.definition, loc: posLoc, overlappedRange: &hoverRange)) |
545 | return buildHoverForTypeAlias(hoverRange, type); |
546 | } |
547 | |
548 | return std::nullopt; |
549 | } |
550 | |
551 | std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation( |
552 | SMRange hoverRange, const AsmParserState::OperationDefinition &op) { |
553 | lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); |
554 | llvm::raw_string_ostream os(hover.contents.value); |
555 | |
556 | // Add the operation name to the hover. |
557 | os << "\"" << op.op->getName() << "\"" ; |
558 | if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op)) |
559 | os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "" ; |
560 | os << "\n\n" ; |
561 | |
562 | os << "Generic Form:\n\n```mlir\n" ; |
563 | |
564 | op.op->print(os, flags: OpPrintingFlags() |
565 | .printGenericOpForm() |
566 | .elideLargeElementsAttrs() |
567 | .skipRegions()); |
568 | os << "\n```\n" ; |
569 | |
570 | return hover; |
571 | } |
572 | |
573 | lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange, |
574 | Operation *op, |
575 | unsigned resultStart, |
576 | unsigned resultEnd, |
577 | SMLoc posLoc) { |
578 | lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); |
579 | llvm::raw_string_ostream os(hover.contents.value); |
580 | |
581 | // Add the parent operation name to the hover. |
582 | os << "Operation: \"" << op->getName() << "\"\n\n" ; |
583 | |
584 | // Check to see if the location points to a specific result within the |
585 | // group. |
586 | if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(loc: posLoc)) { |
587 | if ((resultStart + *resultNumber) < resultEnd) { |
588 | resultStart += *resultNumber; |
589 | resultEnd = resultStart + 1; |
590 | } |
591 | } |
592 | |
593 | // Add the range of results and their types to the hover info. |
594 | if ((resultStart + 1) == resultEnd) { |
595 | os << "Result #" << resultStart << "\n\n" |
596 | << "Type: `" << op->getResult(idx: resultStart).getType() << "`\n\n" ; |
597 | } else { |
598 | os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n" |
599 | << "Types: " ; |
600 | llvm::interleaveComma( |
601 | c: op->getResults().slice(n: resultStart, m: resultEnd), os, |
602 | each_fn: [&](Value result) { os << "`" << result.getType() << "`" ; }); |
603 | } |
604 | |
605 | return hover; |
606 | } |
607 | |
608 | lsp::Hover |
609 | MLIRDocument::buildHoverForBlock(SMRange hoverRange, |
610 | const AsmParserState::BlockDefinition &block) { |
611 | lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); |
612 | llvm::raw_string_ostream os(hover.contents.value); |
613 | |
614 | // Print the given block to the hover output stream. |
615 | auto printBlockToHover = [&](Block *newBlock) { |
616 | if (const auto *def = asmState.getBlockDef(block: newBlock)) |
617 | printDefBlockName(os, def: *def); |
618 | else |
619 | printDefBlockName(os, block: newBlock); |
620 | }; |
621 | |
622 | // Display the parent operation, block number, predecessors, and successors. |
623 | os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" |
624 | << "Block #" << getBlockNumber(block: block.block) << "\n\n" ; |
625 | if (!block.block->hasNoPredecessors()) { |
626 | os << "Predecessors: " ; |
627 | llvm::interleaveComma(c: block.block->getPredecessors(), os, |
628 | each_fn: printBlockToHover); |
629 | os << "\n\n" ; |
630 | } |
631 | if (!block.block->hasNoSuccessors()) { |
632 | os << "Successors: " ; |
633 | llvm::interleaveComma(c: block.block->getSuccessors(), os, each_fn: printBlockToHover); |
634 | os << "\n\n" ; |
635 | } |
636 | |
637 | return hover; |
638 | } |
639 | |
640 | lsp::Hover MLIRDocument::buildHoverForBlockArgument( |
641 | SMRange hoverRange, BlockArgument arg, |
642 | const AsmParserState::BlockDefinition &block) { |
643 | lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); |
644 | llvm::raw_string_ostream os(hover.contents.value); |
645 | |
646 | // Display the parent operation, block, the argument number, and the type. |
647 | os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n" |
648 | << "Block: " ; |
649 | printDefBlockName(os, def: block); |
650 | os << "\n\nArgument #" << arg.getArgNumber() << "\n\n" |
651 | << "Type: `" << arg.getType() << "`\n\n" ; |
652 | |
653 | return hover; |
654 | } |
655 | |
656 | lsp::Hover MLIRDocument::buildHoverForAttributeAlias( |
657 | SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) { |
658 | lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); |
659 | llvm::raw_string_ostream os(hover.contents.value); |
660 | |
661 | os << "Attribute Alias: \"" << attr.name << "\n\n" ; |
662 | os << "Value: ```mlir\n" << attr.value << "\n```\n\n" ; |
663 | |
664 | return hover; |
665 | } |
666 | |
667 | lsp::Hover MLIRDocument::buildHoverForTypeAlias( |
668 | SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) { |
669 | lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); |
670 | llvm::raw_string_ostream os(hover.contents.value); |
671 | |
672 | os << "Type Alias: \"" << type.name << "\n\n" ; |
673 | os << "Value: ```mlir\n" << type.value << "\n```\n\n" ; |
674 | |
675 | return hover; |
676 | } |
677 | |
678 | //===----------------------------------------------------------------------===// |
679 | // MLIRDocument: Document Symbols |
680 | //===----------------------------------------------------------------------===// |
681 | |
682 | void MLIRDocument::findDocumentSymbols( |
683 | std::vector<lsp::DocumentSymbol> &symbols) { |
684 | for (Operation &op : parsedIR) |
685 | findDocumentSymbols(op: &op, symbols); |
686 | } |
687 | |
688 | void MLIRDocument::findDocumentSymbols( |
689 | Operation *op, std::vector<lsp::DocumentSymbol> &symbols) { |
690 | std::vector<lsp::DocumentSymbol> *childSymbols = &symbols; |
691 | |
692 | // Check for the source information of this operation. |
693 | if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) { |
694 | // If this operation defines a symbol, record it. |
695 | if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) { |
696 | symbols.emplace_back(symbol.getName(), |
697 | isa<FunctionOpInterface>(Val: op) |
698 | ? lsp::SymbolKind::Function |
699 | : lsp::SymbolKind::Class, |
700 | lsp::Range(sourceMgr, def->scopeLoc), |
701 | lsp::Range(sourceMgr, def->loc)); |
702 | childSymbols = &symbols.back().children; |
703 | |
704 | } else if (op->hasTrait<OpTrait::SymbolTable>()) { |
705 | // Otherwise, if this is a symbol table push an anonymous document symbol. |
706 | symbols.emplace_back(args: "<" + op->getName().getStringRef() + ">" , |
707 | args: lsp::SymbolKind::Namespace, |
708 | args: lsp::Range(sourceMgr, def->scopeLoc), |
709 | args: lsp::Range(sourceMgr, def->loc)); |
710 | childSymbols = &symbols.back().children; |
711 | } |
712 | } |
713 | |
714 | // Recurse into the regions of this operation. |
715 | if (!op->getNumRegions()) |
716 | return; |
717 | for (Region ®ion : op->getRegions()) |
718 | for (Operation &childOp : region.getOps()) |
719 | findDocumentSymbols(op: &childOp, symbols&: *childSymbols); |
720 | } |
721 | |
722 | //===----------------------------------------------------------------------===// |
723 | // MLIRDocument: Code Completion |
724 | //===----------------------------------------------------------------------===// |
725 | |
726 | namespace { |
727 | class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { |
728 | public: |
729 | LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList, |
730 | MLIRContext *ctx) |
731 | : AsmParserCodeCompleteContext(completeLoc), |
732 | completionList(completionList), ctx(ctx) {} |
733 | |
734 | /// Signal code completion for a dialect name, with an optional prefix. |
735 | void completeDialectName(StringRef prefix) final { |
736 | for (StringRef dialect : ctx->getAvailableDialects()) { |
737 | lsp::CompletionItem item(prefix + dialect, |
738 | lsp::CompletionItemKind::Module, |
739 | /*sortText=*/"3" ); |
740 | item.detail = "dialect" ; |
741 | completionList.items.emplace_back(args&: item); |
742 | } |
743 | } |
744 | using AsmParserCodeCompleteContext::completeDialectName; |
745 | |
746 | /// Signal code completion for an operation name within the given dialect. |
747 | void completeOperationName(StringRef dialectName) final { |
748 | Dialect *dialect = ctx->getOrLoadDialect(name: dialectName); |
749 | if (!dialect) |
750 | return; |
751 | |
752 | for (const auto &op : ctx->getRegisteredOperations()) { |
753 | if (&op.getDialect() != dialect) |
754 | continue; |
755 | |
756 | lsp::CompletionItem item( |
757 | op.getStringRef().drop_front(N: dialectName.size() + 1), |
758 | lsp::CompletionItemKind::Field, |
759 | /*sortText=*/"1" ); |
760 | item.detail = "operation" ; |
761 | completionList.items.emplace_back(args&: item); |
762 | } |
763 | } |
764 | |
765 | /// Append the given SSA value as a code completion result for SSA value |
766 | /// completions. |
767 | void appendSSAValueCompletion(StringRef name, std::string typeData) final { |
768 | // Check if we need to insert the `%` or not. |
769 | bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%'; |
770 | |
771 | lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable); |
772 | if (stripPrefix) |
773 | item.insertText = name.drop_front(N: 1).str(); |
774 | item.detail = std::move(typeData); |
775 | completionList.items.emplace_back(args&: item); |
776 | } |
777 | |
778 | /// Append the given block as a code completion result for block name |
779 | /// completions. |
780 | void appendBlockCompletion(StringRef name) final { |
781 | // Check if we need to insert the `^` or not. |
782 | bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^'; |
783 | |
784 | lsp::CompletionItem item(name, lsp::CompletionItemKind::Field); |
785 | if (stripPrefix) |
786 | item.insertText = name.drop_front(N: 1).str(); |
787 | completionList.items.emplace_back(args&: item); |
788 | } |
789 | |
790 | /// Signal a completion for the given expected token. |
791 | void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final { |
792 | for (StringRef token : tokens) { |
793 | lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword, |
794 | /*sortText=*/"0" ); |
795 | item.detail = optional ? "optional" : "" ; |
796 | completionList.items.emplace_back(args&: item); |
797 | } |
798 | } |
799 | |
800 | /// Signal a completion for an attribute. |
801 | void completeAttribute(const llvm::StringMap<Attribute> &aliases) override { |
802 | appendSimpleCompletions(completions: {"affine_set" , "affine_map" , "dense" , |
803 | "dense_resource" , "false" , "loc" , "sparse" , "true" , |
804 | "unit" }, |
805 | kind: lsp::CompletionItemKind::Field, |
806 | /*sortText=*/"1" ); |
807 | |
808 | completeDialectName(prefix: "#" ); |
809 | completeAliases(aliases, prefix: "#" ); |
810 | } |
811 | void completeDialectAttributeOrAlias( |
812 | const llvm::StringMap<Attribute> &aliases) override { |
813 | completeDialectName(); |
814 | completeAliases(aliases); |
815 | } |
816 | |
817 | /// Signal a completion for a type. |
818 | void completeType(const llvm::StringMap<Type> &aliases) override { |
819 | // Handle the various builtin types. |
820 | appendSimpleCompletions(completions: {"memref" , "tensor" , "complex" , "tuple" , "vector" , |
821 | "bf16" , "f16" , "f32" , "f64" , "f80" , "f128" , |
822 | "index" , "none" }, |
823 | kind: lsp::CompletionItemKind::Field, |
824 | /*sortText=*/"1" ); |
825 | |
826 | // Handle the builtin integer types. |
827 | for (StringRef type : {"i" , "si" , "ui" }) { |
828 | lsp::CompletionItem item(type + "<N>" , lsp::CompletionItemKind::Field, |
829 | /*sortText=*/"1" ); |
830 | item.insertText = type.str(); |
831 | completionList.items.emplace_back(args&: item); |
832 | } |
833 | |
834 | // Insert completions for dialect types and aliases. |
835 | completeDialectName(prefix: "!" ); |
836 | completeAliases(aliases, prefix: "!" ); |
837 | } |
838 | void |
839 | completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override { |
840 | completeDialectName(); |
841 | completeAliases(aliases); |
842 | } |
843 | |
844 | /// Add completion results for the given set of aliases. |
845 | template <typename T> |
846 | void completeAliases(const llvm::StringMap<T> &aliases, |
847 | StringRef prefix = "" ) { |
848 | for (const auto &alias : aliases) { |
849 | lsp::CompletionItem item(prefix + alias.getKey(), |
850 | lsp::CompletionItemKind::Field, |
851 | /*sortText=*/"2" ); |
852 | llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue(); |
853 | completionList.items.emplace_back(args&: item); |
854 | } |
855 | } |
856 | |
857 | /// Add a set of simple completions that all have the same kind. |
858 | void appendSimpleCompletions(ArrayRef<StringRef> completions, |
859 | lsp::CompletionItemKind kind, |
860 | StringRef sortText = "" ) { |
861 | for (StringRef completion : completions) |
862 | completionList.items.emplace_back(args&: completion, args&: kind, args&: sortText); |
863 | } |
864 | |
865 | private: |
866 | lsp::CompletionList &completionList; |
867 | MLIRContext *ctx; |
868 | }; |
869 | } // namespace |
870 | |
871 | lsp::CompletionList |
872 | MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri, |
873 | const lsp::Position &completePos, |
874 | const DialectRegistry ®istry) { |
875 | SMLoc posLoc = completePos.getAsSMLoc(mgr&: sourceMgr); |
876 | if (!posLoc.isValid()) |
877 | return lsp::CompletionList(); |
878 | |
879 | // To perform code completion, we run another parse of the module with the |
880 | // code completion context provided. |
881 | MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED); |
882 | tmpContext.allowUnregisteredDialects(); |
883 | lsp::CompletionList completionList; |
884 | LSPCodeCompleteContext lspCompleteContext(posLoc, completionList, |
885 | &tmpContext); |
886 | |
887 | Block tmpIR; |
888 | AsmParserState tmpState; |
889 | (void)parseAsmSourceFile(sourceMgr, block: &tmpIR, config: &tmpContext, asmState: &tmpState, |
890 | codeCompleteContext: &lspCompleteContext); |
891 | return completionList; |
892 | } |
893 | |
894 | //===----------------------------------------------------------------------===// |
895 | // MLIRDocument: Code Action |
896 | //===----------------------------------------------------------------------===// |
897 | |
898 | void MLIRDocument::getCodeActionForDiagnostic( |
899 | const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity, |
900 | StringRef message, std::vector<lsp::TextEdit> &edits) { |
901 | // Ignore diagnostics that print the current operation. These are always |
902 | // enabled for the language server, but not generally during normal |
903 | // parsing/verification. |
904 | if (message.starts_with(Prefix: "see current operation: " )) |
905 | return; |
906 | |
907 | // Get the start of the line containing the diagnostic. |
908 | const auto &buffer = sourceMgr.getBufferInfo(i: sourceMgr.getMainFileID()); |
909 | const char *lineStart = buffer.getPointerForLineNumber(LineNo: pos.line + 1); |
910 | if (!lineStart) |
911 | return; |
912 | StringRef line(lineStart, pos.character); |
913 | |
914 | // Add a text edit for adding an expected-* diagnostic check for this |
915 | // diagnostic. |
916 | lsp::TextEdit edit; |
917 | edit.range = lsp::Range(lsp::Position(pos.line, 0)); |
918 | |
919 | // Use the indent of the current line for the expected-* diagnostic. |
920 | size_t indent = line.find_first_not_of(Chars: " " ); |
921 | if (indent == StringRef::npos) |
922 | indent = line.size(); |
923 | |
924 | edit.newText.append(n: indent, c: ' '); |
925 | llvm::raw_string_ostream(edit.newText) |
926 | << "// expected-" << severity << " @below {{" << message << "}}\n" ; |
927 | edits.emplace_back(args: std::move(edit)); |
928 | } |
929 | |
930 | //===----------------------------------------------------------------------===// |
931 | // MLIRDocument: Bytecode |
932 | //===----------------------------------------------------------------------===// |
933 | |
934 | llvm::Expected<lsp::MLIRConvertBytecodeResult> |
935 | MLIRDocument::convertToBytecode() { |
936 | // TODO: We currently require a single top-level operation, but this could |
937 | // conceptually be relaxed. |
938 | if (!llvm::hasSingleElement(C&: parsedIR)) { |
939 | if (parsedIR.empty()) { |
940 | return llvm::make_error<lsp::LSPError>( |
941 | Args: "expected a single and valid top-level operation, please ensure " |
942 | "there are no errors" , |
943 | Args: lsp::ErrorCode::RequestFailed); |
944 | } |
945 | return llvm::make_error<lsp::LSPError>( |
946 | Args: "expected a single top-level operation" , Args: lsp::ErrorCode::RequestFailed); |
947 | } |
948 | |
949 | lsp::MLIRConvertBytecodeResult result; |
950 | { |
951 | BytecodeWriterConfig writerConfig(fallbackResourceMap); |
952 | |
953 | std::string rawBytecodeBuffer; |
954 | llvm::raw_string_ostream os(rawBytecodeBuffer); |
955 | // No desired bytecode version set, so no need to check for error. |
956 | (void)writeBytecodeToFile(op: &parsedIR.front(), os, config: writerConfig); |
957 | result.output = llvm::encodeBase64(Bytes: rawBytecodeBuffer); |
958 | } |
959 | return result; |
960 | } |
961 | |
962 | //===----------------------------------------------------------------------===// |
963 | // MLIRTextFileChunk |
964 | //===----------------------------------------------------------------------===// |
965 | |
966 | namespace { |
967 | /// This class represents a single chunk of an MLIR text file. |
968 | struct MLIRTextFileChunk { |
969 | MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset, |
970 | const lsp::URIForFile &uri, StringRef contents, |
971 | std::vector<lsp::Diagnostic> &diagnostics) |
972 | : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {} |
973 | |
974 | /// Adjust the line number of the given range to anchor at the beginning of |
975 | /// the file, instead of the beginning of this chunk. |
976 | void adjustLocForChunkOffset(lsp::Range &range) { |
977 | adjustLocForChunkOffset(pos&: range.start); |
978 | adjustLocForChunkOffset(pos&: range.end); |
979 | } |
980 | /// Adjust the line number of the given position to anchor at the beginning of |
981 | /// the file, instead of the beginning of this chunk. |
982 | void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } |
983 | |
984 | /// The line offset of this chunk from the beginning of the file. |
985 | uint64_t lineOffset; |
986 | /// The document referred to by this chunk. |
987 | MLIRDocument document; |
988 | }; |
989 | } // namespace |
990 | |
991 | //===----------------------------------------------------------------------===// |
992 | // MLIRTextFile |
993 | //===----------------------------------------------------------------------===// |
994 | |
995 | namespace { |
996 | /// This class represents a text file containing one or more MLIR documents. |
997 | class MLIRTextFile { |
998 | public: |
999 | MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, |
1000 | int64_t version, DialectRegistry ®istry, |
1001 | std::vector<lsp::Diagnostic> &diagnostics); |
1002 | |
1003 | /// Return the current version of this text file. |
1004 | int64_t getVersion() const { return version; } |
1005 | |
1006 | //===--------------------------------------------------------------------===// |
1007 | // LSP Queries |
1008 | //===--------------------------------------------------------------------===// |
1009 | |
1010 | void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, |
1011 | std::vector<lsp::Location> &locations); |
1012 | void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, |
1013 | std::vector<lsp::Location> &references); |
1014 | std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri, |
1015 | lsp::Position hoverPos); |
1016 | void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols); |
1017 | lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, |
1018 | lsp::Position completePos); |
1019 | void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos, |
1020 | const lsp::CodeActionContext &context, |
1021 | std::vector<lsp::CodeAction> &actions); |
1022 | llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode(); |
1023 | |
1024 | private: |
1025 | /// Find the MLIR document that contains the given position, and update the |
1026 | /// position to be anchored at the start of the found chunk instead of the |
1027 | /// beginning of the file. |
1028 | MLIRTextFileChunk &getChunkFor(lsp::Position &pos); |
1029 | |
1030 | /// The context used to hold the state contained by the parsed document. |
1031 | MLIRContext context; |
1032 | |
1033 | /// The full string contents of the file. |
1034 | std::string contents; |
1035 | |
1036 | /// The version of this file. |
1037 | int64_t version; |
1038 | |
1039 | /// The number of lines in the file. |
1040 | int64_t totalNumLines = 0; |
1041 | |
1042 | /// The chunks of this file. The order of these chunks is the order in which |
1043 | /// they appear in the text file. |
1044 | std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks; |
1045 | }; |
1046 | } // namespace |
1047 | |
1048 | MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, |
1049 | int64_t version, DialectRegistry ®istry, |
1050 | std::vector<lsp::Diagnostic> &diagnostics) |
1051 | : context(registry, MLIRContext::Threading::DISABLED), |
1052 | contents(fileContents.str()), version(version) { |
1053 | context.allowUnregisteredDialects(); |
1054 | |
1055 | // Split the file into separate MLIR documents. |
1056 | SmallVector<StringRef, 8> subContents; |
1057 | StringRef(contents).split(A&: subContents, Separator: kDefaultSplitMarker); |
1058 | chunks.emplace_back(args: std::make_unique<MLIRTextFileChunk>( |
1059 | args&: context, /*lineOffset=*/args: 0, args: uri, args&: subContents.front(), args&: diagnostics)); |
1060 | |
1061 | uint64_t lineOffset = subContents.front().count(C: '\n'); |
1062 | for (StringRef docContents : llvm::drop_begin(RangeOrContainer&: subContents)) { |
1063 | unsigned currentNumDiags = diagnostics.size(); |
1064 | auto chunk = std::make_unique<MLIRTextFileChunk>(args&: context, args&: lineOffset, args: uri, |
1065 | args&: docContents, args&: diagnostics); |
1066 | lineOffset += docContents.count(C: '\n'); |
1067 | |
1068 | // Adjust locations used in diagnostics to account for the offset from the |
1069 | // beginning of the file. |
1070 | for (lsp::Diagnostic &diag : |
1071 | llvm::drop_begin(RangeOrContainer&: diagnostics, N: currentNumDiags)) { |
1072 | chunk->adjustLocForChunkOffset(range&: diag.range); |
1073 | |
1074 | if (!diag.relatedInformation) |
1075 | continue; |
1076 | for (auto &it : *diag.relatedInformation) |
1077 | if (it.location.uri == uri) |
1078 | chunk->adjustLocForChunkOffset(range&: it.location.range); |
1079 | } |
1080 | chunks.emplace_back(args: std::move(chunk)); |
1081 | } |
1082 | totalNumLines = lineOffset; |
1083 | } |
1084 | |
1085 | void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri, |
1086 | lsp::Position defPos, |
1087 | std::vector<lsp::Location> &locations) { |
1088 | MLIRTextFileChunk &chunk = getChunkFor(pos&: defPos); |
1089 | chunk.document.getLocationsOf(uri, defPos, locations); |
1090 | |
1091 | // Adjust any locations within this file for the offset of this chunk. |
1092 | if (chunk.lineOffset == 0) |
1093 | return; |
1094 | for (lsp::Location &loc : locations) |
1095 | if (loc.uri == uri) |
1096 | chunk.adjustLocForChunkOffset(range&: loc.range); |
1097 | } |
1098 | |
1099 | void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri, |
1100 | lsp::Position pos, |
1101 | std::vector<lsp::Location> &references) { |
1102 | MLIRTextFileChunk &chunk = getChunkFor(pos); |
1103 | chunk.document.findReferencesOf(uri, pos, references); |
1104 | |
1105 | // Adjust any locations within this file for the offset of this chunk. |
1106 | if (chunk.lineOffset == 0) |
1107 | return; |
1108 | for (lsp::Location &loc : references) |
1109 | if (loc.uri == uri) |
1110 | chunk.adjustLocForChunkOffset(range&: loc.range); |
1111 | } |
1112 | |
1113 | std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri, |
1114 | lsp::Position hoverPos) { |
1115 | MLIRTextFileChunk &chunk = getChunkFor(pos&: hoverPos); |
1116 | std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos); |
1117 | |
1118 | // Adjust any locations within this file for the offset of this chunk. |
1119 | if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) |
1120 | chunk.adjustLocForChunkOffset(range&: *hoverInfo->range); |
1121 | return hoverInfo; |
1122 | } |
1123 | |
1124 | void MLIRTextFile::findDocumentSymbols( |
1125 | std::vector<lsp::DocumentSymbol> &symbols) { |
1126 | if (chunks.size() == 1) |
1127 | return chunks.front()->document.findDocumentSymbols(symbols); |
1128 | |
1129 | // If there are multiple chunks in this file, we create top-level symbols for |
1130 | // each chunk. |
1131 | for (unsigned i = 0, e = chunks.size(); i < e; ++i) { |
1132 | MLIRTextFileChunk &chunk = *chunks[i]; |
1133 | lsp::Position startPos(chunk.lineOffset); |
1134 | lsp::Position endPos((i == e - 1) ? totalNumLines - 1 |
1135 | : chunks[i + 1]->lineOffset); |
1136 | lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">" , |
1137 | lsp::SymbolKind::Namespace, |
1138 | /*range=*/lsp::Range(startPos, endPos), |
1139 | /*selectionRange=*/lsp::Range(startPos)); |
1140 | chunk.document.findDocumentSymbols(symbols&: symbol.children); |
1141 | |
1142 | // Fixup the locations of document symbols within this chunk. |
1143 | if (i != 0) { |
1144 | SmallVector<lsp::DocumentSymbol *> symbolsToFix; |
1145 | for (lsp::DocumentSymbol &childSymbol : symbol.children) |
1146 | symbolsToFix.push_back(Elt: &childSymbol); |
1147 | |
1148 | while (!symbolsToFix.empty()) { |
1149 | lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); |
1150 | chunk.adjustLocForChunkOffset(range&: symbol->range); |
1151 | chunk.adjustLocForChunkOffset(range&: symbol->selectionRange); |
1152 | |
1153 | for (lsp::DocumentSymbol &childSymbol : symbol->children) |
1154 | symbolsToFix.push_back(Elt: &childSymbol); |
1155 | } |
1156 | } |
1157 | |
1158 | // Push the symbol for this chunk. |
1159 | symbols.emplace_back(args: std::move(symbol)); |
1160 | } |
1161 | } |
1162 | |
1163 | lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri, |
1164 | lsp::Position completePos) { |
1165 | MLIRTextFileChunk &chunk = getChunkFor(pos&: completePos); |
1166 | lsp::CompletionList completionList = chunk.document.getCodeCompletion( |
1167 | uri, completePos, registry: context.getDialectRegistry()); |
1168 | |
1169 | // Adjust any completion locations. |
1170 | for (lsp::CompletionItem &item : completionList.items) { |
1171 | if (item.textEdit) |
1172 | chunk.adjustLocForChunkOffset(range&: item.textEdit->range); |
1173 | for (lsp::TextEdit &edit : item.additionalTextEdits) |
1174 | chunk.adjustLocForChunkOffset(range&: edit.range); |
1175 | } |
1176 | return completionList; |
1177 | } |
1178 | |
1179 | void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, |
1180 | const lsp::Range &pos, |
1181 | const lsp::CodeActionContext &context, |
1182 | std::vector<lsp::CodeAction> &actions) { |
1183 | // Create actions for any diagnostics in this file. |
1184 | for (auto &diag : context.diagnostics) { |
1185 | if (diag.source != "mlir" ) |
1186 | continue; |
1187 | lsp::Position diagPos = diag.range.start; |
1188 | MLIRTextFileChunk &chunk = getChunkFor(pos&: diagPos); |
1189 | |
1190 | // Add a new code action that inserts a "expected" diagnostic check. |
1191 | lsp::CodeAction action; |
1192 | action.title = "Add expected-* diagnostic checks" ; |
1193 | action.kind = lsp::CodeAction::kQuickFix.str(); |
1194 | |
1195 | StringRef severity; |
1196 | switch (diag.severity) { |
1197 | case lsp::DiagnosticSeverity::Error: |
1198 | severity = "error" ; |
1199 | break; |
1200 | case lsp::DiagnosticSeverity::Warning: |
1201 | severity = "warning" ; |
1202 | break; |
1203 | default: |
1204 | continue; |
1205 | } |
1206 | |
1207 | // Get edits for the diagnostic. |
1208 | std::vector<lsp::TextEdit> edits; |
1209 | chunk.document.getCodeActionForDiagnostic(uri, pos&: diagPos, severity, |
1210 | message: diag.message, edits); |
1211 | |
1212 | // Walk the related diagnostics, this is how we encode notes. |
1213 | if (diag.relatedInformation) { |
1214 | for (auto ¬eDiag : *diag.relatedInformation) { |
1215 | if (noteDiag.location.uri != uri) |
1216 | continue; |
1217 | diagPos = noteDiag.location.range.start; |
1218 | diagPos.line -= chunk.lineOffset; |
1219 | chunk.document.getCodeActionForDiagnostic(uri, pos&: diagPos, severity: "note" , |
1220 | message: noteDiag.message, edits); |
1221 | } |
1222 | } |
1223 | // Fixup the locations for any edits. |
1224 | for (lsp::TextEdit &edit : edits) |
1225 | chunk.adjustLocForChunkOffset(range&: edit.range); |
1226 | |
1227 | action.edit.emplace(); |
1228 | action.edit->changes[uri.uri().str()] = std::move(edits); |
1229 | action.diagnostics = {diag}; |
1230 | |
1231 | actions.emplace_back(args: std::move(action)); |
1232 | } |
1233 | } |
1234 | |
1235 | llvm::Expected<lsp::MLIRConvertBytecodeResult> |
1236 | MLIRTextFile::convertToBytecode() { |
1237 | // Bail out if there is more than one chunk, bytecode wants a single module. |
1238 | if (chunks.size() != 1) { |
1239 | return llvm::make_error<lsp::LSPError>( |
1240 | Args: "unexpected split file, please remove all `// -----`" , |
1241 | Args: lsp::ErrorCode::RequestFailed); |
1242 | } |
1243 | return chunks.front()->document.convertToBytecode(); |
1244 | } |
1245 | |
1246 | MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { |
1247 | if (chunks.size() == 1) |
1248 | return *chunks.front(); |
1249 | |
1250 | // Search for the first chunk with a greater line offset, the previous chunk |
1251 | // is the one that contains `pos`. |
1252 | auto it = llvm::upper_bound( |
1253 | Range&: chunks, Value&: pos, C: [](const lsp::Position &pos, const auto &chunk) { |
1254 | return static_cast<uint64_t>(pos.line) < chunk->lineOffset; |
1255 | }); |
1256 | MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it); |
1257 | pos.line -= chunk.lineOffset; |
1258 | return chunk; |
1259 | } |
1260 | |
1261 | //===----------------------------------------------------------------------===// |
1262 | // MLIRServer::Impl |
1263 | //===----------------------------------------------------------------------===// |
1264 | |
1265 | struct lsp::MLIRServer::Impl { |
1266 | Impl(DialectRegistry ®istry) : registry(registry) {} |
1267 | |
1268 | /// The registry containing dialects that can be recognized in parsed .mlir |
1269 | /// files. |
1270 | DialectRegistry ®istry; |
1271 | |
1272 | /// The files held by the server, mapped by their URI file name. |
1273 | llvm::StringMap<std::unique_ptr<MLIRTextFile>> files; |
1274 | }; |
1275 | |
1276 | //===----------------------------------------------------------------------===// |
1277 | // MLIRServer |
1278 | //===----------------------------------------------------------------------===// |
1279 | |
1280 | lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) |
1281 | : impl(std::make_unique<Impl>(args&: registry)) {} |
1282 | lsp::MLIRServer::~MLIRServer() = default; |
1283 | |
1284 | void lsp::MLIRServer::addOrUpdateDocument( |
1285 | const URIForFile &uri, StringRef contents, int64_t version, |
1286 | std::vector<Diagnostic> &diagnostics) { |
1287 | impl->files[uri.file()] = std::make_unique<MLIRTextFile>( |
1288 | args: uri, args&: contents, args&: version, args&: impl->registry, args&: diagnostics); |
1289 | } |
1290 | |
1291 | std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) { |
1292 | auto it = impl->files.find(Key: uri.file()); |
1293 | if (it == impl->files.end()) |
1294 | return std::nullopt; |
1295 | |
1296 | int64_t version = it->second->getVersion(); |
1297 | impl->files.erase(I: it); |
1298 | return version; |
1299 | } |
1300 | |
1301 | void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, |
1302 | const Position &defPos, |
1303 | std::vector<Location> &locations) { |
1304 | auto fileIt = impl->files.find(Key: uri.file()); |
1305 | if (fileIt != impl->files.end()) |
1306 | fileIt->second->getLocationsOf(uri, defPos, locations); |
1307 | } |
1308 | |
1309 | void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, |
1310 | const Position &pos, |
1311 | std::vector<Location> &references) { |
1312 | auto fileIt = impl->files.find(Key: uri.file()); |
1313 | if (fileIt != impl->files.end()) |
1314 | fileIt->second->findReferencesOf(uri, pos, references); |
1315 | } |
1316 | |
1317 | std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri, |
1318 | const Position &hoverPos) { |
1319 | auto fileIt = impl->files.find(Key: uri.file()); |
1320 | if (fileIt != impl->files.end()) |
1321 | return fileIt->second->findHover(uri, hoverPos); |
1322 | return std::nullopt; |
1323 | } |
1324 | |
1325 | void lsp::MLIRServer::findDocumentSymbols( |
1326 | const URIForFile &uri, std::vector<DocumentSymbol> &symbols) { |
1327 | auto fileIt = impl->files.find(Key: uri.file()); |
1328 | if (fileIt != impl->files.end()) |
1329 | fileIt->second->findDocumentSymbols(symbols); |
1330 | } |
1331 | |
1332 | lsp::CompletionList |
1333 | lsp::MLIRServer::getCodeCompletion(const URIForFile &uri, |
1334 | const Position &completePos) { |
1335 | auto fileIt = impl->files.find(Key: uri.file()); |
1336 | if (fileIt != impl->files.end()) |
1337 | return fileIt->second->getCodeCompletion(uri, completePos); |
1338 | return CompletionList(); |
1339 | } |
1340 | |
1341 | void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos, |
1342 | const CodeActionContext &context, |
1343 | std::vector<CodeAction> &actions) { |
1344 | auto fileIt = impl->files.find(Key: uri.file()); |
1345 | if (fileIt != impl->files.end()) |
1346 | fileIt->second->getCodeActions(uri, pos, context, actions); |
1347 | } |
1348 | |
1349 | llvm::Expected<lsp::MLIRConvertBytecodeResult> |
1350 | lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { |
1351 | MLIRContext tempContext(impl->registry); |
1352 | tempContext.allowUnregisteredDialects(); |
1353 | |
1354 | // Collect any errors during parsing. |
1355 | std::string errorMsg; |
1356 | ScopedDiagnosticHandler diagHandler( |
1357 | &tempContext, |
1358 | [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n" ; }); |
1359 | |
1360 | // Handling for external resources, which we want to propagate up to the user. |
1361 | FallbackAsmResourceMap fallbackResourceMap; |
1362 | |
1363 | // Setup the parser config. |
1364 | ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true, |
1365 | &fallbackResourceMap); |
1366 | |
1367 | // Try to parse the given source file. |
1368 | Block parsedBlock; |
1369 | if (failed(result: parseSourceFile(filename: uri.file(), block: &parsedBlock, config: parserConfig))) { |
1370 | return llvm::make_error<lsp::LSPError>( |
1371 | Args: "failed to parse bytecode source file: " + errorMsg, |
1372 | Args: lsp::ErrorCode::RequestFailed); |
1373 | } |
1374 | |
1375 | // TODO: We currently expect a single top-level operation, but this could |
1376 | // conceptually be relaxed. |
1377 | if (!llvm::hasSingleElement(C&: parsedBlock)) { |
1378 | return llvm::make_error<lsp::LSPError>( |
1379 | Args: "expected bytecode to contain a single top-level operation" , |
1380 | Args: lsp::ErrorCode::RequestFailed); |
1381 | } |
1382 | |
1383 | // Print the module to a buffer. |
1384 | lsp::MLIRConvertBytecodeResult result; |
1385 | { |
1386 | // Extract the top-level op so that aliases get printed. |
1387 | // FIXME: We should be able to enable aliases without having to do this! |
1388 | OwningOpRef<Operation *> topOp = &parsedBlock.front(); |
1389 | topOp->remove(); |
1390 | |
1391 | AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(), |
1392 | /*locationMap=*/nullptr, &fallbackResourceMap); |
1393 | |
1394 | llvm::raw_string_ostream os(result.output); |
1395 | topOp->print(os, state); |
1396 | } |
1397 | return std::move(result); |
1398 | } |
1399 | |
1400 | llvm::Expected<lsp::MLIRConvertBytecodeResult> |
1401 | lsp::MLIRServer::convertToBytecode(const URIForFile &uri) { |
1402 | auto fileIt = impl->files.find(Key: uri.file()); |
1403 | if (fileIt == impl->files.end()) { |
1404 | return llvm::make_error<lsp::LSPError>( |
1405 | Args: "language server does not contain an entry for this source file" , |
1406 | Args: lsp::ErrorCode::RequestFailed); |
1407 | } |
1408 | return fileIt->second->convertToBytecode(); |
1409 | } |
1410 | |