1//===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PDLLServer.h"
10
11#include "Protocol.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/Support/ToolUtilities.h"
14#include "mlir/Tools/PDLL/AST/Context.h"
15#include "mlir/Tools/PDLL/AST/Nodes.h"
16#include "mlir/Tools/PDLL/AST/Types.h"
17#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
18#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
19#include "mlir/Tools/PDLL/ODS/Constraint.h"
20#include "mlir/Tools/PDLL/ODS/Context.h"
21#include "mlir/Tools/PDLL/ODS/Dialect.h"
22#include "mlir/Tools/PDLL/ODS/Operation.h"
23#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
24#include "mlir/Tools/PDLL/Parser/Parser.h"
25#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
26#include "mlir/Tools/lsp-server-support/Logging.h"
27#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
28#include "llvm/ADT/IntervalMap.h"
29#include "llvm/ADT/StringMap.h"
30#include "llvm/ADT/StringSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/FileSystem.h"
33#include "llvm/Support/Path.h"
34#include <optional>
35
36using namespace mlir;
37using namespace mlir::pdll;
38
39/// Returns a language server uri for the given source location. `mainFileURI`
40/// corresponds to the uri for the main file of the source manager.
41static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
42 const lsp::URIForFile &mainFileURI) {
43 int bufferId = mgr.FindBufferContainingLoc(Loc: loc.Start);
44 if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
45 return mainFileURI;
46 llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
47 absoluteFilepath: mgr.getBufferInfo(i: bufferId).Buffer->getBufferIdentifier());
48 if (fileForLoc)
49 return *fileForLoc;
50 lsp::Logger::error(fmt: "Failed to create URI for include file: {0}",
51 vals: llvm::toString(E: fileForLoc.takeError()));
52 return mainFileURI;
53}
54
55/// Returns true if the given location is in the main file of the source
56/// manager.
57static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
58 return mgr.FindBufferContainingLoc(Loc: loc.Start) == mgr.getMainFileID();
59}
60
61/// Returns a language server location from the given source range.
62static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
63 const lsp::URIForFile &uri) {
64 return lsp::Location(getURIFromLoc(mgr, loc: range, mainFileURI: uri), lsp::Range(mgr, range));
65}
66
67/// Convert the given MLIR diagnostic to the LSP form.
68static std::optional<lsp::Diagnostic>
69getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
70 const lsp::URIForFile &uri) {
71 lsp::Diagnostic lspDiag;
72 lspDiag.source = "pdll";
73
74 // FIXME: Right now all of the diagnostics are treated as parser issues, but
75 // some are parser and some are verifier.
76 lspDiag.category = "Parse Error";
77
78 // Try to grab a file location for this diagnostic.
79 lsp::Location loc = getLocationFromLoc(mgr&: sourceMgr, range: diag.getLocation(), uri);
80 lspDiag.range = loc.range;
81
82 // Skip diagnostics that weren't emitted within the main file.
83 if (loc.uri != uri)
84 return std::nullopt;
85
86 // Convert the severity for the diagnostic.
87 switch (diag.getSeverity()) {
88 case ast::Diagnostic::Severity::DK_Note:
89 llvm_unreachable("expected notes to be handled separately");
90 case ast::Diagnostic::Severity::DK_Warning:
91 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
92 break;
93 case ast::Diagnostic::Severity::DK_Error:
94 lspDiag.severity = lsp::DiagnosticSeverity::Error;
95 break;
96 case ast::Diagnostic::Severity::DK_Remark:
97 lspDiag.severity = lsp::DiagnosticSeverity::Information;
98 break;
99 }
100 lspDiag.message = diag.getMessage().str();
101
102 // Attach any notes to the main diagnostic as related information.
103 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
104 for (const ast::Diagnostic &note : diag.getNotes()) {
105 relatedDiags.emplace_back(
106 args: getLocationFromLoc(mgr&: sourceMgr, range: note.getLocation(), uri),
107 args: note.getMessage().str());
108 }
109 if (!relatedDiags.empty())
110 lspDiag.relatedInformation = std::move(relatedDiags);
111
112 return lspDiag;
113}
114
115/// Get or extract the documentation for the given decl.
116static std::optional<std::string>
117getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl) {
118 // If the decl already had documentation set, use it.
119 if (std::optional<StringRef> doc = decl->getDocComment())
120 return doc->str();
121
122 // If the decl doesn't yet have documentation, try to extract it from the
123 // source file.
124 return lsp::extractSourceDocComment(sourceMgr, loc: decl->getLoc().Start);
125}
126
127//===----------------------------------------------------------------------===//
128// PDLIndex
129//===----------------------------------------------------------------------===//
130
131namespace {
132struct PDLIndexSymbol {
133 explicit PDLIndexSymbol(const ast::Decl *definition)
134 : definition(definition) {}
135 explicit PDLIndexSymbol(const ods::Operation *definition)
136 : definition(definition) {}
137
138 /// Return the location of the definition of this symbol.
139 SMRange getDefLoc() const {
140 if (const ast::Decl *decl =
141 llvm::dyn_cast_if_present<const ast::Decl *>(Val: definition)) {
142 const ast::Name *declName = decl->getName();
143 return declName ? declName->getLoc() : decl->getLoc();
144 }
145 return cast<const ods::Operation *>(Val: definition)->getLoc();
146 }
147
148 /// The main definition of the symbol.
149 PointerUnion<const ast::Decl *, const ods::Operation *> definition;
150 /// The set of references to the symbol.
151 std::vector<SMRange> references;
152};
153
154/// This class provides an index for definitions/uses within a PDL document.
155/// It provides efficient lookup of a definition given an input source range.
156class PDLIndex {
157public:
158 PDLIndex() : intervalMap(allocator) {}
159
160 /// Initialize the index with the given ast::Module.
161 void initialize(const ast::Module &module, const ods::Context &odsContext);
162
163 /// Lookup a symbol for the given location. Returns nullptr if no symbol could
164 /// be found. If provided, `overlappedRange` is set to the range that the
165 /// provided `loc` overlapped with.
166 const PDLIndexSymbol *lookup(SMLoc loc,
167 SMRange *overlappedRange = nullptr) const;
168
169private:
170 /// The type of interval map used to store source references. SMRange is
171 /// half-open, so we also need to use a half-open interval map.
172 using MapT =
173 llvm::IntervalMap<const char *, const PDLIndexSymbol *,
174 llvm::IntervalMapImpl::NodeSizer<
175 const char *, const PDLIndexSymbol *>::LeafSize,
176 llvm::IntervalMapHalfOpenInfo<const char *>>;
177
178 /// An allocator for the interval map.
179 MapT::Allocator allocator;
180
181 /// An interval map containing a corresponding definition mapped to a source
182 /// interval.
183 MapT intervalMap;
184
185 /// A mapping between definitions and their corresponding symbol.
186 DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol;
187};
188} // namespace
189
190void PDLIndex::initialize(const ast::Module &module,
191 const ods::Context &odsContext) {
192 auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
193 auto it = defToSymbol.try_emplace(def, nullptr);
194 if (it.second)
195 it.first->second = std::make_unique<PDLIndexSymbol>(def);
196 return &*it.first->second;
197 };
198 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
199 bool isDef = false) {
200 const char *startLoc = refLoc.Start.getPointer();
201 const char *endLoc = refLoc.End.getPointer();
202 if (!intervalMap.overlaps(a: startLoc, b: endLoc)) {
203 intervalMap.insert(a: startLoc, b: endLoc, y: sym);
204 if (!isDef)
205 sym->references.push_back(x: refLoc);
206 }
207 };
208 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
209 const ods::Operation *odsOp = odsContext.lookupOperation(name: opName);
210 if (!odsOp)
211 return;
212
213 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
214 insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
215 insertDeclRef(symbol, refLoc);
216 };
217
218 module.walk(walkFn: [&](const ast::Node *node) {
219 // Handle references to PDL decls.
220 if (const auto *decl = dyn_cast<ast::OpNameDecl>(Val: node)) {
221 if (std::optional<StringRef> name = decl->getName())
222 insertODSOpRef(*name, decl->getLoc());
223 } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(Val: node)) {
224 const ast::Name *name = decl->getName();
225 if (!name)
226 return;
227 PDLIndexSymbol *declSym = getOrInsertDef(decl);
228 insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
229
230 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl)) {
231 // Record references to any constraints.
232 for (const auto &it : varDecl->getConstraints())
233 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
234 }
235 } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(Val: node)) {
236 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
237 }
238 });
239}
240
241const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
242 SMRange *overlappedRange) const {
243 auto it = intervalMap.find(x: loc.getPointer());
244 if (!it.valid() || loc.getPointer() < it.start())
245 return nullptr;
246
247 if (overlappedRange) {
248 *overlappedRange = SMRange(SMLoc::getFromPointer(Ptr: it.start()),
249 SMLoc::getFromPointer(Ptr: it.stop()));
250 }
251 return it.value();
252}
253
254//===----------------------------------------------------------------------===//
255// PDLDocument
256//===----------------------------------------------------------------------===//
257
258namespace {
259/// This class represents all of the information pertaining to a specific PDL
260/// document.
261struct PDLDocument {
262 PDLDocument(const lsp::URIForFile &uri, StringRef contents,
263 const std::vector<std::string> &extraDirs,
264 std::vector<lsp::Diagnostic> &diagnostics);
265 PDLDocument(const PDLDocument &) = delete;
266 PDLDocument &operator=(const PDLDocument &) = delete;
267
268 //===--------------------------------------------------------------------===//
269 // Definitions and References
270 //===--------------------------------------------------------------------===//
271
272 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
273 std::vector<lsp::Location> &locations);
274 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
275 std::vector<lsp::Location> &references);
276
277 //===--------------------------------------------------------------------===//
278 // Document Links
279 //===--------------------------------------------------------------------===//
280
281 void getDocumentLinks(const lsp::URIForFile &uri,
282 std::vector<lsp::DocumentLink> &links);
283
284 //===--------------------------------------------------------------------===//
285 // Hover
286 //===--------------------------------------------------------------------===//
287
288 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
289 const lsp::Position &hoverPos);
290 std::optional<lsp::Hover> findHover(const ast::Decl *decl,
291 const SMRange &hoverRange);
292 lsp::Hover buildHoverForOpName(const ods::Operation *op,
293 const SMRange &hoverRange);
294 lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
295 const SMRange &hoverRange);
296 lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
297 const SMRange &hoverRange);
298 lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
299 const SMRange &hoverRange);
300 template <typename T>
301 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
302 const T *decl,
303 const SMRange &hoverRange);
304
305 //===--------------------------------------------------------------------===//
306 // Document Symbols
307 //===--------------------------------------------------------------------===//
308
309 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
310
311 //===--------------------------------------------------------------------===//
312 // Code Completion
313 //===--------------------------------------------------------------------===//
314
315 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
316 const lsp::Position &completePos);
317
318 //===--------------------------------------------------------------------===//
319 // Signature Help
320 //===--------------------------------------------------------------------===//
321
322 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
323 const lsp::Position &helpPos);
324
325 //===--------------------------------------------------------------------===//
326 // Inlay Hints
327 //===--------------------------------------------------------------------===//
328
329 void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range,
330 std::vector<lsp::InlayHint> &inlayHints);
331 void getInlayHintsFor(const ast::VariableDecl *decl,
332 const lsp::URIForFile &uri,
333 std::vector<lsp::InlayHint> &inlayHints);
334 void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri,
335 std::vector<lsp::InlayHint> &inlayHints);
336 void getInlayHintsFor(const ast::OperationExpr *expr,
337 const lsp::URIForFile &uri,
338 std::vector<lsp::InlayHint> &inlayHints);
339
340 /// Add a parameter hint for the given expression using `label`.
341 void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
342 const ast::Expr *expr, StringRef label);
343
344 //===--------------------------------------------------------------------===//
345 // PDLL ViewOutput
346 //===--------------------------------------------------------------------===//
347
348 void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
349
350 //===--------------------------------------------------------------------===//
351 // Fields
352 //===--------------------------------------------------------------------===//
353
354 /// The include directories for this file.
355 std::vector<std::string> includeDirs;
356
357 /// The source manager containing the contents of the input file.
358 llvm::SourceMgr sourceMgr;
359
360 /// The ODS and AST contexts.
361 ods::Context odsContext;
362 ast::Context astContext;
363
364 /// The parsed AST module, or failure if the file wasn't valid.
365 FailureOr<ast::Module *> astModule;
366
367 /// The index of the parsed module.
368 PDLIndex index;
369
370 /// The set of includes of the parsed module.
371 SmallVector<lsp::SourceMgrInclude> parsedIncludes;
372};
373} // namespace
374
375PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
376 const std::vector<std::string> &extraDirs,
377 std::vector<lsp::Diagnostic> &diagnostics)
378 : astContext(odsContext) {
379 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(InputData: contents, BufferName: uri.file());
380 if (!memBuffer) {
381 lsp::Logger::error(fmt: "Failed to create memory buffer for file", vals: uri.file());
382 return;
383 }
384
385 // Build the set of include directories for this file.
386 llvm::SmallString<32> uriDirectory(uri.file());
387 llvm::sys::path::remove_filename(path&: uriDirectory);
388 includeDirs.push_back(x: uriDirectory.str().str());
389 llvm::append_range(C&: includeDirs, R: extraDirs);
390
391 sourceMgr.setIncludeDirs(includeDirs);
392 sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc());
393
394 astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
395 if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
396 diagnostics.push_back(x: std::move(*lspDiag));
397 });
398 astModule = parsePDLLAST(ctx&: astContext, sourceMgr, /*enableDocumentation=*/true);
399
400 // Initialize the set of parsed includes.
401 lsp::gatherIncludeFiles(sourceMgr, includes&: parsedIncludes);
402
403 // If we failed to parse the module, there is nothing left to initialize.
404 if (failed(Result: astModule))
405 return;
406
407 // Prepare the AST index with the parsed module.
408 index.initialize(module: **astModule, odsContext);
409}
410
411//===----------------------------------------------------------------------===//
412// PDLDocument: Definitions and References
413//===----------------------------------------------------------------------===//
414
415void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
416 const lsp::Position &defPos,
417 std::vector<lsp::Location> &locations) {
418 SMLoc posLoc = defPos.getAsSMLoc(mgr&: sourceMgr);
419 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc);
420 if (!symbol)
421 return;
422
423 locations.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: symbol->getDefLoc(), uri));
424}
425
426void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
427 const lsp::Position &pos,
428 std::vector<lsp::Location> &references) {
429 SMLoc posLoc = pos.getAsSMLoc(mgr&: sourceMgr);
430 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc);
431 if (!symbol)
432 return;
433
434 references.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: symbol->getDefLoc(), uri));
435 for (SMRange refLoc : symbol->references)
436 references.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: refLoc, uri));
437}
438
439//===--------------------------------------------------------------------===//
440// PDLDocument: Document Links
441//===--------------------------------------------------------------------===//
442
443void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
444 std::vector<lsp::DocumentLink> &links) {
445 for (const lsp::SourceMgrInclude &include : parsedIncludes)
446 links.emplace_back(args: include.range, args: include.uri);
447}
448
449//===----------------------------------------------------------------------===//
450// PDLDocument: Hover
451//===----------------------------------------------------------------------===//
452
453std::optional<lsp::Hover>
454PDLDocument::findHover(const lsp::URIForFile &uri,
455 const lsp::Position &hoverPos) {
456 SMLoc posLoc = hoverPos.getAsSMLoc(mgr&: sourceMgr);
457
458 // Check for a reference to an include.
459 for (const lsp::SourceMgrInclude &include : parsedIncludes)
460 if (include.range.contains(pos: hoverPos))
461 return include.buildHover();
462
463 // Find the symbol at the given location.
464 SMRange hoverRange;
465 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc, overlappedRange: &hoverRange);
466 if (!symbol)
467 return std::nullopt;
468
469 // Add hover for operation names.
470 if (const auto *op =
471 llvm::dyn_cast_if_present<const ods::Operation *>(Val: symbol->definition))
472 return buildHoverForOpName(op, hoverRange);
473 const auto *decl = cast<const ast::Decl *>(Val: symbol->definition);
474 return findHover(decl, hoverRange);
475}
476
477std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
478 const SMRange &hoverRange) {
479 // Add hover for variables.
480 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl))
481 return buildHoverForVariable(varDecl, hoverRange);
482
483 // Add hover for patterns.
484 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(Val: decl))
485 return buildHoverForPattern(decl: patternDecl, hoverRange);
486
487 // Add hover for core constraints.
488 if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(Val: decl))
489 return buildHoverForCoreConstraint(decl: cst, hoverRange);
490
491 // Add hover for user constraints.
492 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(Val: decl))
493 return buildHoverForUserConstraintOrRewrite(typeName: "Constraint", decl: cst, hoverRange);
494
495 // Add hover for user rewrites.
496 if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(Val: decl))
497 return buildHoverForUserConstraintOrRewrite(typeName: "Rewrite", decl: rewrite, hoverRange);
498
499 return std::nullopt;
500}
501
502lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
503 const SMRange &hoverRange) {
504 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
505 {
506 llvm::raw_string_ostream hoverOS(hover.contents.value);
507 hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
508 << op->getSummary() << "\n***\n"
509 << op->getDescription();
510 }
511 return hover;
512}
513
514lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
515 const SMRange &hoverRange) {
516 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
517 {
518 llvm::raw_string_ostream hoverOS(hover.contents.value);
519 hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
520 << "Type: `" << varDecl->getType() << "`\n";
521 }
522 return hover;
523}
524
525lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
526 const SMRange &hoverRange) {
527 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
528 {
529 llvm::raw_string_ostream hoverOS(hover.contents.value);
530 hoverOS << "**Pattern**";
531 if (const ast::Name *name = decl->getName())
532 hoverOS << ": `" << name->getName() << "`";
533 hoverOS << "\n***\n";
534 if (std::optional<uint16_t> benefit = decl->getBenefit())
535 hoverOS << "Benefit: " << *benefit << "\n";
536 if (decl->hasBoundedRewriteRecursion())
537 hoverOS << "HasBoundedRewriteRecursion\n";
538 hoverOS << "RootOp: `"
539 << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
540
541 // Format the documentation for the decl.
542 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
543 hoverOS << "\n" << *doc << "\n";
544 }
545 return hover;
546}
547
548lsp::Hover
549PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
550 const SMRange &hoverRange) {
551 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
552 {
553 llvm::raw_string_ostream hoverOS(hover.contents.value);
554 hoverOS << "**Constraint**: `";
555 TypeSwitch<const ast::Decl *>(decl)
556 .Case(caseFn: [&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
557 .Case(caseFn: [&](const ast::OpConstraintDecl *opCst) {
558 hoverOS << "Op";
559 if (std::optional<StringRef> name = opCst->getName())
560 hoverOS << "<" << *name << ">";
561 })
562 .Case(caseFn: [&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
563 .Case(caseFn: [&](const ast::TypeRangeConstraintDecl *) {
564 hoverOS << "TypeRange";
565 })
566 .Case(caseFn: [&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
567 .Case(caseFn: [&](const ast::ValueRangeConstraintDecl *) {
568 hoverOS << "ValueRange";
569 });
570 hoverOS << "`\n";
571 }
572 return hover;
573}
574
575template <typename T>
576lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
577 StringRef typeName, const T *decl, const SMRange &hoverRange) {
578 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579 {
580 llvm::raw_string_ostream hoverOS(hover.contents.value);
581 hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
582 << "`\n***\n";
583 ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
584 if (!inputs.empty()) {
585 hoverOS << "Parameters:\n";
586 for (const ast::VariableDecl *input : inputs)
587 hoverOS << "* " << input->getName().getName() << ": `"
588 << input->getType() << "`\n";
589 hoverOS << "***\n";
590 }
591 ast::Type resultType = decl->getResultType();
592 if (auto resultTupleTy = dyn_cast<ast::TupleType>(Val&: resultType)) {
593 if (!resultTupleTy.empty()) {
594 hoverOS << "Results:\n";
595 for (auto it : llvm::zip(t: resultTupleTy.getElementNames(),
596 u: resultTupleTy.getElementTypes())) {
597 StringRef name = std::get<0>(t&: it);
598 hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
599 << std::get<1>(t&: it) << "`\n";
600 }
601 hoverOS << "***\n";
602 }
603 } else {
604 hoverOS << "Results:\n* `" << resultType << "`\n";
605 hoverOS << "***\n";
606 }
607
608 // Format the documentation for the decl.
609 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
610 hoverOS << "\n" << *doc << "\n";
611 }
612 return hover;
613}
614
615//===----------------------------------------------------------------------===//
616// PDLDocument: Document Symbols
617//===----------------------------------------------------------------------===//
618
619void PDLDocument::findDocumentSymbols(
620 std::vector<lsp::DocumentSymbol> &symbols) {
621 if (failed(Result: astModule))
622 return;
623
624 for (const ast::Decl *decl : (*astModule)->getChildren()) {
625 if (!isMainFileLoc(mgr&: sourceMgr, loc: decl->getLoc()))
626 continue;
627
628 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(Val: decl)) {
629 const ast::Name *name = patternDecl->getName();
630
631 SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
632 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
633
634 symbols.emplace_back(
635 args: name ? name->getName() : "<pattern>", args: lsp::SymbolKind::Class,
636 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
637 } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(Val: decl)) {
638 // TODO: Add source information for the code block body.
639 SMRange nameLoc = cDecl->getName().getLoc();
640 SMRange bodyLoc = nameLoc;
641
642 symbols.emplace_back(
643 args: cDecl->getName().getName(), args: lsp::SymbolKind::Function,
644 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
645 } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(Val: decl)) {
646 // TODO: Add source information for the code block body.
647 SMRange nameLoc = cDecl->getName().getLoc();
648 SMRange bodyLoc = nameLoc;
649
650 symbols.emplace_back(
651 args: cDecl->getName().getName(), args: lsp::SymbolKind::Function,
652 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
653 }
654 }
655}
656
657//===----------------------------------------------------------------------===//
658// PDLDocument: Code Completion
659//===----------------------------------------------------------------------===//
660
661namespace {
662class LSPCodeCompleteContext : public CodeCompleteContext {
663public:
664 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
665 lsp::CompletionList &completionList,
666 ods::Context &odsContext,
667 ArrayRef<std::string> includeDirs)
668 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
669 completionList(completionList), odsContext(odsContext),
670 includeDirs(includeDirs) {}
671
672 void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
673 ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
674 ArrayRef<StringRef> elementNames = tupleType.getElementNames();
675 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
676 // Push back a completion item that uses the result index.
677 lsp::CompletionItem item;
678 item.label = llvm::formatv(Fmt: "{0} (field #{0})", Vals&: i).str();
679 item.insertText = Twine(i).str();
680 item.filterText = item.sortText = item.insertText;
681 item.kind = lsp::CompletionItemKind::Field;
682 item.detail = llvm::formatv(Fmt: "{0}: {1}", Vals&: i, Vals: elementTypes[i]);
683 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
684 completionList.items.emplace_back(args&: item);
685
686 // If the element has a name, push back a completion item with that name.
687 if (!elementNames[i].empty()) {
688 item.label =
689 llvm::formatv(Fmt: "{1} (field #{0})", Vals&: i, Vals: elementNames[i]).str();
690 item.filterText = item.label;
691 item.insertText = elementNames[i].str();
692 completionList.items.emplace_back(args&: item);
693 }
694 }
695 }
696
697 void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
698 const ods::Operation *odsOp = opType.getODSOperation();
699 if (!odsOp)
700 return;
701
702 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
703 for (const auto &it : llvm::enumerate(First&: results)) {
704 const ods::OperandOrResult &result = it.value();
705 const ods::TypeConstraint &constraint = result.getConstraint();
706
707 // Push back a completion item that uses the result index.
708 lsp::CompletionItem item;
709 item.label = llvm::formatv(Fmt: "{0} (field #{0})", Vals: it.index()).str();
710 item.insertText = Twine(it.index()).str();
711 item.filterText = item.sortText = item.insertText;
712 item.kind = lsp::CompletionItemKind::Field;
713 switch (result.getVariableLengthKind()) {
714 case ods::VariableLengthKind::Single:
715 item.detail = llvm::formatv(Fmt: "{0}: Value", Vals: it.index()).str();
716 break;
717 case ods::VariableLengthKind::Optional:
718 item.detail = llvm::formatv(Fmt: "{0}: Value?", Vals: it.index()).str();
719 break;
720 case ods::VariableLengthKind::Variadic:
721 item.detail = llvm::formatv(Fmt: "{0}: ValueRange", Vals: it.index()).str();
722 break;
723 }
724 item.documentation = lsp::MarkupContent{
725 .kind: lsp::MarkupKind::Markdown,
726 .value: llvm::formatv(Fmt: "{0}\n\n```c++\n{1}\n```\n", Vals: constraint.getSummary(),
727 Vals: constraint.getCppClass())
728 .str()};
729 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
730 completionList.items.emplace_back(args&: item);
731
732 // If the result has a name, push back a completion item with the result
733 // name.
734 if (!result.getName().empty()) {
735 item.label =
736 llvm::formatv(Fmt: "{1} (field #{0})", Vals: it.index(), Vals: result.getName())
737 .str();
738 item.filterText = item.label;
739 item.insertText = result.getName().str();
740 completionList.items.emplace_back(args&: item);
741 }
742 }
743 }
744
745 void codeCompleteOperationAttributeName(StringRef opName) final {
746 const ods::Operation *odsOp = odsContext.lookupOperation(name: opName);
747 if (!odsOp)
748 return;
749
750 for (const ods::Attribute &attr : odsOp->getAttributes()) {
751 const ods::AttributeConstraint &constraint = attr.getConstraint();
752
753 lsp::CompletionItem item;
754 item.label = attr.getName().str();
755 item.kind = lsp::CompletionItemKind::Field;
756 item.detail = attr.isOptional() ? "optional" : "";
757 item.documentation = lsp::MarkupContent{
758 .kind: lsp::MarkupKind::Markdown,
759 .value: llvm::formatv(Fmt: "{0}\n\n```c++\n{1}\n```\n", Vals: constraint.getSummary(),
760 Vals: constraint.getCppClass())
761 .str()};
762 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
763 completionList.items.emplace_back(args&: item);
764 }
765 }
766
767 void codeCompleteConstraintName(ast::Type currentType,
768 bool allowInlineTypeConstraints,
769 const ast::DeclScope *scope) final {
770 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
771 StringRef snippetText = "") {
772 lsp::CompletionItem item;
773 item.label = constraint.str();
774 item.kind = lsp::CompletionItemKind::Class;
775 item.detail = (constraint + " constraint").str();
776 item.documentation = lsp::MarkupContent{
777 .kind: lsp::MarkupKind::Markdown,
778 .value: ("A single entity core constraint of type `" + mlirType + "`").str()};
779 item.sortText = "0";
780 item.insertText = snippetText.str();
781 item.insertTextFormat = snippetText.empty()
782 ? lsp::InsertTextFormat::PlainText
783 : lsp::InsertTextFormat::Snippet;
784 completionList.items.emplace_back(args&: item);
785 };
786
787 // Insert completions for the core constraints. Some core constraints have
788 // additional characteristics, so we may add then even if a type has been
789 // inferred.
790 if (!currentType) {
791 addCoreConstraint("Attr", "mlir::Attribute");
792 addCoreConstraint("Op", "mlir::Operation *");
793 addCoreConstraint("Value", "mlir::Value");
794 addCoreConstraint("ValueRange", "mlir::ValueRange");
795 addCoreConstraint("Type", "mlir::Type");
796 addCoreConstraint("TypeRange", "mlir::TypeRange");
797 }
798 if (allowInlineTypeConstraints) {
799 /// Attr<Type>.
800 if (!currentType || isa<ast::AttributeType>(Val: currentType))
801 addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
802 /// Value<Type>.
803 if (!currentType || isa<ast::ValueType>(Val: currentType))
804 addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
805 /// ValueRange<TypeRange>.
806 if (!currentType || isa<ast::ValueRangeType>(Val: currentType))
807 addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
808 "ValueRange<$1>");
809 }
810
811 // If a scope was provided, check it for potential constraints.
812 while (scope) {
813 for (const ast::Decl *decl : scope->getDecls()) {
814 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(Val: decl)) {
815 lsp::CompletionItem item;
816 item.label = cst->getName().getName().str();
817 item.kind = lsp::CompletionItemKind::Interface;
818 item.sortText = "2_" + item.label;
819
820 // Skip constraints that are not single-arg. We currently only
821 // complete variable constraints.
822 if (cst->getInputs().size() != 1)
823 continue;
824
825 // Ensure the input type matched the given type.
826 ast::Type constraintType = cst->getInputs()[0]->getType();
827 if (currentType && !currentType.refineWith(other: constraintType))
828 continue;
829
830 // Format the constraint signature.
831 {
832 llvm::raw_string_ostream strOS(item.detail);
833 strOS << "(";
834 llvm::interleaveComma(
835 c: cst->getInputs(), os&: strOS, each_fn: [&](const ast::VariableDecl *var) {
836 strOS << var->getName().getName() << ": " << var->getType();
837 });
838 strOS << ") -> " << cst->getResultType();
839 }
840
841 // Format the documentation for the constraint.
842 if (std::optional<std::string> doc =
843 getDocumentationFor(sourceMgr, decl: cst)) {
844 item.documentation =
845 lsp::MarkupContent{.kind: lsp::MarkupKind::Markdown, .value: std::move(*doc)};
846 }
847
848 completionList.items.emplace_back(args&: item);
849 }
850 }
851
852 scope = scope->getParentScope();
853 }
854 }
855
856 void codeCompleteDialectName() final {
857 // Code complete known dialects.
858 for (const ods::Dialect &dialect : odsContext.getDialects()) {
859 lsp::CompletionItem item;
860 item.label = dialect.getName().str();
861 item.kind = lsp::CompletionItemKind::Class;
862 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
863 completionList.items.emplace_back(args&: item);
864 }
865 }
866
867 void codeCompleteOperationName(StringRef dialectName) final {
868 const ods::Dialect *dialect = odsContext.lookupDialect(name: dialectName);
869 if (!dialect)
870 return;
871
872 for (const auto &it : dialect->getOperations()) {
873 const ods::Operation &op = *it.second;
874
875 lsp::CompletionItem item;
876 item.label = op.getName().drop_front(N: dialectName.size() + 1).str();
877 item.kind = lsp::CompletionItemKind::Field;
878 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
879 completionList.items.emplace_back(args&: item);
880 }
881 }
882
883 void codeCompletePatternMetadata() final {
884 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
885 StringRef snippetText = "") {
886 lsp::CompletionItem item;
887 item.label = constraint.str();
888 item.kind = lsp::CompletionItemKind::Class;
889 item.detail = "pattern metadata";
890 item.documentation =
891 lsp::MarkupContent{.kind: lsp::MarkupKind::Markdown, .value: desc.str()};
892 item.insertText = snippetText.str();
893 item.insertTextFormat = snippetText.empty()
894 ? lsp::InsertTextFormat::PlainText
895 : lsp::InsertTextFormat::Snippet;
896 completionList.items.emplace_back(args&: item);
897 };
898
899 addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
900 "benefit($1)");
901 addSimpleConstraint("recursion",
902 "The pattern properly handles recursive application.");
903 }
904
905 void codeCompleteIncludeFilename(StringRef curPath) final {
906 // Normalize the path to allow for interacting with the file system
907 // utilities.
908 SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(path: curPath));
909 llvm::sys::path::native(path&: nativeRelDir);
910
911 // Set of already included completion paths.
912 StringSet<> seenResults;
913
914 // Functor used to add a single include completion item.
915 auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
916 lsp::CompletionItem item;
917 item.label = path.str();
918 item.kind = isDirectory ? lsp::CompletionItemKind::Folder
919 : lsp::CompletionItemKind::File;
920 if (seenResults.insert(key: item.label).second)
921 completionList.items.emplace_back(args&: item);
922 };
923
924 // Process the include directories for this file, adding any potential
925 // nested include files or directories.
926 for (StringRef includeDir : includeDirs) {
927 llvm::SmallString<128> dir = includeDir;
928 if (!nativeRelDir.empty())
929 llvm::sys::path::append(path&: dir, a: nativeRelDir);
930
931 std::error_code errorCode;
932 for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
933 e = llvm::sys::fs::directory_iterator();
934 !errorCode && it != e; it.increment(ec&: errorCode)) {
935 StringRef filename = llvm::sys::path::filename(path: it->path());
936
937 // To know whether a symlink should be treated as file or a directory,
938 // we have to stat it. This should be cheap enough as there shouldn't be
939 // many symlinks.
940 llvm::sys::fs::file_type fileType = it->type();
941 if (fileType == llvm::sys::fs::file_type::symlink_file) {
942 if (auto fileStatus = it->status())
943 fileType = fileStatus->type();
944 }
945
946 switch (fileType) {
947 case llvm::sys::fs::file_type::directory_file:
948 addIncludeCompletion(filename, /*isDirectory=*/true);
949 break;
950 case llvm::sys::fs::file_type::regular_file: {
951 // Only consider concrete files that can actually be included by PDLL.
952 if (filename.ends_with(Suffix: ".pdll") || filename.ends_with(Suffix: ".td"))
953 addIncludeCompletion(filename, /*isDirectory=*/false);
954 break;
955 }
956 default:
957 break;
958 }
959 }
960 }
961
962 // Sort the completion results to make sure the output is deterministic in
963 // the face of different iteration schemes for different platforms.
964 llvm::sort(C&: completionList.items, Comp: [](const lsp::CompletionItem &lhs,
965 const lsp::CompletionItem &rhs) {
966 return lhs.label < rhs.label;
967 });
968 }
969
970private:
971 llvm::SourceMgr &sourceMgr;
972 lsp::CompletionList &completionList;
973 ods::Context &odsContext;
974 ArrayRef<std::string> includeDirs;
975};
976} // namespace
977
978lsp::CompletionList
979PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
980 const lsp::Position &completePos) {
981 SMLoc posLoc = completePos.getAsSMLoc(mgr&: sourceMgr);
982 if (!posLoc.isValid())
983 return lsp::CompletionList();
984
985 // To perform code completion, we run another parse of the module with the
986 // code completion context provided.
987 ods::Context tmpODSContext;
988 lsp::CompletionList completionList;
989 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
990 tmpODSContext,
991 sourceMgr.getIncludeDirs());
992
993 ast::Context tmpContext(tmpODSContext);
994 (void)parsePDLLAST(ctx&: tmpContext, sourceMgr, /*enableDocumentation=*/true,
995 codeCompleteContext: &lspCompleteContext);
996
997 return completionList;
998}
999
1000//===----------------------------------------------------------------------===//
1001// PDLDocument: Signature Help
1002//===----------------------------------------------------------------------===//
1003
1004namespace {
1005class LSPSignatureHelpContext : public CodeCompleteContext {
1006public:
1007 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1008 lsp::SignatureHelp &signatureHelp,
1009 ods::Context &odsContext)
1010 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1011 signatureHelp(signatureHelp), odsContext(odsContext) {}
1012
1013 void codeCompleteCallSignature(const ast::CallableDecl *callable,
1014 unsigned currentNumArgs) final {
1015 signatureHelp.activeParameter = currentNumArgs;
1016
1017 lsp::SignatureInformation signatureInfo;
1018 {
1019 llvm::raw_string_ostream strOS(signatureInfo.label);
1020 strOS << callable->getName()->getName() << "(";
1021 auto formatParamFn = [&](const ast::VariableDecl *var) {
1022 unsigned paramStart = strOS.str().size();
1023 strOS << var->getName().getName() << ": " << var->getType();
1024 unsigned paramEnd = strOS.str().size();
1025 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1026 .labelString: StringRef(strOS.str()).slice(Start: paramStart, End: paramEnd).str(),
1027 .labelOffsets: std::make_pair(x&: paramStart, y&: paramEnd), /*paramDoc*/ .documentation: std::string()});
1028 };
1029 llvm::interleaveComma(c: callable->getInputs(), os&: strOS, each_fn: formatParamFn);
1030 strOS << ") -> " << callable->getResultType();
1031 }
1032
1033 // Format the documentation for the callable.
1034 if (std::optional<std::string> doc =
1035 getDocumentationFor(sourceMgr, decl: callable))
1036 signatureInfo.documentation = std::move(*doc);
1037
1038 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1039 }
1040
1041 void
1042 codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
1043 unsigned currentNumOperands) final {
1044 const ods::Operation *odsOp =
1045 opName ? odsContext.lookupOperation(name: *opName) : nullptr;
1046 codeCompleteOperationOperandOrResultSignature(
1047 opName, odsOp,
1048 values: odsOp ? odsOp->getOperands() : ArrayRef<ods::OperandOrResult>(),
1049 currentValue: currentNumOperands, label: "operand", dataType: "Value");
1050 }
1051
1052 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
1053 unsigned currentNumResults) final {
1054 const ods::Operation *odsOp =
1055 opName ? odsContext.lookupOperation(name: *opName) : nullptr;
1056 codeCompleteOperationOperandOrResultSignature(
1057 opName, odsOp,
1058 values: odsOp ? odsOp->getResults() : ArrayRef<ods::OperandOrResult>(),
1059 currentValue: currentNumResults, label: "result", dataType: "Type");
1060 }
1061
1062 void codeCompleteOperationOperandOrResultSignature(
1063 std::optional<StringRef> opName, const ods::Operation *odsOp,
1064 ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1065 StringRef label, StringRef dataType) {
1066 signatureHelp.activeParameter = currentValue;
1067
1068 // If we have ODS information for the operation, add in the ODS signature
1069 // for the operation. We also verify that the current number of values is
1070 // not more than what is defined in ODS, as this will result in an error
1071 // anyways.
1072 if (odsOp && currentValue < values.size()) {
1073 lsp::SignatureInformation signatureInfo;
1074
1075 // Build the signature label.
1076 {
1077 llvm::raw_string_ostream strOS(signatureInfo.label);
1078 strOS << "(";
1079 auto formatFn = [&](const ods::OperandOrResult &value) {
1080 unsigned paramStart = strOS.str().size();
1081
1082 strOS << value.getName() << ": ";
1083
1084 StringRef constraintDoc = value.getConstraint().getSummary();
1085 std::string paramDoc;
1086 switch (value.getVariableLengthKind()) {
1087 case ods::VariableLengthKind::Single:
1088 strOS << dataType;
1089 paramDoc = constraintDoc.str();
1090 break;
1091 case ods::VariableLengthKind::Optional:
1092 strOS << dataType << "?";
1093 paramDoc = ("optional: " + constraintDoc).str();
1094 break;
1095 case ods::VariableLengthKind::Variadic:
1096 strOS << dataType << "Range";
1097 paramDoc = ("variadic: " + constraintDoc).str();
1098 break;
1099 }
1100
1101 unsigned paramEnd = strOS.str().size();
1102 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1103 .labelString: StringRef(strOS.str()).slice(Start: paramStart, End: paramEnd).str(),
1104 .labelOffsets: std::make_pair(x&: paramStart, y&: paramEnd), .documentation: paramDoc});
1105 };
1106 llvm::interleaveComma(c: values, os&: strOS, each_fn: formatFn);
1107 strOS << ")";
1108 }
1109 signatureInfo.documentation =
1110 llvm::formatv(Fmt: "`op<{0}>` ODS {1} specification", Vals&: *opName, Vals&: label)
1111 .str();
1112 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1113 }
1114
1115 // If there aren't any arguments yet, we also add the generic signature.
1116 if (currentValue == 0 && (!odsOp || !values.empty())) {
1117 lsp::SignatureInformation signatureInfo;
1118 signatureInfo.label =
1119 llvm::formatv(Fmt: "(<{0}s>: {1}Range)", Vals&: label, Vals&: dataType).str();
1120 signatureInfo.documentation =
1121 ("Generic operation " + label + " specification").str();
1122 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1123 .labelString: StringRef(signatureInfo.label).drop_front().drop_back().str(),
1124 .labelOffsets: std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1125 .documentation: ("All of the " + label + "s of the operation.").str()});
1126 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1127 }
1128 }
1129
1130private:
1131 llvm::SourceMgr &sourceMgr;
1132 lsp::SignatureHelp &signatureHelp;
1133 ods::Context &odsContext;
1134};
1135} // namespace
1136
1137lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
1138 const lsp::Position &helpPos) {
1139 SMLoc posLoc = helpPos.getAsSMLoc(mgr&: sourceMgr);
1140 if (!posLoc.isValid())
1141 return lsp::SignatureHelp();
1142
1143 // To perform code completion, we run another parse of the module with the
1144 // code completion context provided.
1145 ods::Context tmpODSContext;
1146 lsp::SignatureHelp signatureHelp;
1147 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1148 tmpODSContext);
1149
1150 ast::Context tmpContext(tmpODSContext);
1151 (void)parsePDLLAST(ctx&: tmpContext, sourceMgr, /*enableDocumentation=*/true,
1152 codeCompleteContext: &completeContext);
1153
1154 return signatureHelp;
1155}
1156
1157//===----------------------------------------------------------------------===//
1158// PDLDocument: Inlay Hints
1159//===----------------------------------------------------------------------===//
1160
1161/// Returns true if the given name should be added as a hint for `expr`.
1162static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1163 if (name.empty())
1164 return false;
1165
1166 // If the argument is a reference of the same name, don't add it as a hint.
1167 if (auto *ref = dyn_cast<ast::DeclRefExpr>(Val: expr)) {
1168 const ast::Name *declName = ref->getDecl()->getName();
1169 if (declName && declName->getName() == name)
1170 return false;
1171 }
1172
1173 return true;
1174}
1175
1176void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
1177 const lsp::Range &range,
1178 std::vector<lsp::InlayHint> &inlayHints) {
1179 if (failed(Result: astModule))
1180 return;
1181 SMRange rangeLoc = range.getAsSMRange(mgr&: sourceMgr);
1182 if (!rangeLoc.isValid())
1183 return;
1184 (*astModule)->walk(walkFn: [&](const ast::Node *node) {
1185 SMRange loc = node->getLoc();
1186
1187 // Check that the location of this node is within the input range.
1188 if (!lsp::contains(range: rangeLoc, loc: loc.Start) &&
1189 !lsp::contains(range: rangeLoc, loc: loc.End))
1190 return;
1191
1192 // Handle hints for various types of nodes.
1193 llvm::TypeSwitch<const ast::Node *>(node)
1194 .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1195 caseFn: [&](const auto *node) {
1196 this->getInlayHintsFor(node, uri, inlayHints);
1197 });
1198 });
1199}
1200
1201void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
1202 const lsp::URIForFile &uri,
1203 std::vector<lsp::InlayHint> &inlayHints) {
1204 // Check to see if the variable has a constraint list, if it does we don't
1205 // provide initializer hints.
1206 if (!decl->getConstraints().empty())
1207 return;
1208
1209 // Check to see if the variable has an initializer.
1210 if (const ast::Expr *expr = decl->getInitExpr()) {
1211 // Don't add hints for operation expression initialized variables given that
1212 // the type of the variable is easily inferred by the expression operation
1213 // name.
1214 if (isa<ast::OperationExpr>(Val: expr))
1215 return;
1216 }
1217
1218 lsp::InlayHint hint(lsp::InlayHintKind::Type,
1219 lsp::Position(sourceMgr, decl->getLoc().End));
1220 {
1221 llvm::raw_string_ostream labelOS(hint.label);
1222 labelOS << ": " << decl->getType();
1223 }
1224
1225 inlayHints.emplace_back(args: std::move(hint));
1226}
1227
1228void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
1229 const lsp::URIForFile &uri,
1230 std::vector<lsp::InlayHint> &inlayHints) {
1231 // Try to extract the callable of this call.
1232 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(Val: expr->getCallableExpr());
1233 const auto *callable =
1234 callableRef ? dyn_cast<ast::CallableDecl>(Val: callableRef->getDecl())
1235 : nullptr;
1236 if (!callable)
1237 return;
1238
1239 // Add hints for the arguments to the call.
1240 for (const auto &it : llvm::zip(t: expr->getArguments(), u: callable->getInputs()))
1241 addParameterHintFor(inlayHints, expr: std::get<0>(t: it),
1242 label: std::get<1>(t: it)->getName().getName());
1243}
1244
1245void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
1246 const lsp::URIForFile &uri,
1247 std::vector<lsp::InlayHint> &inlayHints) {
1248 // Check for ODS information.
1249 ast::OperationType opType = dyn_cast<ast::OperationType>(Val: expr->getType());
1250 const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1251
1252 auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1253 // If the value expression used the same location as the operation, don't
1254 // add a hint. This expression was materialized during parsing.
1255 if (expr->getLoc().Start == valueExpr->getLoc().Start)
1256 return;
1257 addParameterHintFor(inlayHints, expr: valueExpr, label);
1258 };
1259
1260 // Functor used to process hints for the operands and results of the
1261 // operation. They effectively have the same format, and thus can be processed
1262 // using the same logic.
1263 auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1264 ArrayRef<ods::OperandOrResult> odsValues,
1265 StringRef allValuesName) {
1266 if (values.empty())
1267 return;
1268
1269 // The values should either map to a single range, or be equivalent to the
1270 // ODS values.
1271 if (values.size() != odsValues.size()) {
1272 // Handle the case of a single element that covers the full range.
1273 if (values.size() == 1)
1274 return addOpHint(values.front(), allValuesName);
1275 return;
1276 }
1277
1278 for (const auto &it : llvm::zip(t&: values, u&: odsValues))
1279 addOpHint(std::get<0>(t: it), std::get<1>(t: it).getName());
1280 };
1281
1282 // Add hints for the operands and results of the operation.
1283 addOperandOrResultHints(expr->getOperands(),
1284 odsOp ? odsOp->getOperands()
1285 : ArrayRef<ods::OperandOrResult>(),
1286 "operands");
1287 addOperandOrResultHints(expr->getResultTypes(),
1288 odsOp ? odsOp->getResults()
1289 : ArrayRef<ods::OperandOrResult>(),
1290 "results");
1291}
1292
1293void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1294 const ast::Expr *expr, StringRef label) {
1295 if (!shouldAddHintFor(expr, name: label))
1296 return;
1297
1298 lsp::InlayHint hint(lsp::InlayHintKind::Parameter,
1299 lsp::Position(sourceMgr, expr->getLoc().Start));
1300 hint.label = (label + ":").str();
1301 hint.paddingRight = true;
1302 inlayHints.emplace_back(args: std::move(hint));
1303}
1304
1305//===----------------------------------------------------------------------===//
1306// PDLL ViewOutput
1307//===----------------------------------------------------------------------===//
1308
1309void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1310 lsp::PDLLViewOutputKind kind) {
1311 if (failed(Result: astModule))
1312 return;
1313 if (kind == lsp::PDLLViewOutputKind::AST) {
1314 (*astModule)->print(os);
1315 return;
1316 }
1317
1318 // Generate the MLIR for the ast module. We also capture diagnostics here to
1319 // show to the user, which may be useful if PDLL isn't capturing constraints
1320 // expected by PDL.
1321 MLIRContext mlirContext;
1322 SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1323 OwningOpRef<ModuleOp> pdlModule =
1324 codegenPDLLToMLIR(mlirContext: &mlirContext, context: astContext, sourceMgr, module: **astModule);
1325 if (!pdlModule)
1326 return;
1327 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1328 pdlModule->print(os, flags: OpPrintingFlags().enableDebugInfo());
1329 return;
1330 }
1331
1332 // Otherwise, generate the output for C++.
1333 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1334 "unexpected PDLLViewOutputKind");
1335 codegenPDLLToCPP(astModule: **astModule, module: *pdlModule, os);
1336}
1337
1338//===----------------------------------------------------------------------===//
1339// PDLTextFileChunk
1340//===----------------------------------------------------------------------===//
1341
1342namespace {
1343/// This class represents a single chunk of an PDL text file.
1344struct PDLTextFileChunk {
1345 PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
1346 StringRef contents,
1347 const std::vector<std::string> &extraDirs,
1348 std::vector<lsp::Diagnostic> &diagnostics)
1349 : lineOffset(lineOffset),
1350 document(uri, contents, extraDirs, diagnostics) {}
1351
1352 /// Adjust the line number of the given range to anchor at the beginning of
1353 /// the file, instead of the beginning of this chunk.
1354 void adjustLocForChunkOffset(lsp::Range &range) {
1355 adjustLocForChunkOffset(pos&: range.start);
1356 adjustLocForChunkOffset(pos&: range.end);
1357 }
1358 /// Adjust the line number of the given position to anchor at the beginning of
1359 /// the file, instead of the beginning of this chunk.
1360 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1361
1362 /// The line offset of this chunk from the beginning of the file.
1363 uint64_t lineOffset;
1364 /// The document referred to by this chunk.
1365 PDLDocument document;
1366};
1367} // namespace
1368
1369//===----------------------------------------------------------------------===//
1370// PDLTextFile
1371//===----------------------------------------------------------------------===//
1372
1373namespace {
1374/// This class represents a text file containing one or more PDL documents.
1375class PDLTextFile {
1376public:
1377 PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1378 int64_t version, const std::vector<std::string> &extraDirs,
1379 std::vector<lsp::Diagnostic> &diagnostics);
1380
1381 /// Return the current version of this text file.
1382 int64_t getVersion() const { return version; }
1383
1384 /// Update the file to the new version using the provided set of content
1385 /// changes. Returns failure if the update was unsuccessful.
1386 LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
1387 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1388 std::vector<lsp::Diagnostic> &diagnostics);
1389
1390 //===--------------------------------------------------------------------===//
1391 // LSP Queries
1392 //===--------------------------------------------------------------------===//
1393
1394 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1395 std::vector<lsp::Location> &locations);
1396 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1397 std::vector<lsp::Location> &references);
1398 void getDocumentLinks(const lsp::URIForFile &uri,
1399 std::vector<lsp::DocumentLink> &links);
1400 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1401 lsp::Position hoverPos);
1402 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1403 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1404 lsp::Position completePos);
1405 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
1406 lsp::Position helpPos);
1407 void getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1408 std::vector<lsp::InlayHint> &inlayHints);
1409 lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
1410
1411private:
1412 using ChunkIterator = llvm::pointee_iterator<
1413 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1414
1415 /// Initialize the text file from the given file contents.
1416 void initialize(const lsp::URIForFile &uri, int64_t newVersion,
1417 std::vector<lsp::Diagnostic> &diagnostics);
1418
1419 /// Find the PDL document that contains the given position, and update the
1420 /// position to be anchored at the start of the found chunk instead of the
1421 /// beginning of the file.
1422 ChunkIterator getChunkItFor(lsp::Position &pos);
1423 PDLTextFileChunk &getChunkFor(lsp::Position &pos) {
1424 return *getChunkItFor(pos);
1425 }
1426
1427 /// The full string contents of the file.
1428 std::string contents;
1429
1430 /// The version of this file.
1431 int64_t version = 0;
1432
1433 /// The number of lines in the file.
1434 int64_t totalNumLines = 0;
1435
1436 /// The chunks of this file. The order of these chunks is the order in which
1437 /// they appear in the text file.
1438 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1439
1440 /// The extra set of include directories for this file.
1441 std::vector<std::string> extraIncludeDirs;
1442};
1443} // namespace
1444
1445PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1446 int64_t version,
1447 const std::vector<std::string> &extraDirs,
1448 std::vector<lsp::Diagnostic> &diagnostics)
1449 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1450 initialize(uri, newVersion: version, diagnostics);
1451}
1452
1453LogicalResult
1454PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
1455 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1456 std::vector<lsp::Diagnostic> &diagnostics) {
1457 if (failed(Result: lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1458 lsp::Logger::error(fmt: "Failed to update contents of {0}", vals: uri.file());
1459 return failure();
1460 }
1461
1462 // If the file contents were properly changed, reinitialize the text file.
1463 initialize(uri, newVersion, diagnostics);
1464 return success();
1465}
1466
1467void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
1468 lsp::Position defPos,
1469 std::vector<lsp::Location> &locations) {
1470 PDLTextFileChunk &chunk = getChunkFor(pos&: defPos);
1471 chunk.document.getLocationsOf(uri, defPos, locations);
1472
1473 // Adjust any locations within this file for the offset of this chunk.
1474 if (chunk.lineOffset == 0)
1475 return;
1476 for (lsp::Location &loc : locations)
1477 if (loc.uri == uri)
1478 chunk.adjustLocForChunkOffset(range&: loc.range);
1479}
1480
1481void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
1482 lsp::Position pos,
1483 std::vector<lsp::Location> &references) {
1484 PDLTextFileChunk &chunk = getChunkFor(pos);
1485 chunk.document.findReferencesOf(uri, pos, references);
1486
1487 // Adjust any locations within this file for the offset of this chunk.
1488 if (chunk.lineOffset == 0)
1489 return;
1490 for (lsp::Location &loc : references)
1491 if (loc.uri == uri)
1492 chunk.adjustLocForChunkOffset(range&: loc.range);
1493}
1494
1495void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
1496 std::vector<lsp::DocumentLink> &links) {
1497 chunks.front()->document.getDocumentLinks(uri, links);
1498 for (const auto &it : llvm::drop_begin(RangeOrContainer&: chunks)) {
1499 size_t currentNumLinks = links.size();
1500 it->document.getDocumentLinks(uri, links);
1501
1502 // Adjust any links within this file to account for the offset of this
1503 // chunk.
1504 for (auto &link : llvm::drop_begin(RangeOrContainer&: links, N: currentNumLinks))
1505 it->adjustLocForChunkOffset(range&: link.range);
1506 }
1507}
1508
1509std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
1510 lsp::Position hoverPos) {
1511 PDLTextFileChunk &chunk = getChunkFor(pos&: hoverPos);
1512 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1513
1514 // Adjust any locations within this file for the offset of this chunk.
1515 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1516 chunk.adjustLocForChunkOffset(range&: *hoverInfo->range);
1517 return hoverInfo;
1518}
1519
1520void PDLTextFile::findDocumentSymbols(
1521 std::vector<lsp::DocumentSymbol> &symbols) {
1522 if (chunks.size() == 1)
1523 return chunks.front()->document.findDocumentSymbols(symbols);
1524
1525 // If there are multiple chunks in this file, we create top-level symbols for
1526 // each chunk.
1527 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1528 PDLTextFileChunk &chunk = *chunks[i];
1529 lsp::Position startPos(chunk.lineOffset);
1530 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1531 : chunks[i + 1]->lineOffset);
1532 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1533 lsp::SymbolKind::Namespace,
1534 /*range=*/lsp::Range(startPos, endPos),
1535 /*selectionRange=*/lsp::Range(startPos));
1536 chunk.document.findDocumentSymbols(symbols&: symbol.children);
1537
1538 // Fixup the locations of document symbols within this chunk.
1539 if (i != 0) {
1540 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1541 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1542 symbolsToFix.push_back(Elt: &childSymbol);
1543
1544 while (!symbolsToFix.empty()) {
1545 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1546 chunk.adjustLocForChunkOffset(range&: symbol->range);
1547 chunk.adjustLocForChunkOffset(range&: symbol->selectionRange);
1548
1549 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1550 symbolsToFix.push_back(Elt: &childSymbol);
1551 }
1552 }
1553
1554 // Push the symbol for this chunk.
1555 symbols.emplace_back(args: std::move(symbol));
1556 }
1557}
1558
1559lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1560 lsp::Position completePos) {
1561 PDLTextFileChunk &chunk = getChunkFor(pos&: completePos);
1562 lsp::CompletionList completionList =
1563 chunk.document.getCodeCompletion(uri, completePos);
1564
1565 // Adjust any completion locations.
1566 for (lsp::CompletionItem &item : completionList.items) {
1567 if (item.textEdit)
1568 chunk.adjustLocForChunkOffset(range&: item.textEdit->range);
1569 for (lsp::TextEdit &edit : item.additionalTextEdits)
1570 chunk.adjustLocForChunkOffset(range&: edit.range);
1571 }
1572 return completionList;
1573}
1574
1575lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
1576 lsp::Position helpPos) {
1577 return getChunkFor(pos&: helpPos).document.getSignatureHelp(uri, helpPos);
1578}
1579
1580void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1581 std::vector<lsp::InlayHint> &inlayHints) {
1582 auto startIt = getChunkItFor(pos&: range.start);
1583 auto endIt = getChunkItFor(pos&: range.end);
1584
1585 // Functor used to get the chunks for a given file, and fixup any locations
1586 auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) {
1587 size_t currentNumHints = inlayHints.size();
1588 chunkIt->document.getInlayHints(uri, range, inlayHints);
1589
1590 // If this isn't the first chunk, update any positions to account for line
1591 // number differences.
1592 if (&*chunkIt != &*chunks.front()) {
1593 for (auto &hint : llvm::drop_begin(RangeOrContainer&: inlayHints, N: currentNumHints))
1594 chunkIt->adjustLocForChunkOffset(pos&: hint.position);
1595 }
1596 };
1597 // Returns the number of lines held by a given chunk.
1598 auto getNumLines = [](ChunkIterator chunkIt) {
1599 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1600 };
1601
1602 // Check if the range is fully within a single chunk.
1603 if (startIt == endIt)
1604 return getHintsForChunk(startIt, range);
1605
1606 // Otherwise, the range is split between multiple chunks. The first chunk
1607 // has the correct range start, but covers the total document.
1608 getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt)));
1609
1610 // Every chunk in between uses the full document.
1611 for (++startIt; startIt != endIt; ++startIt)
1612 getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt)));
1613
1614 // The range for the last chunk starts at the beginning of the document, up
1615 // through the end of the input range.
1616 getHintsForChunk(startIt, lsp::Range(0, range.end));
1617}
1618
1619lsp::PDLLViewOutputResult
1620PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1621 lsp::PDLLViewOutputResult result;
1622 {
1623 llvm::raw_string_ostream outputOS(result.output);
1624 llvm::interleave(
1625 c: llvm::make_pointee_range(Range&: chunks),
1626 each_fn: [&](PDLTextFileChunk &chunk) {
1627 chunk.document.getPDLLViewOutput(os&: outputOS, kind);
1628 },
1629 between_fn: [&] { outputOS << "\n"
1630 << kDefaultSplitMarker << "\n\n"; });
1631 }
1632 return result;
1633}
1634
1635void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
1636 std::vector<lsp::Diagnostic> &diagnostics) {
1637 version = newVersion;
1638 chunks.clear();
1639
1640 // Split the file into separate PDL documents.
1641 SmallVector<StringRef, 8> subContents;
1642 StringRef(contents).split(A&: subContents, Separator: kDefaultSplitMarker);
1643 chunks.emplace_back(args: std::make_unique<PDLTextFileChunk>(
1644 /*lineOffset=*/args: 0, args: uri, args&: subContents.front(), args&: extraIncludeDirs,
1645 args&: diagnostics));
1646
1647 uint64_t lineOffset = subContents.front().count(C: '\n');
1648 for (StringRef docContents : llvm::drop_begin(RangeOrContainer&: subContents)) {
1649 unsigned currentNumDiags = diagnostics.size();
1650 auto chunk = std::make_unique<PDLTextFileChunk>(
1651 args&: lineOffset, args: uri, args&: docContents, args&: extraIncludeDirs, args&: diagnostics);
1652 lineOffset += docContents.count(C: '\n');
1653
1654 // Adjust locations used in diagnostics to account for the offset from the
1655 // beginning of the file.
1656 for (lsp::Diagnostic &diag :
1657 llvm::drop_begin(RangeOrContainer&: diagnostics, N: currentNumDiags)) {
1658 chunk->adjustLocForChunkOffset(range&: diag.range);
1659
1660 if (!diag.relatedInformation)
1661 continue;
1662 for (auto &it : *diag.relatedInformation)
1663 if (it.location.uri == uri)
1664 chunk->adjustLocForChunkOffset(range&: it.location.range);
1665 }
1666 chunks.emplace_back(args: std::move(chunk));
1667 }
1668 totalNumLines = lineOffset;
1669}
1670
1671PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) {
1672 if (chunks.size() == 1)
1673 return chunks.begin();
1674
1675 // Search for the first chunk with a greater line offset, the previous chunk
1676 // is the one that contains `pos`.
1677 auto it = llvm::upper_bound(
1678 Range&: chunks, Value&: pos, C: [](const lsp::Position &pos, const auto &chunk) {
1679 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1680 });
1681 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1682 pos.line -= chunkIt->lineOffset;
1683 return chunkIt;
1684}
1685
1686//===----------------------------------------------------------------------===//
1687// PDLLServer::Impl
1688//===----------------------------------------------------------------------===//
1689
1690struct lsp::PDLLServer::Impl {
1691 explicit Impl(const Options &options)
1692 : options(options), compilationDatabase(options.compilationDatabases) {}
1693
1694 /// PDLL LSP options.
1695 const Options &options;
1696
1697 /// The compilation database containing additional information for files
1698 /// passed to the server.
1699 lsp::CompilationDatabase compilationDatabase;
1700
1701 /// The files held by the server, mapped by their URI file name.
1702 llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1703};
1704
1705//===----------------------------------------------------------------------===//
1706// PDLLServer
1707//===----------------------------------------------------------------------===//
1708
1709lsp::PDLLServer::PDLLServer(const Options &options)
1710 : impl(std::make_unique<Impl>(args: options)) {}
1711lsp::PDLLServer::~PDLLServer() = default;
1712
1713void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
1714 int64_t version,
1715 std::vector<Diagnostic> &diagnostics) {
1716 // Build the set of additional include directories.
1717 std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1718 const auto &fileInfo = impl->compilationDatabase.getFileInfo(filename: uri.file());
1719 llvm::append_range(C&: additionalIncludeDirs, R: fileInfo.includeDirs);
1720
1721 impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1722 args: uri, args&: contents, args&: version, args&: additionalIncludeDirs, args&: diagnostics);
1723}
1724
1725void lsp::PDLLServer::updateDocument(
1726 const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1727 int64_t version, std::vector<Diagnostic> &diagnostics) {
1728 // Check that we actually have a document for this uri.
1729 auto it = impl->files.find(Key: uri.file());
1730 if (it == impl->files.end())
1731 return;
1732
1733 // Try to update the document. If we fail, erase the file from the server. A
1734 // failed updated generally means we've fallen out of sync somewhere.
1735 if (failed(Result: it->second->update(uri, newVersion: version, changes, diagnostics)))
1736 impl->files.erase(I: it);
1737}
1738
1739std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1740 auto it = impl->files.find(Key: uri.file());
1741 if (it == impl->files.end())
1742 return std::nullopt;
1743
1744 int64_t version = it->second->getVersion();
1745 impl->files.erase(I: it);
1746 return version;
1747}
1748
1749void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
1750 const Position &defPos,
1751 std::vector<Location> &locations) {
1752 auto fileIt = impl->files.find(Key: uri.file());
1753 if (fileIt != impl->files.end())
1754 fileIt->second->getLocationsOf(uri, defPos, locations);
1755}
1756
1757void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
1758 const Position &pos,
1759 std::vector<Location> &references) {
1760 auto fileIt = impl->files.find(Key: uri.file());
1761 if (fileIt != impl->files.end())
1762 fileIt->second->findReferencesOf(uri, pos, references);
1763}
1764
1765void lsp::PDLLServer::getDocumentLinks(
1766 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1767 auto fileIt = impl->files.find(Key: uri.file());
1768 if (fileIt != impl->files.end())
1769 return fileIt->second->getDocumentLinks(uri, links&: documentLinks);
1770}
1771
1772std::optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
1773 const Position &hoverPos) {
1774 auto fileIt = impl->files.find(Key: uri.file());
1775 if (fileIt != impl->files.end())
1776 return fileIt->second->findHover(uri, hoverPos);
1777 return std::nullopt;
1778}
1779
1780void lsp::PDLLServer::findDocumentSymbols(
1781 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1782 auto fileIt = impl->files.find(Key: uri.file());
1783 if (fileIt != impl->files.end())
1784 fileIt->second->findDocumentSymbols(symbols);
1785}
1786
1787lsp::CompletionList
1788lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
1789 const Position &completePos) {
1790 auto fileIt = impl->files.find(Key: uri.file());
1791 if (fileIt != impl->files.end())
1792 return fileIt->second->getCodeCompletion(uri, completePos);
1793 return CompletionList();
1794}
1795
1796lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
1797 const Position &helpPos) {
1798 auto fileIt = impl->files.find(Key: uri.file());
1799 if (fileIt != impl->files.end())
1800 return fileIt->second->getSignatureHelp(uri, helpPos);
1801 return SignatureHelp();
1802}
1803
1804void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1805 std::vector<InlayHint> &inlayHints) {
1806 auto fileIt = impl->files.find(Key: uri.file());
1807 if (fileIt == impl->files.end())
1808 return;
1809 fileIt->second->getInlayHints(uri, range, inlayHints);
1810
1811 // Drop any duplicated hints that may have cropped up.
1812 llvm::sort(C&: inlayHints);
1813 inlayHints.erase(first: llvm::unique(R&: inlayHints), last: inlayHints.end());
1814}
1815
1816std::optional<lsp::PDLLViewOutputResult>
1817lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri,
1818 PDLLViewOutputKind kind) {
1819 auto fileIt = impl->files.find(Key: uri.file());
1820 if (fileIt != impl->files.end())
1821 return fileIt->second->getPDLLViewOutput(kind);
1822 return std::nullopt;
1823}
1824

source code of mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp