1//===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PDLLServer.h"
10
11#include "Protocol.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/Support/ToolUtilities.h"
14#include "mlir/Tools/PDLL/AST/Context.h"
15#include "mlir/Tools/PDLL/AST/Nodes.h"
16#include "mlir/Tools/PDLL/AST/Types.h"
17#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
18#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
19#include "mlir/Tools/PDLL/ODS/Constraint.h"
20#include "mlir/Tools/PDLL/ODS/Context.h"
21#include "mlir/Tools/PDLL/ODS/Dialect.h"
22#include "mlir/Tools/PDLL/ODS/Operation.h"
23#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
24#include "mlir/Tools/PDLL/Parser/Parser.h"
25#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
26#include "mlir/Tools/lsp-server-support/Logging.h"
27#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
28#include "llvm/ADT/IntervalMap.h"
29#include "llvm/ADT/StringMap.h"
30#include "llvm/ADT/StringSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/FileSystem.h"
33#include "llvm/Support/Path.h"
34#include <optional>
35
36using namespace mlir;
37using namespace mlir::pdll;
38
39/// Returns a language server uri for the given source location. `mainFileURI`
40/// corresponds to the uri for the main file of the source manager.
41static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
42 const lsp::URIForFile &mainFileURI) {
43 int bufferId = mgr.FindBufferContainingLoc(Loc: loc.Start);
44 if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
45 return mainFileURI;
46 llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
47 absoluteFilepath: mgr.getBufferInfo(i: bufferId).Buffer->getBufferIdentifier());
48 if (fileForLoc)
49 return *fileForLoc;
50 lsp::Logger::error(fmt: "Failed to create URI for include file: {0}",
51 vals: llvm::toString(E: fileForLoc.takeError()));
52 return mainFileURI;
53}
54
55/// Returns true if the given location is in the main file of the source
56/// manager.
57static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
58 return mgr.FindBufferContainingLoc(Loc: loc.Start) == mgr.getMainFileID();
59}
60
61/// Returns a language server location from the given source range.
62static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
63 const lsp::URIForFile &uri) {
64 return lsp::Location(getURIFromLoc(mgr, loc: range, mainFileURI: uri), lsp::Range(mgr, range));
65}
66
67/// Convert the given MLIR diagnostic to the LSP form.
68static std::optional<lsp::Diagnostic>
69getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
70 const lsp::URIForFile &uri) {
71 lsp::Diagnostic lspDiag;
72 lspDiag.source = "pdll";
73
74 // FIXME: Right now all of the diagnostics are treated as parser issues, but
75 // some are parser and some are verifier.
76 lspDiag.category = "Parse Error";
77
78 // Try to grab a file location for this diagnostic.
79 lsp::Location loc = getLocationFromLoc(mgr&: sourceMgr, range: diag.getLocation(), uri);
80 lspDiag.range = loc.range;
81
82 // Skip diagnostics that weren't emitted within the main file.
83 if (loc.uri != uri)
84 return std::nullopt;
85
86 // Convert the severity for the diagnostic.
87 switch (diag.getSeverity()) {
88 case ast::Diagnostic::Severity::DK_Note:
89 llvm_unreachable("expected notes to be handled separately");
90 case ast::Diagnostic::Severity::DK_Warning:
91 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
92 break;
93 case ast::Diagnostic::Severity::DK_Error:
94 lspDiag.severity = lsp::DiagnosticSeverity::Error;
95 break;
96 case ast::Diagnostic::Severity::DK_Remark:
97 lspDiag.severity = lsp::DiagnosticSeverity::Information;
98 break;
99 }
100 lspDiag.message = diag.getMessage().str();
101
102 // Attach any notes to the main diagnostic as related information.
103 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
104 for (const ast::Diagnostic &note : diag.getNotes()) {
105 relatedDiags.emplace_back(
106 args: getLocationFromLoc(mgr&: sourceMgr, range: note.getLocation(), uri),
107 args: note.getMessage().str());
108 }
109 if (!relatedDiags.empty())
110 lspDiag.relatedInformation = std::move(relatedDiags);
111
112 return lspDiag;
113}
114
115/// Get or extract the documentation for the given decl.
116static std::optional<std::string>
117getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl) {
118 // If the decl already had documentation set, use it.
119 if (std::optional<StringRef> doc = decl->getDocComment())
120 return doc->str();
121
122 // If the decl doesn't yet have documentation, try to extract it from the
123 // source file.
124 return lsp::extractSourceDocComment(sourceMgr, loc: decl->getLoc().Start);
125}
126
127//===----------------------------------------------------------------------===//
128// PDLIndex
129//===----------------------------------------------------------------------===//
130
131namespace {
132struct PDLIndexSymbol {
133 explicit PDLIndexSymbol(const ast::Decl *definition)
134 : definition(definition) {}
135 explicit PDLIndexSymbol(const ods::Operation *definition)
136 : definition(definition) {}
137
138 /// Return the location of the definition of this symbol.
139 SMRange getDefLoc() const {
140 if (const ast::Decl *decl = llvm::dyn_cast_if_present<const ast::Decl *>(Val: definition)) {
141 const ast::Name *declName = decl->getName();
142 return declName ? declName->getLoc() : decl->getLoc();
143 }
144 return definition.get<const ods::Operation *>()->getLoc();
145 }
146
147 /// The main definition of the symbol.
148 PointerUnion<const ast::Decl *, const ods::Operation *> definition;
149 /// The set of references to the symbol.
150 std::vector<SMRange> references;
151};
152
153/// This class provides an index for definitions/uses within a PDL document.
154/// It provides efficient lookup of a definition given an input source range.
155class PDLIndex {
156public:
157 PDLIndex() : intervalMap(allocator) {}
158
159 /// Initialize the index with the given ast::Module.
160 void initialize(const ast::Module &module, const ods::Context &odsContext);
161
162 /// Lookup a symbol for the given location. Returns nullptr if no symbol could
163 /// be found. If provided, `overlappedRange` is set to the range that the
164 /// provided `loc` overlapped with.
165 const PDLIndexSymbol *lookup(SMLoc loc,
166 SMRange *overlappedRange = nullptr) const;
167
168private:
169 /// The type of interval map used to store source references. SMRange is
170 /// half-open, so we also need to use a half-open interval map.
171 using MapT =
172 llvm::IntervalMap<const char *, const PDLIndexSymbol *,
173 llvm::IntervalMapImpl::NodeSizer<
174 const char *, const PDLIndexSymbol *>::LeafSize,
175 llvm::IntervalMapHalfOpenInfo<const char *>>;
176
177 /// An allocator for the interval map.
178 MapT::Allocator allocator;
179
180 /// An interval map containing a corresponding definition mapped to a source
181 /// interval.
182 MapT intervalMap;
183
184 /// A mapping between definitions and their corresponding symbol.
185 DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol;
186};
187} // namespace
188
189void PDLIndex::initialize(const ast::Module &module,
190 const ods::Context &odsContext) {
191 auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
192 auto it = defToSymbol.try_emplace(def, nullptr);
193 if (it.second)
194 it.first->second = std::make_unique<PDLIndexSymbol>(def);
195 return &*it.first->second;
196 };
197 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
198 bool isDef = false) {
199 const char *startLoc = refLoc.Start.getPointer();
200 const char *endLoc = refLoc.End.getPointer();
201 if (!intervalMap.overlaps(a: startLoc, b: endLoc)) {
202 intervalMap.insert(a: startLoc, b: endLoc, y: sym);
203 if (!isDef)
204 sym->references.push_back(x: refLoc);
205 }
206 };
207 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
208 const ods::Operation *odsOp = odsContext.lookupOperation(name: opName);
209 if (!odsOp)
210 return;
211
212 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
213 insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
214 insertDeclRef(symbol, refLoc);
215 };
216
217 module.walk(walkFn: [&](const ast::Node *node) {
218 // Handle references to PDL decls.
219 if (const auto *decl = dyn_cast<ast::OpNameDecl>(Val: node)) {
220 if (std::optional<StringRef> name = decl->getName())
221 insertODSOpRef(*name, decl->getLoc());
222 } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(Val: node)) {
223 const ast::Name *name = decl->getName();
224 if (!name)
225 return;
226 PDLIndexSymbol *declSym = getOrInsertDef(decl);
227 insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
228
229 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl)) {
230 // Record references to any constraints.
231 for (const auto &it : varDecl->getConstraints())
232 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
233 }
234 } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(Val: node)) {
235 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
236 }
237 });
238}
239
240const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
241 SMRange *overlappedRange) const {
242 auto it = intervalMap.find(x: loc.getPointer());
243 if (!it.valid() || loc.getPointer() < it.start())
244 return nullptr;
245
246 if (overlappedRange) {
247 *overlappedRange = SMRange(SMLoc::getFromPointer(Ptr: it.start()),
248 SMLoc::getFromPointer(Ptr: it.stop()));
249 }
250 return it.value();
251}
252
253//===----------------------------------------------------------------------===//
254// PDLDocument
255//===----------------------------------------------------------------------===//
256
257namespace {
258/// This class represents all of the information pertaining to a specific PDL
259/// document.
260struct PDLDocument {
261 PDLDocument(const lsp::URIForFile &uri, StringRef contents,
262 const std::vector<std::string> &extraDirs,
263 std::vector<lsp::Diagnostic> &diagnostics);
264 PDLDocument(const PDLDocument &) = delete;
265 PDLDocument &operator=(const PDLDocument &) = delete;
266
267 //===--------------------------------------------------------------------===//
268 // Definitions and References
269 //===--------------------------------------------------------------------===//
270
271 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
272 std::vector<lsp::Location> &locations);
273 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
274 std::vector<lsp::Location> &references);
275
276 //===--------------------------------------------------------------------===//
277 // Document Links
278 //===--------------------------------------------------------------------===//
279
280 void getDocumentLinks(const lsp::URIForFile &uri,
281 std::vector<lsp::DocumentLink> &links);
282
283 //===--------------------------------------------------------------------===//
284 // Hover
285 //===--------------------------------------------------------------------===//
286
287 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
288 const lsp::Position &hoverPos);
289 std::optional<lsp::Hover> findHover(const ast::Decl *decl,
290 const SMRange &hoverRange);
291 lsp::Hover buildHoverForOpName(const ods::Operation *op,
292 const SMRange &hoverRange);
293 lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
294 const SMRange &hoverRange);
295 lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
296 const SMRange &hoverRange);
297 lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
298 const SMRange &hoverRange);
299 template <typename T>
300 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
301 const T *decl,
302 const SMRange &hoverRange);
303
304 //===--------------------------------------------------------------------===//
305 // Document Symbols
306 //===--------------------------------------------------------------------===//
307
308 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
309
310 //===--------------------------------------------------------------------===//
311 // Code Completion
312 //===--------------------------------------------------------------------===//
313
314 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
315 const lsp::Position &completePos);
316
317 //===--------------------------------------------------------------------===//
318 // Signature Help
319 //===--------------------------------------------------------------------===//
320
321 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
322 const lsp::Position &helpPos);
323
324 //===--------------------------------------------------------------------===//
325 // Inlay Hints
326 //===--------------------------------------------------------------------===//
327
328 void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range,
329 std::vector<lsp::InlayHint> &inlayHints);
330 void getInlayHintsFor(const ast::VariableDecl *decl,
331 const lsp::URIForFile &uri,
332 std::vector<lsp::InlayHint> &inlayHints);
333 void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri,
334 std::vector<lsp::InlayHint> &inlayHints);
335 void getInlayHintsFor(const ast::OperationExpr *expr,
336 const lsp::URIForFile &uri,
337 std::vector<lsp::InlayHint> &inlayHints);
338
339 /// Add a parameter hint for the given expression using `label`.
340 void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
341 const ast::Expr *expr, StringRef label);
342
343 //===--------------------------------------------------------------------===//
344 // PDLL ViewOutput
345 //===--------------------------------------------------------------------===//
346
347 void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
348
349 //===--------------------------------------------------------------------===//
350 // Fields
351 //===--------------------------------------------------------------------===//
352
353 /// The include directories for this file.
354 std::vector<std::string> includeDirs;
355
356 /// The source manager containing the contents of the input file.
357 llvm::SourceMgr sourceMgr;
358
359 /// The ODS and AST contexts.
360 ods::Context odsContext;
361 ast::Context astContext;
362
363 /// The parsed AST module, or failure if the file wasn't valid.
364 FailureOr<ast::Module *> astModule;
365
366 /// The index of the parsed module.
367 PDLIndex index;
368
369 /// The set of includes of the parsed module.
370 SmallVector<lsp::SourceMgrInclude> parsedIncludes;
371};
372} // namespace
373
374PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
375 const std::vector<std::string> &extraDirs,
376 std::vector<lsp::Diagnostic> &diagnostics)
377 : astContext(odsContext) {
378 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(InputData: contents, BufferName: uri.file());
379 if (!memBuffer) {
380 lsp::Logger::error(fmt: "Failed to create memory buffer for file", vals: uri.file());
381 return;
382 }
383
384 // Build the set of include directories for this file.
385 llvm::SmallString<32> uriDirectory(uri.file());
386 llvm::sys::path::remove_filename(path&: uriDirectory);
387 includeDirs.push_back(x: uriDirectory.str().str());
388 includeDirs.insert(position: includeDirs.end(), first: extraDirs.begin(), last: extraDirs.end());
389
390 sourceMgr.setIncludeDirs(includeDirs);
391 sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc());
392
393 astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
394 if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
395 diagnostics.push_back(x: std::move(*lspDiag));
396 });
397 astModule = parsePDLLAST(ctx&: astContext, sourceMgr, /*enableDocumentation=*/true);
398
399 // Initialize the set of parsed includes.
400 lsp::gatherIncludeFiles(sourceMgr, includes&: parsedIncludes);
401
402 // If we failed to parse the module, there is nothing left to initialize.
403 if (failed(result: astModule))
404 return;
405
406 // Prepare the AST index with the parsed module.
407 index.initialize(module: **astModule, odsContext);
408}
409
410//===----------------------------------------------------------------------===//
411// PDLDocument: Definitions and References
412//===----------------------------------------------------------------------===//
413
414void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
415 const lsp::Position &defPos,
416 std::vector<lsp::Location> &locations) {
417 SMLoc posLoc = defPos.getAsSMLoc(mgr&: sourceMgr);
418 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc);
419 if (!symbol)
420 return;
421
422 locations.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: symbol->getDefLoc(), uri));
423}
424
425void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
426 const lsp::Position &pos,
427 std::vector<lsp::Location> &references) {
428 SMLoc posLoc = pos.getAsSMLoc(mgr&: sourceMgr);
429 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc);
430 if (!symbol)
431 return;
432
433 references.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: symbol->getDefLoc(), uri));
434 for (SMRange refLoc : symbol->references)
435 references.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: refLoc, uri));
436}
437
438//===--------------------------------------------------------------------===//
439// PDLDocument: Document Links
440//===--------------------------------------------------------------------===//
441
442void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
443 std::vector<lsp::DocumentLink> &links) {
444 for (const lsp::SourceMgrInclude &include : parsedIncludes)
445 links.emplace_back(args: include.range, args: include.uri);
446}
447
448//===----------------------------------------------------------------------===//
449// PDLDocument: Hover
450//===----------------------------------------------------------------------===//
451
452std::optional<lsp::Hover>
453PDLDocument::findHover(const lsp::URIForFile &uri,
454 const lsp::Position &hoverPos) {
455 SMLoc posLoc = hoverPos.getAsSMLoc(mgr&: sourceMgr);
456
457 // Check for a reference to an include.
458 for (const lsp::SourceMgrInclude &include : parsedIncludes)
459 if (include.range.contains(pos: hoverPos))
460 return include.buildHover();
461
462 // Find the symbol at the given location.
463 SMRange hoverRange;
464 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc, overlappedRange: &hoverRange);
465 if (!symbol)
466 return std::nullopt;
467
468 // Add hover for operation names.
469 if (const auto *op = llvm::dyn_cast_if_present<const ods::Operation *>(Val: symbol->definition))
470 return buildHoverForOpName(op, hoverRange);
471 const auto *decl = symbol->definition.get<const ast::Decl *>();
472 return findHover(decl, hoverRange);
473}
474
475std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
476 const SMRange &hoverRange) {
477 // Add hover for variables.
478 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl))
479 return buildHoverForVariable(varDecl, hoverRange);
480
481 // Add hover for patterns.
482 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(Val: decl))
483 return buildHoverForPattern(decl: patternDecl, hoverRange);
484
485 // Add hover for core constraints.
486 if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(Val: decl))
487 return buildHoverForCoreConstraint(decl: cst, hoverRange);
488
489 // Add hover for user constraints.
490 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(Val: decl))
491 return buildHoverForUserConstraintOrRewrite(typeName: "Constraint", decl: cst, hoverRange);
492
493 // Add hover for user rewrites.
494 if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(Val: decl))
495 return buildHoverForUserConstraintOrRewrite(typeName: "Rewrite", decl: rewrite, hoverRange);
496
497 return std::nullopt;
498}
499
500lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
501 const SMRange &hoverRange) {
502 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
503 {
504 llvm::raw_string_ostream hoverOS(hover.contents.value);
505 hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
506 << op->getSummary() << "\n***\n"
507 << op->getDescription();
508 }
509 return hover;
510}
511
512lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
513 const SMRange &hoverRange) {
514 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
515 {
516 llvm::raw_string_ostream hoverOS(hover.contents.value);
517 hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
518 << "Type: `" << varDecl->getType() << "`\n";
519 }
520 return hover;
521}
522
523lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
524 const SMRange &hoverRange) {
525 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
526 {
527 llvm::raw_string_ostream hoverOS(hover.contents.value);
528 hoverOS << "**Pattern**";
529 if (const ast::Name *name = decl->getName())
530 hoverOS << ": `" << name->getName() << "`";
531 hoverOS << "\n***\n";
532 if (std::optional<uint16_t> benefit = decl->getBenefit())
533 hoverOS << "Benefit: " << *benefit << "\n";
534 if (decl->hasBoundedRewriteRecursion())
535 hoverOS << "HasBoundedRewriteRecursion\n";
536 hoverOS << "RootOp: `"
537 << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
538
539 // Format the documentation for the decl.
540 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
541 hoverOS << "\n" << *doc << "\n";
542 }
543 return hover;
544}
545
546lsp::Hover
547PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
548 const SMRange &hoverRange) {
549 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
550 {
551 llvm::raw_string_ostream hoverOS(hover.contents.value);
552 hoverOS << "**Constraint**: `";
553 TypeSwitch<const ast::Decl *>(decl)
554 .Case(caseFn: [&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
555 .Case(caseFn: [&](const ast::OpConstraintDecl *opCst) {
556 hoverOS << "Op";
557 if (std::optional<StringRef> name = opCst->getName())
558 hoverOS << "<" << *name << ">";
559 })
560 .Case(caseFn: [&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
561 .Case(caseFn: [&](const ast::TypeRangeConstraintDecl *) {
562 hoverOS << "TypeRange";
563 })
564 .Case(caseFn: [&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
565 .Case(caseFn: [&](const ast::ValueRangeConstraintDecl *) {
566 hoverOS << "ValueRange";
567 });
568 hoverOS << "`\n";
569 }
570 return hover;
571}
572
573template <typename T>
574lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
575 StringRef typeName, const T *decl, const SMRange &hoverRange) {
576 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
577 {
578 llvm::raw_string_ostream hoverOS(hover.contents.value);
579 hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
580 << "`\n***\n";
581 ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
582 if (!inputs.empty()) {
583 hoverOS << "Parameters:\n";
584 for (const ast::VariableDecl *input : inputs)
585 hoverOS << "* " << input->getName().getName() << ": `"
586 << input->getType() << "`\n";
587 hoverOS << "***\n";
588 }
589 ast::Type resultType = decl->getResultType();
590 if (auto resultTupleTy = resultType.dyn_cast<ast::TupleType>()) {
591 if (!resultTupleTy.empty()) {
592 hoverOS << "Results:\n";
593 for (auto it : llvm::zip(t: resultTupleTy.getElementNames(),
594 u: resultTupleTy.getElementTypes())) {
595 StringRef name = std::get<0>(t&: it);
596 hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
597 << std::get<1>(t&: it) << "`\n";
598 }
599 hoverOS << "***\n";
600 }
601 } else {
602 hoverOS << "Results:\n* `" << resultType << "`\n";
603 hoverOS << "***\n";
604 }
605
606 // Format the documentation for the decl.
607 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
608 hoverOS << "\n" << *doc << "\n";
609 }
610 return hover;
611}
612
613//===----------------------------------------------------------------------===//
614// PDLDocument: Document Symbols
615//===----------------------------------------------------------------------===//
616
617void PDLDocument::findDocumentSymbols(
618 std::vector<lsp::DocumentSymbol> &symbols) {
619 if (failed(result: astModule))
620 return;
621
622 for (const ast::Decl *decl : (*astModule)->getChildren()) {
623 if (!isMainFileLoc(mgr&: sourceMgr, loc: decl->getLoc()))
624 continue;
625
626 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(Val: decl)) {
627 const ast::Name *name = patternDecl->getName();
628
629 SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
630 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
631
632 symbols.emplace_back(
633 args: name ? name->getName() : "<pattern>", args: lsp::SymbolKind::Class,
634 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
635 } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(Val: decl)) {
636 // TODO: Add source information for the code block body.
637 SMRange nameLoc = cDecl->getName().getLoc();
638 SMRange bodyLoc = nameLoc;
639
640 symbols.emplace_back(
641 args: cDecl->getName().getName(), args: lsp::SymbolKind::Function,
642 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
643 } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(Val: decl)) {
644 // TODO: Add source information for the code block body.
645 SMRange nameLoc = cDecl->getName().getLoc();
646 SMRange bodyLoc = nameLoc;
647
648 symbols.emplace_back(
649 args: cDecl->getName().getName(), args: lsp::SymbolKind::Function,
650 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
651 }
652 }
653}
654
655//===----------------------------------------------------------------------===//
656// PDLDocument: Code Completion
657//===----------------------------------------------------------------------===//
658
659namespace {
660class LSPCodeCompleteContext : public CodeCompleteContext {
661public:
662 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
663 lsp::CompletionList &completionList,
664 ods::Context &odsContext,
665 ArrayRef<std::string> includeDirs)
666 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
667 completionList(completionList), odsContext(odsContext),
668 includeDirs(includeDirs) {}
669
670 void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
671 ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
672 ArrayRef<StringRef> elementNames = tupleType.getElementNames();
673 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
674 // Push back a completion item that uses the result index.
675 lsp::CompletionItem item;
676 item.label = llvm::formatv(Fmt: "{0} (field #{0})", Vals&: i).str();
677 item.insertText = Twine(i).str();
678 item.filterText = item.sortText = item.insertText;
679 item.kind = lsp::CompletionItemKind::Field;
680 item.detail = llvm::formatv(Fmt: "{0}: {1}", Vals&: i, Vals: elementTypes[i]);
681 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
682 completionList.items.emplace_back(args&: item);
683
684 // If the element has a name, push back a completion item with that name.
685 if (!elementNames[i].empty()) {
686 item.label =
687 llvm::formatv(Fmt: "{1} (field #{0})", Vals&: i, Vals: elementNames[i]).str();
688 item.filterText = item.label;
689 item.insertText = elementNames[i].str();
690 completionList.items.emplace_back(args&: item);
691 }
692 }
693 }
694
695 void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
696 const ods::Operation *odsOp = opType.getODSOperation();
697 if (!odsOp)
698 return;
699
700 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
701 for (const auto &it : llvm::enumerate(First&: results)) {
702 const ods::OperandOrResult &result = it.value();
703 const ods::TypeConstraint &constraint = result.getConstraint();
704
705 // Push back a completion item that uses the result index.
706 lsp::CompletionItem item;
707 item.label = llvm::formatv(Fmt: "{0} (field #{0})", Vals: it.index()).str();
708 item.insertText = Twine(it.index()).str();
709 item.filterText = item.sortText = item.insertText;
710 item.kind = lsp::CompletionItemKind::Field;
711 switch (result.getVariableLengthKind()) {
712 case ods::VariableLengthKind::Single:
713 item.detail = llvm::formatv(Fmt: "{0}: Value", Vals: it.index()).str();
714 break;
715 case ods::VariableLengthKind::Optional:
716 item.detail = llvm::formatv(Fmt: "{0}: Value?", Vals: it.index()).str();
717 break;
718 case ods::VariableLengthKind::Variadic:
719 item.detail = llvm::formatv(Fmt: "{0}: ValueRange", Vals: it.index()).str();
720 break;
721 }
722 item.documentation = lsp::MarkupContent{
723 .kind: lsp::MarkupKind::Markdown,
724 .value: llvm::formatv(Fmt: "{0}\n\n```c++\n{1}\n```\n", Vals: constraint.getSummary(),
725 Vals: constraint.getCppClass())
726 .str()};
727 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
728 completionList.items.emplace_back(args&: item);
729
730 // If the result has a name, push back a completion item with the result
731 // name.
732 if (!result.getName().empty()) {
733 item.label =
734 llvm::formatv(Fmt: "{1} (field #{0})", Vals: it.index(), Vals: result.getName())
735 .str();
736 item.filterText = item.label;
737 item.insertText = result.getName().str();
738 completionList.items.emplace_back(args&: item);
739 }
740 }
741 }
742
743 void codeCompleteOperationAttributeName(StringRef opName) final {
744 const ods::Operation *odsOp = odsContext.lookupOperation(name: opName);
745 if (!odsOp)
746 return;
747
748 for (const ods::Attribute &attr : odsOp->getAttributes()) {
749 const ods::AttributeConstraint &constraint = attr.getConstraint();
750
751 lsp::CompletionItem item;
752 item.label = attr.getName().str();
753 item.kind = lsp::CompletionItemKind::Field;
754 item.detail = attr.isOptional() ? "optional" : "";
755 item.documentation = lsp::MarkupContent{
756 .kind: lsp::MarkupKind::Markdown,
757 .value: llvm::formatv(Fmt: "{0}\n\n```c++\n{1}\n```\n", Vals: constraint.getSummary(),
758 Vals: constraint.getCppClass())
759 .str()};
760 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
761 completionList.items.emplace_back(args&: item);
762 }
763 }
764
765 void codeCompleteConstraintName(ast::Type currentType,
766 bool allowInlineTypeConstraints,
767 const ast::DeclScope *scope) final {
768 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
769 StringRef snippetText = "") {
770 lsp::CompletionItem item;
771 item.label = constraint.str();
772 item.kind = lsp::CompletionItemKind::Class;
773 item.detail = (constraint + " constraint").str();
774 item.documentation = lsp::MarkupContent{
775 .kind: lsp::MarkupKind::Markdown,
776 .value: ("A single entity core constraint of type `" + mlirType + "`").str()};
777 item.sortText = "0";
778 item.insertText = snippetText.str();
779 item.insertTextFormat = snippetText.empty()
780 ? lsp::InsertTextFormat::PlainText
781 : lsp::InsertTextFormat::Snippet;
782 completionList.items.emplace_back(args&: item);
783 };
784
785 // Insert completions for the core constraints. Some core constraints have
786 // additional characteristics, so we may add then even if a type has been
787 // inferred.
788 if (!currentType) {
789 addCoreConstraint("Attr", "mlir::Attribute");
790 addCoreConstraint("Op", "mlir::Operation *");
791 addCoreConstraint("Value", "mlir::Value");
792 addCoreConstraint("ValueRange", "mlir::ValueRange");
793 addCoreConstraint("Type", "mlir::Type");
794 addCoreConstraint("TypeRange", "mlir::TypeRange");
795 }
796 if (allowInlineTypeConstraints) {
797 /// Attr<Type>.
798 if (!currentType || currentType.isa<ast::AttributeType>())
799 addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
800 /// Value<Type>.
801 if (!currentType || currentType.isa<ast::ValueType>())
802 addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
803 /// ValueRange<TypeRange>.
804 if (!currentType || currentType.isa<ast::ValueRangeType>())
805 addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
806 "ValueRange<$1>");
807 }
808
809 // If a scope was provided, check it for potential constraints.
810 while (scope) {
811 for (const ast::Decl *decl : scope->getDecls()) {
812 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(Val: decl)) {
813 lsp::CompletionItem item;
814 item.label = cst->getName().getName().str();
815 item.kind = lsp::CompletionItemKind::Interface;
816 item.sortText = "2_" + item.label;
817
818 // Skip constraints that are not single-arg. We currently only
819 // complete variable constraints.
820 if (cst->getInputs().size() != 1)
821 continue;
822
823 // Ensure the input type matched the given type.
824 ast::Type constraintType = cst->getInputs()[0]->getType();
825 if (currentType && !currentType.refineWith(other: constraintType))
826 continue;
827
828 // Format the constraint signature.
829 {
830 llvm::raw_string_ostream strOS(item.detail);
831 strOS << "(";
832 llvm::interleaveComma(
833 c: cst->getInputs(), os&: strOS, each_fn: [&](const ast::VariableDecl *var) {
834 strOS << var->getName().getName() << ": " << var->getType();
835 });
836 strOS << ") -> " << cst->getResultType();
837 }
838
839 // Format the documentation for the constraint.
840 if (std::optional<std::string> doc =
841 getDocumentationFor(sourceMgr, decl: cst)) {
842 item.documentation =
843 lsp::MarkupContent{.kind: lsp::MarkupKind::Markdown, .value: std::move(*doc)};
844 }
845
846 completionList.items.emplace_back(args&: item);
847 }
848 }
849
850 scope = scope->getParentScope();
851 }
852 }
853
854 void codeCompleteDialectName() final {
855 // Code complete known dialects.
856 for (const ods::Dialect &dialect : odsContext.getDialects()) {
857 lsp::CompletionItem item;
858 item.label = dialect.getName().str();
859 item.kind = lsp::CompletionItemKind::Class;
860 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
861 completionList.items.emplace_back(args&: item);
862 }
863 }
864
865 void codeCompleteOperationName(StringRef dialectName) final {
866 const ods::Dialect *dialect = odsContext.lookupDialect(name: dialectName);
867 if (!dialect)
868 return;
869
870 for (const auto &it : dialect->getOperations()) {
871 const ods::Operation &op = *it.second;
872
873 lsp::CompletionItem item;
874 item.label = op.getName().drop_front(N: dialectName.size() + 1).str();
875 item.kind = lsp::CompletionItemKind::Field;
876 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
877 completionList.items.emplace_back(args&: item);
878 }
879 }
880
881 void codeCompletePatternMetadata() final {
882 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
883 StringRef snippetText = "") {
884 lsp::CompletionItem item;
885 item.label = constraint.str();
886 item.kind = lsp::CompletionItemKind::Class;
887 item.detail = "pattern metadata";
888 item.documentation =
889 lsp::MarkupContent{.kind: lsp::MarkupKind::Markdown, .value: desc.str()};
890 item.insertText = snippetText.str();
891 item.insertTextFormat = snippetText.empty()
892 ? lsp::InsertTextFormat::PlainText
893 : lsp::InsertTextFormat::Snippet;
894 completionList.items.emplace_back(args&: item);
895 };
896
897 addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
898 "benefit($1)");
899 addSimpleConstraint("recursion",
900 "The pattern properly handles recursive application.");
901 }
902
903 void codeCompleteIncludeFilename(StringRef curPath) final {
904 // Normalize the path to allow for interacting with the file system
905 // utilities.
906 SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(path: curPath));
907 llvm::sys::path::native(path&: nativeRelDir);
908
909 // Set of already included completion paths.
910 StringSet<> seenResults;
911
912 // Functor used to add a single include completion item.
913 auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
914 lsp::CompletionItem item;
915 item.label = path.str();
916 item.kind = isDirectory ? lsp::CompletionItemKind::Folder
917 : lsp::CompletionItemKind::File;
918 if (seenResults.insert(key: item.label).second)
919 completionList.items.emplace_back(args&: item);
920 };
921
922 // Process the include directories for this file, adding any potential
923 // nested include files or directories.
924 for (StringRef includeDir : includeDirs) {
925 llvm::SmallString<128> dir = includeDir;
926 if (!nativeRelDir.empty())
927 llvm::sys::path::append(path&: dir, a: nativeRelDir);
928
929 std::error_code errorCode;
930 for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
931 e = llvm::sys::fs::directory_iterator();
932 !errorCode && it != e; it.increment(ec&: errorCode)) {
933 StringRef filename = llvm::sys::path::filename(path: it->path());
934
935 // To know whether a symlink should be treated as file or a directory,
936 // we have to stat it. This should be cheap enough as there shouldn't be
937 // many symlinks.
938 llvm::sys::fs::file_type fileType = it->type();
939 if (fileType == llvm::sys::fs::file_type::symlink_file) {
940 if (auto fileStatus = it->status())
941 fileType = fileStatus->type();
942 }
943
944 switch (fileType) {
945 case llvm::sys::fs::file_type::directory_file:
946 addIncludeCompletion(filename, /*isDirectory=*/true);
947 break;
948 case llvm::sys::fs::file_type::regular_file: {
949 // Only consider concrete files that can actually be included by PDLL.
950 if (filename.ends_with(Suffix: ".pdll") || filename.ends_with(Suffix: ".td"))
951 addIncludeCompletion(filename, /*isDirectory=*/false);
952 break;
953 }
954 default:
955 break;
956 }
957 }
958 }
959
960 // Sort the completion results to make sure the output is deterministic in
961 // the face of different iteration schemes for different platforms.
962 llvm::sort(C&: completionList.items, Comp: [](const lsp::CompletionItem &lhs,
963 const lsp::CompletionItem &rhs) {
964 return lhs.label < rhs.label;
965 });
966 }
967
968private:
969 llvm::SourceMgr &sourceMgr;
970 lsp::CompletionList &completionList;
971 ods::Context &odsContext;
972 ArrayRef<std::string> includeDirs;
973};
974} // namespace
975
976lsp::CompletionList
977PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
978 const lsp::Position &completePos) {
979 SMLoc posLoc = completePos.getAsSMLoc(mgr&: sourceMgr);
980 if (!posLoc.isValid())
981 return lsp::CompletionList();
982
983 // To perform code completion, we run another parse of the module with the
984 // code completion context provided.
985 ods::Context tmpODSContext;
986 lsp::CompletionList completionList;
987 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
988 tmpODSContext,
989 sourceMgr.getIncludeDirs());
990
991 ast::Context tmpContext(tmpODSContext);
992 (void)parsePDLLAST(ctx&: tmpContext, sourceMgr, /*enableDocumentation=*/true,
993 codeCompleteContext: &lspCompleteContext);
994
995 return completionList;
996}
997
998//===----------------------------------------------------------------------===//
999// PDLDocument: Signature Help
1000//===----------------------------------------------------------------------===//
1001
1002namespace {
1003class LSPSignatureHelpContext : public CodeCompleteContext {
1004public:
1005 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1006 lsp::SignatureHelp &signatureHelp,
1007 ods::Context &odsContext)
1008 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1009 signatureHelp(signatureHelp), odsContext(odsContext) {}
1010
1011 void codeCompleteCallSignature(const ast::CallableDecl *callable,
1012 unsigned currentNumArgs) final {
1013 signatureHelp.activeParameter = currentNumArgs;
1014
1015 lsp::SignatureInformation signatureInfo;
1016 {
1017 llvm::raw_string_ostream strOS(signatureInfo.label);
1018 strOS << callable->getName()->getName() << "(";
1019 auto formatParamFn = [&](const ast::VariableDecl *var) {
1020 unsigned paramStart = strOS.str().size();
1021 strOS << var->getName().getName() << ": " << var->getType();
1022 unsigned paramEnd = strOS.str().size();
1023 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1024 .labelString: StringRef(strOS.str()).slice(Start: paramStart, End: paramEnd).str(),
1025 .labelOffsets: std::make_pair(x&: paramStart, y&: paramEnd), /*paramDoc*/ .documentation: std::string()});
1026 };
1027 llvm::interleaveComma(c: callable->getInputs(), os&: strOS, each_fn: formatParamFn);
1028 strOS << ") -> " << callable->getResultType();
1029 }
1030
1031 // Format the documentation for the callable.
1032 if (std::optional<std::string> doc =
1033 getDocumentationFor(sourceMgr, decl: callable))
1034 signatureInfo.documentation = std::move(*doc);
1035
1036 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1037 }
1038
1039 void
1040 codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
1041 unsigned currentNumOperands) final {
1042 const ods::Operation *odsOp =
1043 opName ? odsContext.lookupOperation(name: *opName) : nullptr;
1044 codeCompleteOperationOperandOrResultSignature(
1045 opName, odsOp, values: odsOp ? odsOp->getOperands() : std::nullopt,
1046 currentValue: currentNumOperands, label: "operand", dataType: "Value");
1047 }
1048
1049 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
1050 unsigned currentNumResults) final {
1051 const ods::Operation *odsOp =
1052 opName ? odsContext.lookupOperation(name: *opName) : nullptr;
1053 codeCompleteOperationOperandOrResultSignature(
1054 opName, odsOp, values: odsOp ? odsOp->getResults() : std::nullopt,
1055 currentValue: currentNumResults, label: "result", dataType: "Type");
1056 }
1057
1058 void codeCompleteOperationOperandOrResultSignature(
1059 std::optional<StringRef> opName, const ods::Operation *odsOp,
1060 ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1061 StringRef label, StringRef dataType) {
1062 signatureHelp.activeParameter = currentValue;
1063
1064 // If we have ODS information for the operation, add in the ODS signature
1065 // for the operation. We also verify that the current number of values is
1066 // not more than what is defined in ODS, as this will result in an error
1067 // anyways.
1068 if (odsOp && currentValue < values.size()) {
1069 lsp::SignatureInformation signatureInfo;
1070
1071 // Build the signature label.
1072 {
1073 llvm::raw_string_ostream strOS(signatureInfo.label);
1074 strOS << "(";
1075 auto formatFn = [&](const ods::OperandOrResult &value) {
1076 unsigned paramStart = strOS.str().size();
1077
1078 strOS << value.getName() << ": ";
1079
1080 StringRef constraintDoc = value.getConstraint().getSummary();
1081 std::string paramDoc;
1082 switch (value.getVariableLengthKind()) {
1083 case ods::VariableLengthKind::Single:
1084 strOS << dataType;
1085 paramDoc = constraintDoc.str();
1086 break;
1087 case ods::VariableLengthKind::Optional:
1088 strOS << dataType << "?";
1089 paramDoc = ("optional: " + constraintDoc).str();
1090 break;
1091 case ods::VariableLengthKind::Variadic:
1092 strOS << dataType << "Range";
1093 paramDoc = ("variadic: " + constraintDoc).str();
1094 break;
1095 }
1096
1097 unsigned paramEnd = strOS.str().size();
1098 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1099 .labelString: StringRef(strOS.str()).slice(Start: paramStart, End: paramEnd).str(),
1100 .labelOffsets: std::make_pair(x&: paramStart, y&: paramEnd), .documentation: paramDoc});
1101 };
1102 llvm::interleaveComma(c: values, os&: strOS, each_fn: formatFn);
1103 strOS << ")";
1104 }
1105 signatureInfo.documentation =
1106 llvm::formatv(Fmt: "`op<{0}>` ODS {1} specification", Vals&: *opName, Vals&: label)
1107 .str();
1108 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1109 }
1110
1111 // If there aren't any arguments yet, we also add the generic signature.
1112 if (currentValue == 0 && (!odsOp || !values.empty())) {
1113 lsp::SignatureInformation signatureInfo;
1114 signatureInfo.label =
1115 llvm::formatv(Fmt: "(<{0}s>: {1}Range)", Vals&: label, Vals&: dataType).str();
1116 signatureInfo.documentation =
1117 ("Generic operation " + label + " specification").str();
1118 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1119 .labelString: StringRef(signatureInfo.label).drop_front().drop_back().str(),
1120 .labelOffsets: std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1121 .documentation: ("All of the " + label + "s of the operation.").str()});
1122 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1123 }
1124 }
1125
1126private:
1127 llvm::SourceMgr &sourceMgr;
1128 lsp::SignatureHelp &signatureHelp;
1129 ods::Context &odsContext;
1130};
1131} // namespace
1132
1133lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
1134 const lsp::Position &helpPos) {
1135 SMLoc posLoc = helpPos.getAsSMLoc(mgr&: sourceMgr);
1136 if (!posLoc.isValid())
1137 return lsp::SignatureHelp();
1138
1139 // To perform code completion, we run another parse of the module with the
1140 // code completion context provided.
1141 ods::Context tmpODSContext;
1142 lsp::SignatureHelp signatureHelp;
1143 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1144 tmpODSContext);
1145
1146 ast::Context tmpContext(tmpODSContext);
1147 (void)parsePDLLAST(ctx&: tmpContext, sourceMgr, /*enableDocumentation=*/true,
1148 codeCompleteContext: &completeContext);
1149
1150 return signatureHelp;
1151}
1152
1153//===----------------------------------------------------------------------===//
1154// PDLDocument: Inlay Hints
1155//===----------------------------------------------------------------------===//
1156
1157/// Returns true if the given name should be added as a hint for `expr`.
1158static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1159 if (name.empty())
1160 return false;
1161
1162 // If the argument is a reference of the same name, don't add it as a hint.
1163 if (auto *ref = dyn_cast<ast::DeclRefExpr>(Val: expr)) {
1164 const ast::Name *declName = ref->getDecl()->getName();
1165 if (declName && declName->getName() == name)
1166 return false;
1167 }
1168
1169 return true;
1170}
1171
1172void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
1173 const lsp::Range &range,
1174 std::vector<lsp::InlayHint> &inlayHints) {
1175 if (failed(result: astModule))
1176 return;
1177 SMRange rangeLoc = range.getAsSMRange(mgr&: sourceMgr);
1178 if (!rangeLoc.isValid())
1179 return;
1180 (*astModule)->walk(walkFn: [&](const ast::Node *node) {
1181 SMRange loc = node->getLoc();
1182
1183 // Check that the location of this node is within the input range.
1184 if (!lsp::contains(range: rangeLoc, loc: loc.Start) &&
1185 !lsp::contains(range: rangeLoc, loc: loc.End))
1186 return;
1187
1188 // Handle hints for various types of nodes.
1189 llvm::TypeSwitch<const ast::Node *>(node)
1190 .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1191 caseFn: [&](const auto *node) {
1192 this->getInlayHintsFor(node, uri, inlayHints);
1193 });
1194 });
1195}
1196
1197void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
1198 const lsp::URIForFile &uri,
1199 std::vector<lsp::InlayHint> &inlayHints) {
1200 // Check to see if the variable has a constraint list, if it does we don't
1201 // provide initializer hints.
1202 if (!decl->getConstraints().empty())
1203 return;
1204
1205 // Check to see if the variable has an initializer.
1206 if (const ast::Expr *expr = decl->getInitExpr()) {
1207 // Don't add hints for operation expression initialized variables given that
1208 // the type of the variable is easily inferred by the expression operation
1209 // name.
1210 if (isa<ast::OperationExpr>(Val: expr))
1211 return;
1212 }
1213
1214 lsp::InlayHint hint(lsp::InlayHintKind::Type,
1215 lsp::Position(sourceMgr, decl->getLoc().End));
1216 {
1217 llvm::raw_string_ostream labelOS(hint.label);
1218 labelOS << ": " << decl->getType();
1219 }
1220
1221 inlayHints.emplace_back(args: std::move(hint));
1222}
1223
1224void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
1225 const lsp::URIForFile &uri,
1226 std::vector<lsp::InlayHint> &inlayHints) {
1227 // Try to extract the callable of this call.
1228 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(Val: expr->getCallableExpr());
1229 const auto *callable =
1230 callableRef ? dyn_cast<ast::CallableDecl>(Val: callableRef->getDecl())
1231 : nullptr;
1232 if (!callable)
1233 return;
1234
1235 // Add hints for the arguments to the call.
1236 for (const auto &it : llvm::zip(t: expr->getArguments(), u: callable->getInputs()))
1237 addParameterHintFor(inlayHints, expr: std::get<0>(t: it),
1238 label: std::get<1>(t: it)->getName().getName());
1239}
1240
1241void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
1242 const lsp::URIForFile &uri,
1243 std::vector<lsp::InlayHint> &inlayHints) {
1244 // Check for ODS information.
1245 ast::OperationType opType = expr->getType().dyn_cast<ast::OperationType>();
1246 const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1247
1248 auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1249 // If the value expression used the same location as the operation, don't
1250 // add a hint. This expression was materialized during parsing.
1251 if (expr->getLoc().Start == valueExpr->getLoc().Start)
1252 return;
1253 addParameterHintFor(inlayHints, expr: valueExpr, label);
1254 };
1255
1256 // Functor used to process hints for the operands and results of the
1257 // operation. They effectively have the same format, and thus can be processed
1258 // using the same logic.
1259 auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1260 ArrayRef<ods::OperandOrResult> odsValues,
1261 StringRef allValuesName) {
1262 if (values.empty())
1263 return;
1264
1265 // The values should either map to a single range, or be equivalent to the
1266 // ODS values.
1267 if (values.size() != odsValues.size()) {
1268 // Handle the case of a single element that covers the full range.
1269 if (values.size() == 1)
1270 return addOpHint(values.front(), allValuesName);
1271 return;
1272 }
1273
1274 for (const auto &it : llvm::zip(t&: values, u&: odsValues))
1275 addOpHint(std::get<0>(t: it), std::get<1>(t: it).getName());
1276 };
1277
1278 // Add hints for the operands and results of the operation.
1279 addOperandOrResultHints(expr->getOperands(),
1280 odsOp ? odsOp->getOperands()
1281 : ArrayRef<ods::OperandOrResult>(),
1282 "operands");
1283 addOperandOrResultHints(expr->getResultTypes(),
1284 odsOp ? odsOp->getResults()
1285 : ArrayRef<ods::OperandOrResult>(),
1286 "results");
1287}
1288
1289void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1290 const ast::Expr *expr, StringRef label) {
1291 if (!shouldAddHintFor(expr, name: label))
1292 return;
1293
1294 lsp::InlayHint hint(lsp::InlayHintKind::Parameter,
1295 lsp::Position(sourceMgr, expr->getLoc().Start));
1296 hint.label = (label + ":").str();
1297 hint.paddingRight = true;
1298 inlayHints.emplace_back(args: std::move(hint));
1299}
1300
1301//===----------------------------------------------------------------------===//
1302// PDLL ViewOutput
1303//===----------------------------------------------------------------------===//
1304
1305void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1306 lsp::PDLLViewOutputKind kind) {
1307 if (failed(result: astModule))
1308 return;
1309 if (kind == lsp::PDLLViewOutputKind::AST) {
1310 (*astModule)->print(os);
1311 return;
1312 }
1313
1314 // Generate the MLIR for the ast module. We also capture diagnostics here to
1315 // show to the user, which may be useful if PDLL isn't capturing constraints
1316 // expected by PDL.
1317 MLIRContext mlirContext;
1318 SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1319 OwningOpRef<ModuleOp> pdlModule =
1320 codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
1321 if (!pdlModule)
1322 return;
1323 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1324 pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1325 return;
1326 }
1327
1328 // Otherwise, generate the output for C++.
1329 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1330 "unexpected PDLLViewOutputKind");
1331 codegenPDLLToCPP(**astModule, *pdlModule, os);
1332}
1333
1334//===----------------------------------------------------------------------===//
1335// PDLTextFileChunk
1336//===----------------------------------------------------------------------===//
1337
1338namespace {
1339/// This class represents a single chunk of an PDL text file.
1340struct PDLTextFileChunk {
1341 PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
1342 StringRef contents,
1343 const std::vector<std::string> &extraDirs,
1344 std::vector<lsp::Diagnostic> &diagnostics)
1345 : lineOffset(lineOffset),
1346 document(uri, contents, extraDirs, diagnostics) {}
1347
1348 /// Adjust the line number of the given range to anchor at the beginning of
1349 /// the file, instead of the beginning of this chunk.
1350 void adjustLocForChunkOffset(lsp::Range &range) {
1351 adjustLocForChunkOffset(pos&: range.start);
1352 adjustLocForChunkOffset(pos&: range.end);
1353 }
1354 /// Adjust the line number of the given position to anchor at the beginning of
1355 /// the file, instead of the beginning of this chunk.
1356 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1357
1358 /// The line offset of this chunk from the beginning of the file.
1359 uint64_t lineOffset;
1360 /// The document referred to by this chunk.
1361 PDLDocument document;
1362};
1363} // namespace
1364
1365//===----------------------------------------------------------------------===//
1366// PDLTextFile
1367//===----------------------------------------------------------------------===//
1368
1369namespace {
1370/// This class represents a text file containing one or more PDL documents.
1371class PDLTextFile {
1372public:
1373 PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1374 int64_t version, const std::vector<std::string> &extraDirs,
1375 std::vector<lsp::Diagnostic> &diagnostics);
1376
1377 /// Return the current version of this text file.
1378 int64_t getVersion() const { return version; }
1379
1380 /// Update the file to the new version using the provided set of content
1381 /// changes. Returns failure if the update was unsuccessful.
1382 LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
1383 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1384 std::vector<lsp::Diagnostic> &diagnostics);
1385
1386 //===--------------------------------------------------------------------===//
1387 // LSP Queries
1388 //===--------------------------------------------------------------------===//
1389
1390 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1391 std::vector<lsp::Location> &locations);
1392 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1393 std::vector<lsp::Location> &references);
1394 void getDocumentLinks(const lsp::URIForFile &uri,
1395 std::vector<lsp::DocumentLink> &links);
1396 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1397 lsp::Position hoverPos);
1398 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1399 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1400 lsp::Position completePos);
1401 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
1402 lsp::Position helpPos);
1403 void getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1404 std::vector<lsp::InlayHint> &inlayHints);
1405 lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
1406
1407private:
1408 using ChunkIterator = llvm::pointee_iterator<
1409 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1410
1411 /// Initialize the text file from the given file contents.
1412 void initialize(const lsp::URIForFile &uri, int64_t newVersion,
1413 std::vector<lsp::Diagnostic> &diagnostics);
1414
1415 /// Find the PDL document that contains the given position, and update the
1416 /// position to be anchored at the start of the found chunk instead of the
1417 /// beginning of the file.
1418 ChunkIterator getChunkItFor(lsp::Position &pos);
1419 PDLTextFileChunk &getChunkFor(lsp::Position &pos) {
1420 return *getChunkItFor(pos);
1421 }
1422
1423 /// The full string contents of the file.
1424 std::string contents;
1425
1426 /// The version of this file.
1427 int64_t version = 0;
1428
1429 /// The number of lines in the file.
1430 int64_t totalNumLines = 0;
1431
1432 /// The chunks of this file. The order of these chunks is the order in which
1433 /// they appear in the text file.
1434 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1435
1436 /// The extra set of include directories for this file.
1437 std::vector<std::string> extraIncludeDirs;
1438};
1439} // namespace
1440
1441PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1442 int64_t version,
1443 const std::vector<std::string> &extraDirs,
1444 std::vector<lsp::Diagnostic> &diagnostics)
1445 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1446 initialize(uri, newVersion: version, diagnostics);
1447}
1448
1449LogicalResult
1450PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
1451 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1452 std::vector<lsp::Diagnostic> &diagnostics) {
1453 if (failed(result: lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1454 lsp::Logger::error(fmt: "Failed to update contents of {0}", vals: uri.file());
1455 return failure();
1456 }
1457
1458 // If the file contents were properly changed, reinitialize the text file.
1459 initialize(uri, newVersion, diagnostics);
1460 return success();
1461}
1462
1463void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
1464 lsp::Position defPos,
1465 std::vector<lsp::Location> &locations) {
1466 PDLTextFileChunk &chunk = getChunkFor(pos&: defPos);
1467 chunk.document.getLocationsOf(uri, defPos, locations);
1468
1469 // Adjust any locations within this file for the offset of this chunk.
1470 if (chunk.lineOffset == 0)
1471 return;
1472 for (lsp::Location &loc : locations)
1473 if (loc.uri == uri)
1474 chunk.adjustLocForChunkOffset(range&: loc.range);
1475}
1476
1477void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
1478 lsp::Position pos,
1479 std::vector<lsp::Location> &references) {
1480 PDLTextFileChunk &chunk = getChunkFor(pos);
1481 chunk.document.findReferencesOf(uri, pos, references);
1482
1483 // Adjust any locations within this file for the offset of this chunk.
1484 if (chunk.lineOffset == 0)
1485 return;
1486 for (lsp::Location &loc : references)
1487 if (loc.uri == uri)
1488 chunk.adjustLocForChunkOffset(range&: loc.range);
1489}
1490
1491void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
1492 std::vector<lsp::DocumentLink> &links) {
1493 chunks.front()->document.getDocumentLinks(uri, links);
1494 for (const auto &it : llvm::drop_begin(RangeOrContainer&: chunks)) {
1495 size_t currentNumLinks = links.size();
1496 it->document.getDocumentLinks(uri, links);
1497
1498 // Adjust any links within this file to account for the offset of this
1499 // chunk.
1500 for (auto &link : llvm::drop_begin(RangeOrContainer&: links, N: currentNumLinks))
1501 it->adjustLocForChunkOffset(range&: link.range);
1502 }
1503}
1504
1505std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
1506 lsp::Position hoverPos) {
1507 PDLTextFileChunk &chunk = getChunkFor(pos&: hoverPos);
1508 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1509
1510 // Adjust any locations within this file for the offset of this chunk.
1511 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1512 chunk.adjustLocForChunkOffset(range&: *hoverInfo->range);
1513 return hoverInfo;
1514}
1515
1516void PDLTextFile::findDocumentSymbols(
1517 std::vector<lsp::DocumentSymbol> &symbols) {
1518 if (chunks.size() == 1)
1519 return chunks.front()->document.findDocumentSymbols(symbols);
1520
1521 // If there are multiple chunks in this file, we create top-level symbols for
1522 // each chunk.
1523 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1524 PDLTextFileChunk &chunk = *chunks[i];
1525 lsp::Position startPos(chunk.lineOffset);
1526 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1527 : chunks[i + 1]->lineOffset);
1528 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1529 lsp::SymbolKind::Namespace,
1530 /*range=*/lsp::Range(startPos, endPos),
1531 /*selectionRange=*/lsp::Range(startPos));
1532 chunk.document.findDocumentSymbols(symbols&: symbol.children);
1533
1534 // Fixup the locations of document symbols within this chunk.
1535 if (i != 0) {
1536 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1537 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1538 symbolsToFix.push_back(Elt: &childSymbol);
1539
1540 while (!symbolsToFix.empty()) {
1541 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1542 chunk.adjustLocForChunkOffset(range&: symbol->range);
1543 chunk.adjustLocForChunkOffset(range&: symbol->selectionRange);
1544
1545 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1546 symbolsToFix.push_back(Elt: &childSymbol);
1547 }
1548 }
1549
1550 // Push the symbol for this chunk.
1551 symbols.emplace_back(args: std::move(symbol));
1552 }
1553}
1554
1555lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1556 lsp::Position completePos) {
1557 PDLTextFileChunk &chunk = getChunkFor(pos&: completePos);
1558 lsp::CompletionList completionList =
1559 chunk.document.getCodeCompletion(uri, completePos);
1560
1561 // Adjust any completion locations.
1562 for (lsp::CompletionItem &item : completionList.items) {
1563 if (item.textEdit)
1564 chunk.adjustLocForChunkOffset(range&: item.textEdit->range);
1565 for (lsp::TextEdit &edit : item.additionalTextEdits)
1566 chunk.adjustLocForChunkOffset(range&: edit.range);
1567 }
1568 return completionList;
1569}
1570
1571lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
1572 lsp::Position helpPos) {
1573 return getChunkFor(pos&: helpPos).document.getSignatureHelp(uri, helpPos);
1574}
1575
1576void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1577 std::vector<lsp::InlayHint> &inlayHints) {
1578 auto startIt = getChunkItFor(pos&: range.start);
1579 auto endIt = getChunkItFor(pos&: range.end);
1580
1581 // Functor used to get the chunks for a given file, and fixup any locations
1582 auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) {
1583 size_t currentNumHints = inlayHints.size();
1584 chunkIt->document.getInlayHints(uri, range, inlayHints);
1585
1586 // If this isn't the first chunk, update any positions to account for line
1587 // number differences.
1588 if (&*chunkIt != &*chunks.front()) {
1589 for (auto &hint : llvm::drop_begin(RangeOrContainer&: inlayHints, N: currentNumHints))
1590 chunkIt->adjustLocForChunkOffset(pos&: hint.position);
1591 }
1592 };
1593 // Returns the number of lines held by a given chunk.
1594 auto getNumLines = [](ChunkIterator chunkIt) {
1595 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1596 };
1597
1598 // Check if the range is fully within a single chunk.
1599 if (startIt == endIt)
1600 return getHintsForChunk(startIt, range);
1601
1602 // Otherwise, the range is split between multiple chunks. The first chunk
1603 // has the correct range start, but covers the total document.
1604 getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt)));
1605
1606 // Every chunk in between uses the full document.
1607 for (++startIt; startIt != endIt; ++startIt)
1608 getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt)));
1609
1610 // The range for the last chunk starts at the beginning of the document, up
1611 // through the end of the input range.
1612 getHintsForChunk(startIt, lsp::Range(0, range.end));
1613}
1614
1615lsp::PDLLViewOutputResult
1616PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1617 lsp::PDLLViewOutputResult result;
1618 {
1619 llvm::raw_string_ostream outputOS(result.output);
1620 llvm::interleave(
1621 c: llvm::make_pointee_range(Range&: chunks),
1622 each_fn: [&](PDLTextFileChunk &chunk) {
1623 chunk.document.getPDLLViewOutput(os&: outputOS, kind);
1624 },
1625 between_fn: [&] { outputOS << "\n"
1626 << kDefaultSplitMarker << "\n\n"; });
1627 }
1628 return result;
1629}
1630
1631void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
1632 std::vector<lsp::Diagnostic> &diagnostics) {
1633 version = newVersion;
1634 chunks.clear();
1635
1636 // Split the file into separate PDL documents.
1637 SmallVector<StringRef, 8> subContents;
1638 StringRef(contents).split(A&: subContents, Separator: kDefaultSplitMarker);
1639 chunks.emplace_back(args: std::make_unique<PDLTextFileChunk>(
1640 /*lineOffset=*/args: 0, args: uri, args&: subContents.front(), args&: extraIncludeDirs,
1641 args&: diagnostics));
1642
1643 uint64_t lineOffset = subContents.front().count(C: '\n');
1644 for (StringRef docContents : llvm::drop_begin(RangeOrContainer&: subContents)) {
1645 unsigned currentNumDiags = diagnostics.size();
1646 auto chunk = std::make_unique<PDLTextFileChunk>(
1647 args&: lineOffset, args: uri, args&: docContents, args&: extraIncludeDirs, args&: diagnostics);
1648 lineOffset += docContents.count(C: '\n');
1649
1650 // Adjust locations used in diagnostics to account for the offset from the
1651 // beginning of the file.
1652 for (lsp::Diagnostic &diag :
1653 llvm::drop_begin(RangeOrContainer&: diagnostics, N: currentNumDiags)) {
1654 chunk->adjustLocForChunkOffset(range&: diag.range);
1655
1656 if (!diag.relatedInformation)
1657 continue;
1658 for (auto &it : *diag.relatedInformation)
1659 if (it.location.uri == uri)
1660 chunk->adjustLocForChunkOffset(range&: it.location.range);
1661 }
1662 chunks.emplace_back(args: std::move(chunk));
1663 }
1664 totalNumLines = lineOffset;
1665}
1666
1667PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) {
1668 if (chunks.size() == 1)
1669 return chunks.begin();
1670
1671 // Search for the first chunk with a greater line offset, the previous chunk
1672 // is the one that contains `pos`.
1673 auto it = llvm::upper_bound(
1674 Range&: chunks, Value&: pos, C: [](const lsp::Position &pos, const auto &chunk) {
1675 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1676 });
1677 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1678 pos.line -= chunkIt->lineOffset;
1679 return chunkIt;
1680}
1681
1682//===----------------------------------------------------------------------===//
1683// PDLLServer::Impl
1684//===----------------------------------------------------------------------===//
1685
1686struct lsp::PDLLServer::Impl {
1687 explicit Impl(const Options &options)
1688 : options(options), compilationDatabase(options.compilationDatabases) {}
1689
1690 /// PDLL LSP options.
1691 const Options &options;
1692
1693 /// The compilation database containing additional information for files
1694 /// passed to the server.
1695 lsp::CompilationDatabase compilationDatabase;
1696
1697 /// The files held by the server, mapped by their URI file name.
1698 llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1699};
1700
1701//===----------------------------------------------------------------------===//
1702// PDLLServer
1703//===----------------------------------------------------------------------===//
1704
1705lsp::PDLLServer::PDLLServer(const Options &options)
1706 : impl(std::make_unique<Impl>(args: options)) {}
1707lsp::PDLLServer::~PDLLServer() = default;
1708
1709void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
1710 int64_t version,
1711 std::vector<Diagnostic> &diagnostics) {
1712 // Build the set of additional include directories.
1713 std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1714 const auto &fileInfo = impl->compilationDatabase.getFileInfo(filename: uri.file());
1715 llvm::append_range(C&: additionalIncludeDirs, R: fileInfo.includeDirs);
1716
1717 impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1718 args: uri, args&: contents, args&: version, args&: additionalIncludeDirs, args&: diagnostics);
1719}
1720
1721void lsp::PDLLServer::updateDocument(
1722 const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1723 int64_t version, std::vector<Diagnostic> &diagnostics) {
1724 // Check that we actually have a document for this uri.
1725 auto it = impl->files.find(Key: uri.file());
1726 if (it == impl->files.end())
1727 return;
1728
1729 // Try to update the document. If we fail, erase the file from the server. A
1730 // failed updated generally means we've fallen out of sync somewhere.
1731 if (failed(result: it->second->update(uri, newVersion: version, changes, diagnostics)))
1732 impl->files.erase(I: it);
1733}
1734
1735std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1736 auto it = impl->files.find(Key: uri.file());
1737 if (it == impl->files.end())
1738 return std::nullopt;
1739
1740 int64_t version = it->second->getVersion();
1741 impl->files.erase(I: it);
1742 return version;
1743}
1744
1745void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
1746 const Position &defPos,
1747 std::vector<Location> &locations) {
1748 auto fileIt = impl->files.find(Key: uri.file());
1749 if (fileIt != impl->files.end())
1750 fileIt->second->getLocationsOf(uri, defPos, locations);
1751}
1752
1753void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
1754 const Position &pos,
1755 std::vector<Location> &references) {
1756 auto fileIt = impl->files.find(Key: uri.file());
1757 if (fileIt != impl->files.end())
1758 fileIt->second->findReferencesOf(uri, pos, references);
1759}
1760
1761void lsp::PDLLServer::getDocumentLinks(
1762 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1763 auto fileIt = impl->files.find(Key: uri.file());
1764 if (fileIt != impl->files.end())
1765 return fileIt->second->getDocumentLinks(uri, links&: documentLinks);
1766}
1767
1768std::optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
1769 const Position &hoverPos) {
1770 auto fileIt = impl->files.find(Key: uri.file());
1771 if (fileIt != impl->files.end())
1772 return fileIt->second->findHover(uri, hoverPos);
1773 return std::nullopt;
1774}
1775
1776void lsp::PDLLServer::findDocumentSymbols(
1777 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1778 auto fileIt = impl->files.find(Key: uri.file());
1779 if (fileIt != impl->files.end())
1780 fileIt->second->findDocumentSymbols(symbols);
1781}
1782
1783lsp::CompletionList
1784lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
1785 const Position &completePos) {
1786 auto fileIt = impl->files.find(Key: uri.file());
1787 if (fileIt != impl->files.end())
1788 return fileIt->second->getCodeCompletion(uri, completePos);
1789 return CompletionList();
1790}
1791
1792lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
1793 const Position &helpPos) {
1794 auto fileIt = impl->files.find(Key: uri.file());
1795 if (fileIt != impl->files.end())
1796 return fileIt->second->getSignatureHelp(uri, helpPos);
1797 return SignatureHelp();
1798}
1799
1800void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1801 std::vector<InlayHint> &inlayHints) {
1802 auto fileIt = impl->files.find(Key: uri.file());
1803 if (fileIt == impl->files.end())
1804 return;
1805 fileIt->second->getInlayHints(uri, range, inlayHints);
1806
1807 // Drop any duplicated hints that may have cropped up.
1808 llvm::sort(C&: inlayHints);
1809 inlayHints.erase(first: std::unique(first: inlayHints.begin(), last: inlayHints.end()),
1810 last: inlayHints.end());
1811}
1812
1813std::optional<lsp::PDLLViewOutputResult>
1814lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri,
1815 PDLLViewOutputKind kind) {
1816 auto fileIt = impl->files.find(Key: uri.file());
1817 if (fileIt != impl->files.end())
1818 return fileIt->second->getPDLLViewOutput(kind);
1819 return std::nullopt;
1820}
1821

source code of mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp