1//===- MLIRServer.cpp - MLIR Generic Language Server ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "MLIRServer.h"
10#include "Protocol.h"
11#include "mlir/AsmParser/AsmParser.h"
12#include "mlir/AsmParser/AsmParserState.h"
13#include "mlir/AsmParser/CodeComplete.h"
14#include "mlir/Bytecode/BytecodeWriter.h"
15#include "mlir/IR/Operation.h"
16#include "mlir/Interfaces/FunctionInterfaces.h"
17#include "mlir/Parser/Parser.h"
18#include "mlir/Support/ToolUtilities.h"
19#include "mlir/Tools/lsp-server-support/Logging.h"
20#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
21#include "llvm/ADT/StringExtras.h"
22#include "llvm/Support/Base64.h"
23#include "llvm/Support/SourceMgr.h"
24#include <optional>
25
26using namespace mlir;
27
28/// Returns the range of a lexical token given a SMLoc corresponding to the
29/// start of an token location. The range is computed heuristically, and
30/// supports identifier-like tokens, strings, etc.
31static SMRange convertTokenLocToRange(SMLoc loc) {
32 return lsp::convertTokenLocToRange(loc, identifierChars: "$-.");
33}
34
35/// Returns a language server location from the given MLIR file location.
36/// `uriScheme` is the scheme to use when building new uris.
37static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
38 FileLineColLoc loc) {
39 llvm::Expected<lsp::URIForFile> sourceURI =
40 lsp::URIForFile::fromFile(absoluteFilepath: loc.getFilename(), scheme: uriScheme);
41 if (!sourceURI) {
42 lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
43 loc.getFilename(),
44 llvm::toString(E: sourceURI.takeError()));
45 return std::nullopt;
46 }
47
48 lsp::Position position;
49 position.line = loc.getLine() - 1;
50 position.character = loc.getColumn() ? loc.getColumn() - 1 : 0;
51 return lsp::Location{*sourceURI, lsp::Range(position)};
52}
53
54/// Returns a language server location from the given MLIR location, or
55/// std::nullopt if one couldn't be created. `uriScheme` is the scheme to use
56/// when building new uris. `uri` is an optional additional filter that, when
57/// present, is used to filter sub locations that do not share the same uri.
58static std::optional<lsp::Location>
59getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc,
60 StringRef uriScheme, const lsp::URIForFile *uri = nullptr) {
61 std::optional<lsp::Location> location;
62 loc->walk(walkFn: [&](Location nestedLoc) {
63 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
64 if (!fileLoc)
65 return WalkResult::advance();
66
67 std::optional<lsp::Location> sourceLoc =
68 getLocationFromLoc(uriScheme, fileLoc);
69 if (sourceLoc && (!uri || sourceLoc->uri == *uri)) {
70 location = *sourceLoc;
71 SMLoc loc = sourceMgr.FindLocForLineAndColumn(
72 BufferID: sourceMgr.getMainFileID(), LineNo: fileLoc.getLine(), ColNo: fileLoc.getColumn());
73
74 // Use range of potential identifier starting at location, else length 1
75 // range.
76 location->range.end.character += 1;
77 if (std::optional<SMRange> range = convertTokenLocToRange(loc)) {
78 auto lineCol = sourceMgr.getLineAndColumn(Loc: range->End);
79 location->range.end.character =
80 std::max(fileLoc.getColumn() + 1, lineCol.second - 1);
81 }
82 return WalkResult::interrupt();
83 }
84 return WalkResult::advance();
85 });
86 return location;
87}
88
89/// Collect all of the locations from the given MLIR location that are not
90/// contained within the given URI.
91static void collectLocationsFromLoc(Location loc,
92 std::vector<lsp::Location> &locations,
93 const lsp::URIForFile &uri) {
94 SetVector<Location> visitedLocs;
95 loc->walk(walkFn: [&](Location nestedLoc) {
96 FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(nestedLoc);
97 if (!fileLoc || !visitedLocs.insert(X: nestedLoc))
98 return WalkResult::advance();
99
100 std::optional<lsp::Location> sourceLoc =
101 getLocationFromLoc(uri.scheme(), fileLoc);
102 if (sourceLoc && sourceLoc->uri != uri)
103 locations.push_back(x: *sourceLoc);
104 return WalkResult::advance();
105 });
106}
107
108/// Returns true if the given range contains the given source location. Note
109/// that this has slightly different behavior than SMRange because it is
110/// inclusive of the end location.
111static bool contains(SMRange range, SMLoc loc) {
112 return range.Start.getPointer() <= loc.getPointer() &&
113 loc.getPointer() <= range.End.getPointer();
114}
115
116/// Returns true if the given location is contained by the definition or one of
117/// the uses of the given SMDefinition. If provided, `overlappedRange` is set to
118/// the range within `def` that the provided `loc` overlapped with.
119static bool isDefOrUse(const AsmParserState::SMDefinition &def, SMLoc loc,
120 SMRange *overlappedRange = nullptr) {
121 // Check the main definition.
122 if (contains(range: def.loc, loc)) {
123 if (overlappedRange)
124 *overlappedRange = def.loc;
125 return true;
126 }
127
128 // Check the uses.
129 const auto *useIt = llvm::find_if(
130 Range: def.uses, P: [&](const SMRange &range) { return contains(range, loc); });
131 if (useIt != def.uses.end()) {
132 if (overlappedRange)
133 *overlappedRange = *useIt;
134 return true;
135 }
136 return false;
137}
138
139/// Given a location pointing to a result, return the result number it refers
140/// to or std::nullopt if it refers to all of the results.
141static std::optional<unsigned> getResultNumberFromLoc(SMLoc loc) {
142 // Skip all of the identifier characters.
143 auto isIdentifierChar = [](char c) {
144 return isalnum(c) || c == '%' || c == '$' || c == '.' || c == '_' ||
145 c == '-';
146 };
147 const char *curPtr = loc.getPointer();
148 while (isIdentifierChar(*curPtr))
149 ++curPtr;
150
151 // Check to see if this location indexes into the result group, via `#`. If it
152 // doesn't, we can't extract a sub result number.
153 if (*curPtr != '#')
154 return std::nullopt;
155
156 // Compute the sub result number from the remaining portion of the string.
157 const char *numberStart = ++curPtr;
158 while (llvm::isDigit(C: *curPtr))
159 ++curPtr;
160 StringRef numberStr(numberStart, curPtr - numberStart);
161 unsigned resultNumber = 0;
162 return numberStr.consumeInteger(Radix: 10, Result&: resultNumber) ? std::optional<unsigned>()
163 : resultNumber;
164}
165
166/// Given a source location range, return the text covered by the given range.
167/// If the range is invalid, returns std::nullopt.
168static std::optional<StringRef> getTextFromRange(SMRange range) {
169 if (!range.isValid())
170 return std::nullopt;
171 const char *startPtr = range.Start.getPointer();
172 return StringRef(startPtr, range.End.getPointer() - startPtr);
173}
174
175/// Given a block, return its position in its parent region.
176static unsigned getBlockNumber(Block *block) {
177 return std::distance(first: block->getParent()->begin(), last: block->getIterator());
178}
179
180/// Given a block and source location, print the source name of the block to the
181/// given output stream.
182static void printDefBlockName(raw_ostream &os, Block *block, SMRange loc = {}) {
183 // Try to extract a name from the source location.
184 std::optional<StringRef> text = getTextFromRange(range: loc);
185 if (text && text->starts_with(Prefix: "^")) {
186 os << *text;
187 return;
188 }
189
190 // Otherwise, we don't have a name so print the block number.
191 os << "<Block #" << getBlockNumber(block) << ">";
192}
193static void printDefBlockName(raw_ostream &os,
194 const AsmParserState::BlockDefinition &def) {
195 printDefBlockName(os, block: def.block, loc: def.definition.loc);
196}
197
198/// Convert the given MLIR diagnostic to the LSP form.
199static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
200 Diagnostic &diag,
201 const lsp::URIForFile &uri) {
202 lsp::Diagnostic lspDiag;
203 lspDiag.source = "mlir";
204
205 // Note: Right now all of the diagnostics are treated as parser issues, but
206 // some are parser and some are verifier.
207 lspDiag.category = "Parse Error";
208
209 // Try to grab a file location for this diagnostic.
210 // TODO: For simplicity, we just grab the first one. It may be likely that we
211 // will need a more interesting heuristic here.'
212 StringRef uriScheme = uri.scheme();
213 std::optional<lsp::Location> lspLocation =
214 getLocationFromLoc(sourceMgr, loc: diag.getLocation(), uriScheme, uri: &uri);
215 if (lspLocation)
216 lspDiag.range = lspLocation->range;
217
218 // Convert the severity for the diagnostic.
219 switch (diag.getSeverity()) {
220 case DiagnosticSeverity::Note:
221 llvm_unreachable("expected notes to be handled separately");
222 case DiagnosticSeverity::Warning:
223 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
224 break;
225 case DiagnosticSeverity::Error:
226 lspDiag.severity = lsp::DiagnosticSeverity::Error;
227 break;
228 case DiagnosticSeverity::Remark:
229 lspDiag.severity = lsp::DiagnosticSeverity::Information;
230 break;
231 }
232 lspDiag.message = diag.str();
233
234 // Attach any notes to the main diagnostic as related information.
235 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
236 for (Diagnostic &note : diag.getNotes()) {
237 lsp::Location noteLoc;
238 if (std::optional<lsp::Location> loc =
239 getLocationFromLoc(sourceMgr, loc: note.getLocation(), uriScheme))
240 noteLoc = *loc;
241 else
242 noteLoc.uri = uri;
243 relatedDiags.emplace_back(args&: noteLoc, args: note.str());
244 }
245 if (!relatedDiags.empty())
246 lspDiag.relatedInformation = std::move(relatedDiags);
247
248 return lspDiag;
249}
250
251//===----------------------------------------------------------------------===//
252// MLIRDocument
253//===----------------------------------------------------------------------===//
254
255namespace {
256/// This class represents all of the information pertaining to a specific MLIR
257/// document.
258struct MLIRDocument {
259 MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
260 StringRef contents, std::vector<lsp::Diagnostic> &diagnostics);
261 MLIRDocument(const MLIRDocument &) = delete;
262 MLIRDocument &operator=(const MLIRDocument &) = delete;
263
264 //===--------------------------------------------------------------------===//
265 // Definitions and References
266 //===--------------------------------------------------------------------===//
267
268 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
269 std::vector<lsp::Location> &locations);
270 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
271 std::vector<lsp::Location> &references);
272
273 //===--------------------------------------------------------------------===//
274 // Hover
275 //===--------------------------------------------------------------------===//
276
277 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
278 const lsp::Position &hoverPos);
279 std::optional<lsp::Hover>
280 buildHoverForOperation(SMRange hoverRange,
281 const AsmParserState::OperationDefinition &op);
282 lsp::Hover buildHoverForOperationResult(SMRange hoverRange, Operation *op,
283 unsigned resultStart,
284 unsigned resultEnd, SMLoc posLoc);
285 lsp::Hover buildHoverForBlock(SMRange hoverRange,
286 const AsmParserState::BlockDefinition &block);
287 lsp::Hover
288 buildHoverForBlockArgument(SMRange hoverRange, BlockArgument arg,
289 const AsmParserState::BlockDefinition &block);
290
291 lsp::Hover buildHoverForAttributeAlias(
292 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr);
293 lsp::Hover
294 buildHoverForTypeAlias(SMRange hoverRange,
295 const AsmParserState::TypeAliasDefinition &type);
296
297 //===--------------------------------------------------------------------===//
298 // Document Symbols
299 //===--------------------------------------------------------------------===//
300
301 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
302 void findDocumentSymbols(Operation *op,
303 std::vector<lsp::DocumentSymbol> &symbols);
304
305 //===--------------------------------------------------------------------===//
306 // Code Completion
307 //===--------------------------------------------------------------------===//
308
309 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
310 const lsp::Position &completePos,
311 const DialectRegistry &registry);
312
313 //===--------------------------------------------------------------------===//
314 // Code Action
315 //===--------------------------------------------------------------------===//
316
317 void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
318 lsp::Position &pos, StringRef severity,
319 StringRef message,
320 std::vector<lsp::TextEdit> &edits);
321
322 //===--------------------------------------------------------------------===//
323 // Bytecode
324 //===--------------------------------------------------------------------===//
325
326 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
327
328 //===--------------------------------------------------------------------===//
329 // Fields
330 //===--------------------------------------------------------------------===//
331
332 /// The high level parser state used to find definitions and references within
333 /// the source file.
334 AsmParserState asmState;
335
336 /// The container for the IR parsed from the input file.
337 Block parsedIR;
338
339 /// A collection of external resources, which we want to propagate up to the
340 /// user.
341 FallbackAsmResourceMap fallbackResourceMap;
342
343 /// The source manager containing the contents of the input file.
344 llvm::SourceMgr sourceMgr;
345};
346} // namespace
347
348MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
349 StringRef contents,
350 std::vector<lsp::Diagnostic> &diagnostics) {
351 ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) {
352 diagnostics.push_back(x: getLspDiagnoticFromDiag(sourceMgr, diag, uri));
353 });
354
355 // Try to parsed the given IR string.
356 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(InputData: contents, BufferName: uri.file());
357 if (!memBuffer) {
358 lsp::Logger::error(fmt: "Failed to create memory buffer for file", vals: uri.file());
359 return;
360 }
361
362 ParserConfig config(&context, /*verifyAfterParse=*/true,
363 &fallbackResourceMap);
364 sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc());
365 if (failed(result: parseAsmSourceFile(sourceMgr, block: &parsedIR, config, asmState: &asmState))) {
366 // If parsing failed, clear out any of the current state.
367 parsedIR.clear();
368 asmState = AsmParserState();
369 fallbackResourceMap = FallbackAsmResourceMap();
370 return;
371 }
372}
373
374//===----------------------------------------------------------------------===//
375// MLIRDocument: Definitions and References
376//===----------------------------------------------------------------------===//
377
378void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
379 const lsp::Position &defPos,
380 std::vector<lsp::Location> &locations) {
381 SMLoc posLoc = defPos.getAsSMLoc(mgr&: sourceMgr);
382
383 // Functor used to check if an SM definition contains the position.
384 auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
385 if (!isDefOrUse(def, loc: posLoc))
386 return false;
387 locations.emplace_back(args: uri, args&: sourceMgr, args: def.loc);
388 return true;
389 };
390
391 // Check all definitions related to operations.
392 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
393 if (contains(range: op.loc, loc: posLoc))
394 return collectLocationsFromLoc(loc: op.op->getLoc(), locations, uri);
395 for (const auto &result : op.resultGroups)
396 if (containsPosition(result.definition))
397 return collectLocationsFromLoc(loc: op.op->getLoc(), locations, uri);
398 for (const auto &symUse : op.symbolUses) {
399 if (contains(range: symUse, loc: posLoc)) {
400 locations.emplace_back(args: uri, args&: sourceMgr, args: op.loc);
401 return collectLocationsFromLoc(loc: op.op->getLoc(), locations, uri);
402 }
403 }
404 }
405
406 // Check all definitions related to blocks.
407 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
408 if (containsPosition(block.definition))
409 return;
410 for (const AsmParserState::SMDefinition &arg : block.arguments)
411 if (containsPosition(arg))
412 return;
413 }
414
415 // Check all alias definitions.
416 for (const AsmParserState::AttributeAliasDefinition &attr :
417 asmState.getAttributeAliasDefs()) {
418 if (containsPosition(attr.definition))
419 return;
420 }
421 for (const AsmParserState::TypeAliasDefinition &type :
422 asmState.getTypeAliasDefs()) {
423 if (containsPosition(type.definition))
424 return;
425 }
426}
427
428void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
429 const lsp::Position &pos,
430 std::vector<lsp::Location> &references) {
431 // Functor used to append all of the definitions/uses of the given SM
432 // definition to the reference list.
433 auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
434 references.emplace_back(args: uri, args&: sourceMgr, args: def.loc);
435 for (const SMRange &use : def.uses)
436 references.emplace_back(args: uri, args&: sourceMgr, args: use);
437 };
438
439 SMLoc posLoc = pos.getAsSMLoc(mgr&: sourceMgr);
440
441 // Check all definitions related to operations.
442 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
443 if (contains(range: op.loc, loc: posLoc)) {
444 for (const auto &result : op.resultGroups)
445 appendSMDef(result.definition);
446 for (const auto &symUse : op.symbolUses)
447 if (contains(range: symUse, loc: posLoc))
448 references.emplace_back(args: uri, args&: sourceMgr, args: symUse);
449 return;
450 }
451 for (const auto &result : op.resultGroups)
452 if (isDefOrUse(def: result.definition, loc: posLoc))
453 return appendSMDef(result.definition);
454 for (const auto &symUse : op.symbolUses) {
455 if (!contains(range: symUse, loc: posLoc))
456 continue;
457 for (const auto &symUse : op.symbolUses)
458 references.emplace_back(args: uri, args&: sourceMgr, args: symUse);
459 return;
460 }
461 }
462
463 // Check all definitions related to blocks.
464 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
465 if (isDefOrUse(def: block.definition, loc: posLoc))
466 return appendSMDef(block.definition);
467
468 for (const AsmParserState::SMDefinition &arg : block.arguments)
469 if (isDefOrUse(def: arg, loc: posLoc))
470 return appendSMDef(arg);
471 }
472
473 // Check all alias definitions.
474 for (const AsmParserState::AttributeAliasDefinition &attr :
475 asmState.getAttributeAliasDefs()) {
476 if (isDefOrUse(def: attr.definition, loc: posLoc))
477 return appendSMDef(attr.definition);
478 }
479 for (const AsmParserState::TypeAliasDefinition &type :
480 asmState.getTypeAliasDefs()) {
481 if (isDefOrUse(def: type.definition, loc: posLoc))
482 return appendSMDef(type.definition);
483 }
484}
485
486//===----------------------------------------------------------------------===//
487// MLIRDocument: Hover
488//===----------------------------------------------------------------------===//
489
490std::optional<lsp::Hover>
491MLIRDocument::findHover(const lsp::URIForFile &uri,
492 const lsp::Position &hoverPos) {
493 SMLoc posLoc = hoverPos.getAsSMLoc(mgr&: sourceMgr);
494 SMRange hoverRange;
495
496 // Check for Hovers on operations and results.
497 for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
498 // Check if the position points at this operation.
499 if (contains(range: op.loc, loc: posLoc))
500 return buildHoverForOperation(hoverRange: op.loc, op);
501
502 // Check if the position points at the symbol name.
503 for (auto &use : op.symbolUses)
504 if (contains(range: use, loc: posLoc))
505 return buildHoverForOperation(hoverRange: use, op);
506
507 // Check if the position points at a result group.
508 for (unsigned i = 0, e = op.resultGroups.size(); i < e; ++i) {
509 const auto &result = op.resultGroups[i];
510 if (!isDefOrUse(def: result.definition, loc: posLoc, overlappedRange: &hoverRange))
511 continue;
512
513 // Get the range of results covered by the over position.
514 unsigned resultStart = result.startIndex;
515 unsigned resultEnd = (i == e - 1) ? op.op->getNumResults()
516 : op.resultGroups[i + 1].startIndex;
517 return buildHoverForOperationResult(hoverRange, op: op.op, resultStart,
518 resultEnd, posLoc);
519 }
520 }
521
522 // Check to see if the hover is over a block argument.
523 for (const AsmParserState::BlockDefinition &block : asmState.getBlockDefs()) {
524 if (isDefOrUse(def: block.definition, loc: posLoc, overlappedRange: &hoverRange))
525 return buildHoverForBlock(hoverRange, block);
526
527 for (const auto &arg : llvm::enumerate(First: block.arguments)) {
528 if (!isDefOrUse(def: arg.value(), loc: posLoc, overlappedRange: &hoverRange))
529 continue;
530
531 return buildHoverForBlockArgument(
532 hoverRange, arg: block.block->getArgument(i: arg.index()), block);
533 }
534 }
535
536 // Check to see if the hover is over an alias.
537 for (const AsmParserState::AttributeAliasDefinition &attr :
538 asmState.getAttributeAliasDefs()) {
539 if (isDefOrUse(def: attr.definition, loc: posLoc, overlappedRange: &hoverRange))
540 return buildHoverForAttributeAlias(hoverRange, attr);
541 }
542 for (const AsmParserState::TypeAliasDefinition &type :
543 asmState.getTypeAliasDefs()) {
544 if (isDefOrUse(def: type.definition, loc: posLoc, overlappedRange: &hoverRange))
545 return buildHoverForTypeAlias(hoverRange, type);
546 }
547
548 return std::nullopt;
549}
550
551std::optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
552 SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
553 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
554 llvm::raw_string_ostream os(hover.contents.value);
555
556 // Add the operation name to the hover.
557 os << "\"" << op.op->getName() << "\"";
558 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op.op))
559 os << " : " << symbol.getVisibility() << " @" << symbol.getName() << "";
560 os << "\n\n";
561
562 os << "Generic Form:\n\n```mlir\n";
563
564 op.op->print(os, flags: OpPrintingFlags()
565 .printGenericOpForm()
566 .elideLargeElementsAttrs()
567 .skipRegions());
568 os << "\n```\n";
569
570 return hover;
571}
572
573lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
574 Operation *op,
575 unsigned resultStart,
576 unsigned resultEnd,
577 SMLoc posLoc) {
578 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579 llvm::raw_string_ostream os(hover.contents.value);
580
581 // Add the parent operation name to the hover.
582 os << "Operation: \"" << op->getName() << "\"\n\n";
583
584 // Check to see if the location points to a specific result within the
585 // group.
586 if (std::optional<unsigned> resultNumber = getResultNumberFromLoc(loc: posLoc)) {
587 if ((resultStart + *resultNumber) < resultEnd) {
588 resultStart += *resultNumber;
589 resultEnd = resultStart + 1;
590 }
591 }
592
593 // Add the range of results and their types to the hover info.
594 if ((resultStart + 1) == resultEnd) {
595 os << "Result #" << resultStart << "\n\n"
596 << "Type: `" << op->getResult(idx: resultStart).getType() << "`\n\n";
597 } else {
598 os << "Result #[" << resultStart << ", " << (resultEnd - 1) << "]\n\n"
599 << "Types: ";
600 llvm::interleaveComma(
601 c: op->getResults().slice(n: resultStart, m: resultEnd), os,
602 each_fn: [&](Value result) { os << "`" << result.getType() << "`"; });
603 }
604
605 return hover;
606}
607
608lsp::Hover
609MLIRDocument::buildHoverForBlock(SMRange hoverRange,
610 const AsmParserState::BlockDefinition &block) {
611 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
612 llvm::raw_string_ostream os(hover.contents.value);
613
614 // Print the given block to the hover output stream.
615 auto printBlockToHover = [&](Block *newBlock) {
616 if (const auto *def = asmState.getBlockDef(block: newBlock))
617 printDefBlockName(os, def: *def);
618 else
619 printDefBlockName(os, block: newBlock);
620 };
621
622 // Display the parent operation, block number, predecessors, and successors.
623 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
624 << "Block #" << getBlockNumber(block: block.block) << "\n\n";
625 if (!block.block->hasNoPredecessors()) {
626 os << "Predecessors: ";
627 llvm::interleaveComma(c: block.block->getPredecessors(), os,
628 each_fn: printBlockToHover);
629 os << "\n\n";
630 }
631 if (!block.block->hasNoSuccessors()) {
632 os << "Successors: ";
633 llvm::interleaveComma(c: block.block->getSuccessors(), os, each_fn: printBlockToHover);
634 os << "\n\n";
635 }
636
637 return hover;
638}
639
640lsp::Hover MLIRDocument::buildHoverForBlockArgument(
641 SMRange hoverRange, BlockArgument arg,
642 const AsmParserState::BlockDefinition &block) {
643 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
644 llvm::raw_string_ostream os(hover.contents.value);
645
646 // Display the parent operation, block, the argument number, and the type.
647 os << "Operation: \"" << block.block->getParentOp()->getName() << "\"\n\n"
648 << "Block: ";
649 printDefBlockName(os, def: block);
650 os << "\n\nArgument #" << arg.getArgNumber() << "\n\n"
651 << "Type: `" << arg.getType() << "`\n\n";
652
653 return hover;
654}
655
656lsp::Hover MLIRDocument::buildHoverForAttributeAlias(
657 SMRange hoverRange, const AsmParserState::AttributeAliasDefinition &attr) {
658 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
659 llvm::raw_string_ostream os(hover.contents.value);
660
661 os << "Attribute Alias: \"" << attr.name << "\n\n";
662 os << "Value: ```mlir\n" << attr.value << "\n```\n\n";
663
664 return hover;
665}
666
667lsp::Hover MLIRDocument::buildHoverForTypeAlias(
668 SMRange hoverRange, const AsmParserState::TypeAliasDefinition &type) {
669 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
670 llvm::raw_string_ostream os(hover.contents.value);
671
672 os << "Type Alias: \"" << type.name << "\n\n";
673 os << "Value: ```mlir\n" << type.value << "\n```\n\n";
674
675 return hover;
676}
677
678//===----------------------------------------------------------------------===//
679// MLIRDocument: Document Symbols
680//===----------------------------------------------------------------------===//
681
682void MLIRDocument::findDocumentSymbols(
683 std::vector<lsp::DocumentSymbol> &symbols) {
684 for (Operation &op : parsedIR)
685 findDocumentSymbols(op: &op, symbols);
686}
687
688void MLIRDocument::findDocumentSymbols(
689 Operation *op, std::vector<lsp::DocumentSymbol> &symbols) {
690 std::vector<lsp::DocumentSymbol> *childSymbols = &symbols;
691
692 // Check for the source information of this operation.
693 if (const AsmParserState::OperationDefinition *def = asmState.getOpDef(op)) {
694 // If this operation defines a symbol, record it.
695 if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
696 symbols.emplace_back(symbol.getName(),
697 isa<FunctionOpInterface>(Val: op)
698 ? lsp::SymbolKind::Function
699 : lsp::SymbolKind::Class,
700 lsp::Range(sourceMgr, def->scopeLoc),
701 lsp::Range(sourceMgr, def->loc));
702 childSymbols = &symbols.back().children;
703
704 } else if (op->hasTrait<OpTrait::SymbolTable>()) {
705 // Otherwise, if this is a symbol table push an anonymous document symbol.
706 symbols.emplace_back(args: "<" + op->getName().getStringRef() + ">",
707 args: lsp::SymbolKind::Namespace,
708 args: lsp::Range(sourceMgr, def->scopeLoc),
709 args: lsp::Range(sourceMgr, def->loc));
710 childSymbols = &symbols.back().children;
711 }
712 }
713
714 // Recurse into the regions of this operation.
715 if (!op->getNumRegions())
716 return;
717 for (Region &region : op->getRegions())
718 for (Operation &childOp : region.getOps())
719 findDocumentSymbols(op: &childOp, symbols&: *childSymbols);
720}
721
722//===----------------------------------------------------------------------===//
723// MLIRDocument: Code Completion
724//===----------------------------------------------------------------------===//
725
726namespace {
727class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
728public:
729 LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList,
730 MLIRContext *ctx)
731 : AsmParserCodeCompleteContext(completeLoc),
732 completionList(completionList), ctx(ctx) {}
733
734 /// Signal code completion for a dialect name, with an optional prefix.
735 void completeDialectName(StringRef prefix) final {
736 for (StringRef dialect : ctx->getAvailableDialects()) {
737 lsp::CompletionItem item(prefix + dialect,
738 lsp::CompletionItemKind::Module,
739 /*sortText=*/"3");
740 item.detail = "dialect";
741 completionList.items.emplace_back(args&: item);
742 }
743 }
744 using AsmParserCodeCompleteContext::completeDialectName;
745
746 /// Signal code completion for an operation name within the given dialect.
747 void completeOperationName(StringRef dialectName) final {
748 Dialect *dialect = ctx->getOrLoadDialect(name: dialectName);
749 if (!dialect)
750 return;
751
752 for (const auto &op : ctx->getRegisteredOperations()) {
753 if (&op.getDialect() != dialect)
754 continue;
755
756 lsp::CompletionItem item(
757 op.getStringRef().drop_front(N: dialectName.size() + 1),
758 lsp::CompletionItemKind::Field,
759 /*sortText=*/"1");
760 item.detail = "operation";
761 completionList.items.emplace_back(args&: item);
762 }
763 }
764
765 /// Append the given SSA value as a code completion result for SSA value
766 /// completions.
767 void appendSSAValueCompletion(StringRef name, std::string typeData) final {
768 // Check if we need to insert the `%` or not.
769 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
770
771 lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
772 if (stripPrefix)
773 item.insertText = name.drop_front(N: 1).str();
774 item.detail = std::move(typeData);
775 completionList.items.emplace_back(args&: item);
776 }
777
778 /// Append the given block as a code completion result for block name
779 /// completions.
780 void appendBlockCompletion(StringRef name) final {
781 // Check if we need to insert the `^` or not.
782 bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
783
784 lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
785 if (stripPrefix)
786 item.insertText = name.drop_front(N: 1).str();
787 completionList.items.emplace_back(args&: item);
788 }
789
790 /// Signal a completion for the given expected token.
791 void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
792 for (StringRef token : tokens) {
793 lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
794 /*sortText=*/"0");
795 item.detail = optional ? "optional" : "";
796 completionList.items.emplace_back(args&: item);
797 }
798 }
799
800 /// Signal a completion for an attribute.
801 void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
802 appendSimpleCompletions(completions: {"affine_set", "affine_map", "dense",
803 "dense_resource", "false", "loc", "sparse", "true",
804 "unit"},
805 kind: lsp::CompletionItemKind::Field,
806 /*sortText=*/"1");
807
808 completeDialectName(prefix: "#");
809 completeAliases(aliases, prefix: "#");
810 }
811 void completeDialectAttributeOrAlias(
812 const llvm::StringMap<Attribute> &aliases) override {
813 completeDialectName();
814 completeAliases(aliases);
815 }
816
817 /// Signal a completion for a type.
818 void completeType(const llvm::StringMap<Type> &aliases) override {
819 // Handle the various builtin types.
820 appendSimpleCompletions(completions: {"memref", "tensor", "complex", "tuple", "vector",
821 "bf16", "f16", "f32", "f64", "f80", "f128",
822 "index", "none"},
823 kind: lsp::CompletionItemKind::Field,
824 /*sortText=*/"1");
825
826 // Handle the builtin integer types.
827 for (StringRef type : {"i", "si", "ui"}) {
828 lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field,
829 /*sortText=*/"1");
830 item.insertText = type.str();
831 completionList.items.emplace_back(args&: item);
832 }
833
834 // Insert completions for dialect types and aliases.
835 completeDialectName(prefix: "!");
836 completeAliases(aliases, prefix: "!");
837 }
838 void
839 completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
840 completeDialectName();
841 completeAliases(aliases);
842 }
843
844 /// Add completion results for the given set of aliases.
845 template <typename T>
846 void completeAliases(const llvm::StringMap<T> &aliases,
847 StringRef prefix = "") {
848 for (const auto &alias : aliases) {
849 lsp::CompletionItem item(prefix + alias.getKey(),
850 lsp::CompletionItemKind::Field,
851 /*sortText=*/"2");
852 llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
853 completionList.items.emplace_back(args&: item);
854 }
855 }
856
857 /// Add a set of simple completions that all have the same kind.
858 void appendSimpleCompletions(ArrayRef<StringRef> completions,
859 lsp::CompletionItemKind kind,
860 StringRef sortText = "") {
861 for (StringRef completion : completions)
862 completionList.items.emplace_back(args&: completion, args&: kind, args&: sortText);
863 }
864
865private:
866 lsp::CompletionList &completionList;
867 MLIRContext *ctx;
868};
869} // namespace
870
871lsp::CompletionList
872MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
873 const lsp::Position &completePos,
874 const DialectRegistry &registry) {
875 SMLoc posLoc = completePos.getAsSMLoc(mgr&: sourceMgr);
876 if (!posLoc.isValid())
877 return lsp::CompletionList();
878
879 // To perform code completion, we run another parse of the module with the
880 // code completion context provided.
881 MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED);
882 tmpContext.allowUnregisteredDialects();
883 lsp::CompletionList completionList;
884 LSPCodeCompleteContext lspCompleteContext(posLoc, completionList,
885 &tmpContext);
886
887 Block tmpIR;
888 AsmParserState tmpState;
889 (void)parseAsmSourceFile(sourceMgr, block: &tmpIR, config: &tmpContext, asmState: &tmpState,
890 codeCompleteContext: &lspCompleteContext);
891 return completionList;
892}
893
894//===----------------------------------------------------------------------===//
895// MLIRDocument: Code Action
896//===----------------------------------------------------------------------===//
897
898void MLIRDocument::getCodeActionForDiagnostic(
899 const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
900 StringRef message, std::vector<lsp::TextEdit> &edits) {
901 // Ignore diagnostics that print the current operation. These are always
902 // enabled for the language server, but not generally during normal
903 // parsing/verification.
904 if (message.starts_with(Prefix: "see current operation: "))
905 return;
906
907 // Get the start of the line containing the diagnostic.
908 const auto &buffer = sourceMgr.getBufferInfo(i: sourceMgr.getMainFileID());
909 const char *lineStart = buffer.getPointerForLineNumber(LineNo: pos.line + 1);
910 if (!lineStart)
911 return;
912 StringRef line(lineStart, pos.character);
913
914 // Add a text edit for adding an expected-* diagnostic check for this
915 // diagnostic.
916 lsp::TextEdit edit;
917 edit.range = lsp::Range(lsp::Position(pos.line, 0));
918
919 // Use the indent of the current line for the expected-* diagnostic.
920 size_t indent = line.find_first_not_of(Chars: " ");
921 if (indent == StringRef::npos)
922 indent = line.size();
923
924 edit.newText.append(n: indent, c: ' ');
925 llvm::raw_string_ostream(edit.newText)
926 << "// expected-" << severity << " @below {{" << message << "}}\n";
927 edits.emplace_back(args: std::move(edit));
928}
929
930//===----------------------------------------------------------------------===//
931// MLIRDocument: Bytecode
932//===----------------------------------------------------------------------===//
933
934llvm::Expected<lsp::MLIRConvertBytecodeResult>
935MLIRDocument::convertToBytecode() {
936 // TODO: We currently require a single top-level operation, but this could
937 // conceptually be relaxed.
938 if (!llvm::hasSingleElement(C&: parsedIR)) {
939 if (parsedIR.empty()) {
940 return llvm::make_error<lsp::LSPError>(
941 Args: "expected a single and valid top-level operation, please ensure "
942 "there are no errors",
943 Args: lsp::ErrorCode::RequestFailed);
944 }
945 return llvm::make_error<lsp::LSPError>(
946 Args: "expected a single top-level operation", Args: lsp::ErrorCode::RequestFailed);
947 }
948
949 lsp::MLIRConvertBytecodeResult result;
950 {
951 BytecodeWriterConfig writerConfig(fallbackResourceMap);
952
953 std::string rawBytecodeBuffer;
954 llvm::raw_string_ostream os(rawBytecodeBuffer);
955 // No desired bytecode version set, so no need to check for error.
956 (void)writeBytecodeToFile(op: &parsedIR.front(), os, config: writerConfig);
957 result.output = llvm::encodeBase64(Bytes: rawBytecodeBuffer);
958 }
959 return result;
960}
961
962//===----------------------------------------------------------------------===//
963// MLIRTextFileChunk
964//===----------------------------------------------------------------------===//
965
966namespace {
967/// This class represents a single chunk of an MLIR text file.
968struct MLIRTextFileChunk {
969 MLIRTextFileChunk(MLIRContext &context, uint64_t lineOffset,
970 const lsp::URIForFile &uri, StringRef contents,
971 std::vector<lsp::Diagnostic> &diagnostics)
972 : lineOffset(lineOffset), document(context, uri, contents, diagnostics) {}
973
974 /// Adjust the line number of the given range to anchor at the beginning of
975 /// the file, instead of the beginning of this chunk.
976 void adjustLocForChunkOffset(lsp::Range &range) {
977 adjustLocForChunkOffset(pos&: range.start);
978 adjustLocForChunkOffset(pos&: range.end);
979 }
980 /// Adjust the line number of the given position to anchor at the beginning of
981 /// the file, instead of the beginning of this chunk.
982 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
983
984 /// The line offset of this chunk from the beginning of the file.
985 uint64_t lineOffset;
986 /// The document referred to by this chunk.
987 MLIRDocument document;
988};
989} // namespace
990
991//===----------------------------------------------------------------------===//
992// MLIRTextFile
993//===----------------------------------------------------------------------===//
994
995namespace {
996/// This class represents a text file containing one or more MLIR documents.
997class MLIRTextFile {
998public:
999 MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1000 int64_t version, DialectRegistry &registry,
1001 std::vector<lsp::Diagnostic> &diagnostics);
1002
1003 /// Return the current version of this text file.
1004 int64_t getVersion() const { return version; }
1005
1006 //===--------------------------------------------------------------------===//
1007 // LSP Queries
1008 //===--------------------------------------------------------------------===//
1009
1010 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1011 std::vector<lsp::Location> &locations);
1012 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1013 std::vector<lsp::Location> &references);
1014 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1015 lsp::Position hoverPos);
1016 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1017 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1018 lsp::Position completePos);
1019 void getCodeActions(const lsp::URIForFile &uri, const lsp::Range &pos,
1020 const lsp::CodeActionContext &context,
1021 std::vector<lsp::CodeAction> &actions);
1022 llvm::Expected<lsp::MLIRConvertBytecodeResult> convertToBytecode();
1023
1024private:
1025 /// Find the MLIR document that contains the given position, and update the
1026 /// position to be anchored at the start of the found chunk instead of the
1027 /// beginning of the file.
1028 MLIRTextFileChunk &getChunkFor(lsp::Position &pos);
1029
1030 /// The context used to hold the state contained by the parsed document.
1031 MLIRContext context;
1032
1033 /// The full string contents of the file.
1034 std::string contents;
1035
1036 /// The version of this file.
1037 int64_t version;
1038
1039 /// The number of lines in the file.
1040 int64_t totalNumLines = 0;
1041
1042 /// The chunks of this file. The order of these chunks is the order in which
1043 /// they appear in the text file.
1044 std::vector<std::unique_ptr<MLIRTextFileChunk>> chunks;
1045};
1046} // namespace
1047
1048MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1049 int64_t version, DialectRegistry &registry,
1050 std::vector<lsp::Diagnostic> &diagnostics)
1051 : context(registry, MLIRContext::Threading::DISABLED),
1052 contents(fileContents.str()), version(version) {
1053 context.allowUnregisteredDialects();
1054
1055 // Split the file into separate MLIR documents.
1056 SmallVector<StringRef, 8> subContents;
1057 StringRef(contents).split(A&: subContents, Separator: kDefaultSplitMarker);
1058 chunks.emplace_back(args: std::make_unique<MLIRTextFileChunk>(
1059 args&: context, /*lineOffset=*/args: 0, args: uri, args&: subContents.front(), args&: diagnostics));
1060
1061 uint64_t lineOffset = subContents.front().count(C: '\n');
1062 for (StringRef docContents : llvm::drop_begin(RangeOrContainer&: subContents)) {
1063 unsigned currentNumDiags = diagnostics.size();
1064 auto chunk = std::make_unique<MLIRTextFileChunk>(args&: context, args&: lineOffset, args: uri,
1065 args&: docContents, args&: diagnostics);
1066 lineOffset += docContents.count(C: '\n');
1067
1068 // Adjust locations used in diagnostics to account for the offset from the
1069 // beginning of the file.
1070 for (lsp::Diagnostic &diag :
1071 llvm::drop_begin(RangeOrContainer&: diagnostics, N: currentNumDiags)) {
1072 chunk->adjustLocForChunkOffset(range&: diag.range);
1073
1074 if (!diag.relatedInformation)
1075 continue;
1076 for (auto &it : *diag.relatedInformation)
1077 if (it.location.uri == uri)
1078 chunk->adjustLocForChunkOffset(range&: it.location.range);
1079 }
1080 chunks.emplace_back(args: std::move(chunk));
1081 }
1082 totalNumLines = lineOffset;
1083}
1084
1085void MLIRTextFile::getLocationsOf(const lsp::URIForFile &uri,
1086 lsp::Position defPos,
1087 std::vector<lsp::Location> &locations) {
1088 MLIRTextFileChunk &chunk = getChunkFor(pos&: defPos);
1089 chunk.document.getLocationsOf(uri, defPos, locations);
1090
1091 // Adjust any locations within this file for the offset of this chunk.
1092 if (chunk.lineOffset == 0)
1093 return;
1094 for (lsp::Location &loc : locations)
1095 if (loc.uri == uri)
1096 chunk.adjustLocForChunkOffset(range&: loc.range);
1097}
1098
1099void MLIRTextFile::findReferencesOf(const lsp::URIForFile &uri,
1100 lsp::Position pos,
1101 std::vector<lsp::Location> &references) {
1102 MLIRTextFileChunk &chunk = getChunkFor(pos);
1103 chunk.document.findReferencesOf(uri, pos, references);
1104
1105 // Adjust any locations within this file for the offset of this chunk.
1106 if (chunk.lineOffset == 0)
1107 return;
1108 for (lsp::Location &loc : references)
1109 if (loc.uri == uri)
1110 chunk.adjustLocForChunkOffset(range&: loc.range);
1111}
1112
1113std::optional<lsp::Hover> MLIRTextFile::findHover(const lsp::URIForFile &uri,
1114 lsp::Position hoverPos) {
1115 MLIRTextFileChunk &chunk = getChunkFor(pos&: hoverPos);
1116 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1117
1118 // Adjust any locations within this file for the offset of this chunk.
1119 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1120 chunk.adjustLocForChunkOffset(range&: *hoverInfo->range);
1121 return hoverInfo;
1122}
1123
1124void MLIRTextFile::findDocumentSymbols(
1125 std::vector<lsp::DocumentSymbol> &symbols) {
1126 if (chunks.size() == 1)
1127 return chunks.front()->document.findDocumentSymbols(symbols);
1128
1129 // If there are multiple chunks in this file, we create top-level symbols for
1130 // each chunk.
1131 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1132 MLIRTextFileChunk &chunk = *chunks[i];
1133 lsp::Position startPos(chunk.lineOffset);
1134 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1135 : chunks[i + 1]->lineOffset);
1136 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1137 lsp::SymbolKind::Namespace,
1138 /*range=*/lsp::Range(startPos, endPos),
1139 /*selectionRange=*/lsp::Range(startPos));
1140 chunk.document.findDocumentSymbols(symbols&: symbol.children);
1141
1142 // Fixup the locations of document symbols within this chunk.
1143 if (i != 0) {
1144 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1145 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1146 symbolsToFix.push_back(Elt: &childSymbol);
1147
1148 while (!symbolsToFix.empty()) {
1149 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1150 chunk.adjustLocForChunkOffset(range&: symbol->range);
1151 chunk.adjustLocForChunkOffset(range&: symbol->selectionRange);
1152
1153 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1154 symbolsToFix.push_back(Elt: &childSymbol);
1155 }
1156 }
1157
1158 // Push the symbol for this chunk.
1159 symbols.emplace_back(args: std::move(symbol));
1160 }
1161}
1162
1163lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1164 lsp::Position completePos) {
1165 MLIRTextFileChunk &chunk = getChunkFor(pos&: completePos);
1166 lsp::CompletionList completionList = chunk.document.getCodeCompletion(
1167 uri, completePos, registry: context.getDialectRegistry());
1168
1169 // Adjust any completion locations.
1170 for (lsp::CompletionItem &item : completionList.items) {
1171 if (item.textEdit)
1172 chunk.adjustLocForChunkOffset(range&: item.textEdit->range);
1173 for (lsp::TextEdit &edit : item.additionalTextEdits)
1174 chunk.adjustLocForChunkOffset(range&: edit.range);
1175 }
1176 return completionList;
1177}
1178
1179void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
1180 const lsp::Range &pos,
1181 const lsp::CodeActionContext &context,
1182 std::vector<lsp::CodeAction> &actions) {
1183 // Create actions for any diagnostics in this file.
1184 for (auto &diag : context.diagnostics) {
1185 if (diag.source != "mlir")
1186 continue;
1187 lsp::Position diagPos = diag.range.start;
1188 MLIRTextFileChunk &chunk = getChunkFor(pos&: diagPos);
1189
1190 // Add a new code action that inserts a "expected" diagnostic check.
1191 lsp::CodeAction action;
1192 action.title = "Add expected-* diagnostic checks";
1193 action.kind = lsp::CodeAction::kQuickFix.str();
1194
1195 StringRef severity;
1196 switch (diag.severity) {
1197 case lsp::DiagnosticSeverity::Error:
1198 severity = "error";
1199 break;
1200 case lsp::DiagnosticSeverity::Warning:
1201 severity = "warning";
1202 break;
1203 default:
1204 continue;
1205 }
1206
1207 // Get edits for the diagnostic.
1208 std::vector<lsp::TextEdit> edits;
1209 chunk.document.getCodeActionForDiagnostic(uri, pos&: diagPos, severity,
1210 message: diag.message, edits);
1211
1212 // Walk the related diagnostics, this is how we encode notes.
1213 if (diag.relatedInformation) {
1214 for (auto &noteDiag : *diag.relatedInformation) {
1215 if (noteDiag.location.uri != uri)
1216 continue;
1217 diagPos = noteDiag.location.range.start;
1218 diagPos.line -= chunk.lineOffset;
1219 chunk.document.getCodeActionForDiagnostic(uri, pos&: diagPos, severity: "note",
1220 message: noteDiag.message, edits);
1221 }
1222 }
1223 // Fixup the locations for any edits.
1224 for (lsp::TextEdit &edit : edits)
1225 chunk.adjustLocForChunkOffset(range&: edit.range);
1226
1227 action.edit.emplace();
1228 action.edit->changes[uri.uri().str()] = std::move(edits);
1229 action.diagnostics = {diag};
1230
1231 actions.emplace_back(args: std::move(action));
1232 }
1233}
1234
1235llvm::Expected<lsp::MLIRConvertBytecodeResult>
1236MLIRTextFile::convertToBytecode() {
1237 // Bail out if there is more than one chunk, bytecode wants a single module.
1238 if (chunks.size() != 1) {
1239 return llvm::make_error<lsp::LSPError>(
1240 Args: "unexpected split file, please remove all `// -----`",
1241 Args: lsp::ErrorCode::RequestFailed);
1242 }
1243 return chunks.front()->document.convertToBytecode();
1244}
1245
1246MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
1247 if (chunks.size() == 1)
1248 return *chunks.front();
1249
1250 // Search for the first chunk with a greater line offset, the previous chunk
1251 // is the one that contains `pos`.
1252 auto it = llvm::upper_bound(
1253 Range&: chunks, Value&: pos, C: [](const lsp::Position &pos, const auto &chunk) {
1254 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1255 });
1256 MLIRTextFileChunk &chunk = it == chunks.end() ? *chunks.back() : **(--it);
1257 pos.line -= chunk.lineOffset;
1258 return chunk;
1259}
1260
1261//===----------------------------------------------------------------------===//
1262// MLIRServer::Impl
1263//===----------------------------------------------------------------------===//
1264
1265struct lsp::MLIRServer::Impl {
1266 Impl(DialectRegistry &registry) : registry(registry) {}
1267
1268 /// The registry containing dialects that can be recognized in parsed .mlir
1269 /// files.
1270 DialectRegistry &registry;
1271
1272 /// The files held by the server, mapped by their URI file name.
1273 llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
1274};
1275
1276//===----------------------------------------------------------------------===//
1277// MLIRServer
1278//===----------------------------------------------------------------------===//
1279
1280lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
1281 : impl(std::make_unique<Impl>(args&: registry)) {}
1282lsp::MLIRServer::~MLIRServer() = default;
1283
1284void lsp::MLIRServer::addOrUpdateDocument(
1285 const URIForFile &uri, StringRef contents, int64_t version,
1286 std::vector<Diagnostic> &diagnostics) {
1287 impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
1288 args: uri, args&: contents, args&: version, args&: impl->registry, args&: diagnostics);
1289}
1290
1291std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
1292 auto it = impl->files.find(Key: uri.file());
1293 if (it == impl->files.end())
1294 return std::nullopt;
1295
1296 int64_t version = it->second->getVersion();
1297 impl->files.erase(I: it);
1298 return version;
1299}
1300
1301void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
1302 const Position &defPos,
1303 std::vector<Location> &locations) {
1304 auto fileIt = impl->files.find(Key: uri.file());
1305 if (fileIt != impl->files.end())
1306 fileIt->second->getLocationsOf(uri, defPos, locations);
1307}
1308
1309void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
1310 const Position &pos,
1311 std::vector<Location> &references) {
1312 auto fileIt = impl->files.find(Key: uri.file());
1313 if (fileIt != impl->files.end())
1314 fileIt->second->findReferencesOf(uri, pos, references);
1315}
1316
1317std::optional<lsp::Hover> lsp::MLIRServer::findHover(const URIForFile &uri,
1318 const Position &hoverPos) {
1319 auto fileIt = impl->files.find(Key: uri.file());
1320 if (fileIt != impl->files.end())
1321 return fileIt->second->findHover(uri, hoverPos);
1322 return std::nullopt;
1323}
1324
1325void lsp::MLIRServer::findDocumentSymbols(
1326 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1327 auto fileIt = impl->files.find(Key: uri.file());
1328 if (fileIt != impl->files.end())
1329 fileIt->second->findDocumentSymbols(symbols);
1330}
1331
1332lsp::CompletionList
1333lsp::MLIRServer::getCodeCompletion(const URIForFile &uri,
1334 const Position &completePos) {
1335 auto fileIt = impl->files.find(Key: uri.file());
1336 if (fileIt != impl->files.end())
1337 return fileIt->second->getCodeCompletion(uri, completePos);
1338 return CompletionList();
1339}
1340
1341void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,
1342 const CodeActionContext &context,
1343 std::vector<CodeAction> &actions) {
1344 auto fileIt = impl->files.find(Key: uri.file());
1345 if (fileIt != impl->files.end())
1346 fileIt->second->getCodeActions(uri, pos, context, actions);
1347}
1348
1349llvm::Expected<lsp::MLIRConvertBytecodeResult>
1350lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
1351 MLIRContext tempContext(impl->registry);
1352 tempContext.allowUnregisteredDialects();
1353
1354 // Collect any errors during parsing.
1355 std::string errorMsg;
1356 ScopedDiagnosticHandler diagHandler(
1357 &tempContext,
1358 [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
1359
1360 // Handling for external resources, which we want to propagate up to the user.
1361 FallbackAsmResourceMap fallbackResourceMap;
1362
1363 // Setup the parser config.
1364 ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true,
1365 &fallbackResourceMap);
1366
1367 // Try to parse the given source file.
1368 Block parsedBlock;
1369 if (failed(result: parseSourceFile(filename: uri.file(), block: &parsedBlock, config: parserConfig))) {
1370 return llvm::make_error<lsp::LSPError>(
1371 Args: "failed to parse bytecode source file: " + errorMsg,
1372 Args: lsp::ErrorCode::RequestFailed);
1373 }
1374
1375 // TODO: We currently expect a single top-level operation, but this could
1376 // conceptually be relaxed.
1377 if (!llvm::hasSingleElement(C&: parsedBlock)) {
1378 return llvm::make_error<lsp::LSPError>(
1379 Args: "expected bytecode to contain a single top-level operation",
1380 Args: lsp::ErrorCode::RequestFailed);
1381 }
1382
1383 // Print the module to a buffer.
1384 lsp::MLIRConvertBytecodeResult result;
1385 {
1386 // Extract the top-level op so that aliases get printed.
1387 // FIXME: We should be able to enable aliases without having to do this!
1388 OwningOpRef<Operation *> topOp = &parsedBlock.front();
1389 topOp->remove();
1390
1391 AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
1392 /*locationMap=*/nullptr, &fallbackResourceMap);
1393
1394 llvm::raw_string_ostream os(result.output);
1395 topOp->print(os, state);
1396 }
1397 return std::move(result);
1398}
1399
1400llvm::Expected<lsp::MLIRConvertBytecodeResult>
1401lsp::MLIRServer::convertToBytecode(const URIForFile &uri) {
1402 auto fileIt = impl->files.find(Key: uri.file());
1403 if (fileIt == impl->files.end()) {
1404 return llvm::make_error<lsp::LSPError>(
1405 Args: "language server does not contain an entry for this source file",
1406 Args: lsp::ErrorCode::RequestFailed);
1407 }
1408 return fileIt->second->convertToBytecode();
1409}
1410

source code of mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp