1//===- PDLLServer.cpp - PDLL Language Server ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PDLLServer.h"
10
11#include "Protocol.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/Support/ToolUtilities.h"
14#include "mlir/Tools/PDLL/AST/Context.h"
15#include "mlir/Tools/PDLL/AST/Nodes.h"
16#include "mlir/Tools/PDLL/AST/Types.h"
17#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
18#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
19#include "mlir/Tools/PDLL/ODS/Constraint.h"
20#include "mlir/Tools/PDLL/ODS/Context.h"
21#include "mlir/Tools/PDLL/ODS/Dialect.h"
22#include "mlir/Tools/PDLL/ODS/Operation.h"
23#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
24#include "mlir/Tools/PDLL/Parser/Parser.h"
25#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
26#include "mlir/Tools/lsp-server-support/Logging.h"
27#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
28#include "llvm/ADT/IntervalMap.h"
29#include "llvm/ADT/StringMap.h"
30#include "llvm/ADT/StringSet.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/FileSystem.h"
33#include "llvm/Support/Path.h"
34#include <optional>
35
36using namespace mlir;
37using namespace mlir::pdll;
38
39/// Returns a language server uri for the given source location. `mainFileURI`
40/// corresponds to the uri for the main file of the source manager.
41static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc,
42 const lsp::URIForFile &mainFileURI) {
43 int bufferId = mgr.FindBufferContainingLoc(Loc: loc.Start);
44 if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
45 return mainFileURI;
46 llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
47 absoluteFilepath: mgr.getBufferInfo(i: bufferId).Buffer->getBufferIdentifier());
48 if (fileForLoc)
49 return *fileForLoc;
50 lsp::Logger::error(fmt: "Failed to create URI for include file: {0}",
51 vals: llvm::toString(E: fileForLoc.takeError()));
52 return mainFileURI;
53}
54
55/// Returns true if the given location is in the main file of the source
56/// manager.
57static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) {
58 return mgr.FindBufferContainingLoc(Loc: loc.Start) == mgr.getMainFileID();
59}
60
61/// Returns a language server location from the given source range.
62static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
63 const lsp::URIForFile &uri) {
64 return lsp::Location(getURIFromLoc(mgr, loc: range, mainFileURI: uri), lsp::Range(mgr, range));
65}
66
67/// Convert the given MLIR diagnostic to the LSP form.
68static std::optional<lsp::Diagnostic>
69getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag,
70 const lsp::URIForFile &uri) {
71 lsp::Diagnostic lspDiag;
72 lspDiag.source = "pdll";
73
74 // FIXME: Right now all of the diagnostics are treated as parser issues, but
75 // some are parser and some are verifier.
76 lspDiag.category = "Parse Error";
77
78 // Try to grab a file location for this diagnostic.
79 lsp::Location loc = getLocationFromLoc(mgr&: sourceMgr, range: diag.getLocation(), uri);
80 lspDiag.range = loc.range;
81
82 // Skip diagnostics that weren't emitted within the main file.
83 if (loc.uri != uri)
84 return std::nullopt;
85
86 // Convert the severity for the diagnostic.
87 switch (diag.getSeverity()) {
88 case ast::Diagnostic::Severity::DK_Note:
89 llvm_unreachable("expected notes to be handled separately");
90 case ast::Diagnostic::Severity::DK_Warning:
91 lspDiag.severity = lsp::DiagnosticSeverity::Warning;
92 break;
93 case ast::Diagnostic::Severity::DK_Error:
94 lspDiag.severity = lsp::DiagnosticSeverity::Error;
95 break;
96 case ast::Diagnostic::Severity::DK_Remark:
97 lspDiag.severity = lsp::DiagnosticSeverity::Information;
98 break;
99 }
100 lspDiag.message = diag.getMessage().str();
101
102 // Attach any notes to the main diagnostic as related information.
103 std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
104 for (const ast::Diagnostic &note : diag.getNotes()) {
105 relatedDiags.emplace_back(
106 args: getLocationFromLoc(mgr&: sourceMgr, range: note.getLocation(), uri),
107 args: note.getMessage().str());
108 }
109 if (!relatedDiags.empty())
110 lspDiag.relatedInformation = std::move(relatedDiags);
111
112 return lspDiag;
113}
114
115/// Get or extract the documentation for the given decl.
116static std::optional<std::string>
117getDocumentationFor(llvm::SourceMgr &sourceMgr, const ast::Decl *decl) {
118 // If the decl already had documentation set, use it.
119 if (std::optional<StringRef> doc = decl->getDocComment())
120 return doc->str();
121
122 // If the decl doesn't yet have documentation, try to extract it from the
123 // source file.
124 return lsp::extractSourceDocComment(sourceMgr, loc: decl->getLoc().Start);
125}
126
127//===----------------------------------------------------------------------===//
128// PDLIndex
129//===----------------------------------------------------------------------===//
130
131namespace {
132struct PDLIndexSymbol {
133 explicit PDLIndexSymbol(const ast::Decl *definition)
134 : definition(definition) {}
135 explicit PDLIndexSymbol(const ods::Operation *definition)
136 : definition(definition) {}
137
138 /// Return the location of the definition of this symbol.
139 SMRange getDefLoc() const {
140 if (const ast::Decl *decl =
141 llvm::dyn_cast_if_present<const ast::Decl *>(Val: definition)) {
142 const ast::Name *declName = decl->getName();
143 return declName ? declName->getLoc() : decl->getLoc();
144 }
145 return cast<const ods::Operation *>(Val: definition)->getLoc();
146 }
147
148 /// The main definition of the symbol.
149 PointerUnion<const ast::Decl *, const ods::Operation *> definition;
150 /// The set of references to the symbol.
151 std::vector<SMRange> references;
152};
153
154/// This class provides an index for definitions/uses within a PDL document.
155/// It provides efficient lookup of a definition given an input source range.
156class PDLIndex {
157public:
158 PDLIndex() : intervalMap(allocator) {}
159
160 /// Initialize the index with the given ast::Module.
161 void initialize(const ast::Module &module, const ods::Context &odsContext);
162
163 /// Lookup a symbol for the given location. Returns nullptr if no symbol could
164 /// be found. If provided, `overlappedRange` is set to the range that the
165 /// provided `loc` overlapped with.
166 const PDLIndexSymbol *lookup(SMLoc loc,
167 SMRange *overlappedRange = nullptr) const;
168
169private:
170 /// The type of interval map used to store source references. SMRange is
171 /// half-open, so we also need to use a half-open interval map.
172 using MapT =
173 llvm::IntervalMap<const char *, const PDLIndexSymbol *,
174 llvm::IntervalMapImpl::NodeSizer<
175 const char *, const PDLIndexSymbol *>::LeafSize,
176 llvm::IntervalMapHalfOpenInfo<const char *>>;
177
178 /// An allocator for the interval map.
179 MapT::Allocator allocator;
180
181 /// An interval map containing a corresponding definition mapped to a source
182 /// interval.
183 MapT intervalMap;
184
185 /// A mapping between definitions and their corresponding symbol.
186 DenseMap<const void *, std::unique_ptr<PDLIndexSymbol>> defToSymbol;
187};
188} // namespace
189
190void PDLIndex::initialize(const ast::Module &module,
191 const ods::Context &odsContext) {
192 auto getOrInsertDef = [&](const auto *def) -> PDLIndexSymbol * {
193 auto it = defToSymbol.try_emplace(def, nullptr);
194 if (it.second)
195 it.first->second = std::make_unique<PDLIndexSymbol>(def);
196 return &*it.first->second;
197 };
198 auto insertDeclRef = [&](PDLIndexSymbol *sym, SMRange refLoc,
199 bool isDef = false) {
200 const char *startLoc = refLoc.Start.getPointer();
201 const char *endLoc = refLoc.End.getPointer();
202 if (!intervalMap.overlaps(a: startLoc, b: endLoc)) {
203 intervalMap.insert(a: startLoc, b: endLoc, y: sym);
204 if (!isDef)
205 sym->references.push_back(x: refLoc);
206 }
207 };
208 auto insertODSOpRef = [&](StringRef opName, SMRange refLoc) {
209 const ods::Operation *odsOp = odsContext.lookupOperation(name: opName);
210 if (!odsOp)
211 return;
212
213 PDLIndexSymbol *symbol = getOrInsertDef(odsOp);
214 insertDeclRef(symbol, odsOp->getLoc(), /*isDef=*/true);
215 insertDeclRef(symbol, refLoc);
216 };
217
218 module.walk(walkFn: [&](const ast::Node *node) {
219 // Handle references to PDL decls.
220 if (const auto *decl = dyn_cast<ast::OpNameDecl>(Val: node)) {
221 if (std::optional<StringRef> name = decl->getName())
222 insertODSOpRef(*name, decl->getLoc());
223 } else if (const ast::Decl *decl = dyn_cast<ast::Decl>(Val: node)) {
224 const ast::Name *name = decl->getName();
225 if (!name)
226 return;
227 PDLIndexSymbol *declSym = getOrInsertDef(decl);
228 insertDeclRef(declSym, name->getLoc(), /*isDef=*/true);
229
230 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl)) {
231 // Record references to any constraints.
232 for (const auto &it : varDecl->getConstraints())
233 insertDeclRef(getOrInsertDef(it.constraint), it.referenceLoc);
234 }
235 } else if (const auto *expr = dyn_cast<ast::DeclRefExpr>(Val: node)) {
236 insertDeclRef(getOrInsertDef(expr->getDecl()), expr->getLoc());
237 }
238 });
239}
240
241const PDLIndexSymbol *PDLIndex::lookup(SMLoc loc,
242 SMRange *overlappedRange) const {
243 auto it = intervalMap.find(x: loc.getPointer());
244 if (!it.valid() || loc.getPointer() < it.start())
245 return nullptr;
246
247 if (overlappedRange) {
248 *overlappedRange = SMRange(SMLoc::getFromPointer(Ptr: it.start()),
249 SMLoc::getFromPointer(Ptr: it.stop()));
250 }
251 return it.value();
252}
253
254//===----------------------------------------------------------------------===//
255// PDLDocument
256//===----------------------------------------------------------------------===//
257
258namespace {
259/// This class represents all of the information pertaining to a specific PDL
260/// document.
261struct PDLDocument {
262 PDLDocument(const lsp::URIForFile &uri, StringRef contents,
263 const std::vector<std::string> &extraDirs,
264 std::vector<lsp::Diagnostic> &diagnostics);
265 PDLDocument(const PDLDocument &) = delete;
266 PDLDocument &operator=(const PDLDocument &) = delete;
267
268 //===--------------------------------------------------------------------===//
269 // Definitions and References
270 //===--------------------------------------------------------------------===//
271
272 void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
273 std::vector<lsp::Location> &locations);
274 void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
275 std::vector<lsp::Location> &references);
276
277 //===--------------------------------------------------------------------===//
278 // Document Links
279 //===--------------------------------------------------------------------===//
280
281 void getDocumentLinks(const lsp::URIForFile &uri,
282 std::vector<lsp::DocumentLink> &links);
283
284 //===--------------------------------------------------------------------===//
285 // Hover
286 //===--------------------------------------------------------------------===//
287
288 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
289 const lsp::Position &hoverPos);
290 std::optional<lsp::Hover> findHover(const ast::Decl *decl,
291 const SMRange &hoverRange);
292 lsp::Hover buildHoverForOpName(const ods::Operation *op,
293 const SMRange &hoverRange);
294 lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl,
295 const SMRange &hoverRange);
296 lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl,
297 const SMRange &hoverRange);
298 lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
299 const SMRange &hoverRange);
300 template <typename T>
301 lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName,
302 const T *decl,
303 const SMRange &hoverRange);
304
305 //===--------------------------------------------------------------------===//
306 // Document Symbols
307 //===--------------------------------------------------------------------===//
308
309 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
310
311 //===--------------------------------------------------------------------===//
312 // Code Completion
313 //===--------------------------------------------------------------------===//
314
315 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
316 const lsp::Position &completePos);
317
318 //===--------------------------------------------------------------------===//
319 // Signature Help
320 //===--------------------------------------------------------------------===//
321
322 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
323 const lsp::Position &helpPos);
324
325 //===--------------------------------------------------------------------===//
326 // Inlay Hints
327 //===--------------------------------------------------------------------===//
328
329 void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range,
330 std::vector<lsp::InlayHint> &inlayHints);
331 void getInlayHintsFor(const ast::VariableDecl *decl,
332 const lsp::URIForFile &uri,
333 std::vector<lsp::InlayHint> &inlayHints);
334 void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri,
335 std::vector<lsp::InlayHint> &inlayHints);
336 void getInlayHintsFor(const ast::OperationExpr *expr,
337 const lsp::URIForFile &uri,
338 std::vector<lsp::InlayHint> &inlayHints);
339
340 /// Add a parameter hint for the given expression using `label`.
341 void addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
342 const ast::Expr *expr, StringRef label);
343
344 //===--------------------------------------------------------------------===//
345 // PDLL ViewOutput
346 //===--------------------------------------------------------------------===//
347
348 void getPDLLViewOutput(raw_ostream &os, lsp::PDLLViewOutputKind kind);
349
350 //===--------------------------------------------------------------------===//
351 // Fields
352 //===--------------------------------------------------------------------===//
353
354 /// The include directories for this file.
355 std::vector<std::string> includeDirs;
356
357 /// The source manager containing the contents of the input file.
358 llvm::SourceMgr sourceMgr;
359
360 /// The ODS and AST contexts.
361 ods::Context odsContext;
362 ast::Context astContext;
363
364 /// The parsed AST module, or failure if the file wasn't valid.
365 FailureOr<ast::Module *> astModule;
366
367 /// The index of the parsed module.
368 PDLIndex index;
369
370 /// The set of includes of the parsed module.
371 SmallVector<lsp::SourceMgrInclude> parsedIncludes;
372};
373} // namespace
374
375PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents,
376 const std::vector<std::string> &extraDirs,
377 std::vector<lsp::Diagnostic> &diagnostics)
378 : astContext(odsContext) {
379 auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(InputData: contents, BufferName: uri.file());
380 if (!memBuffer) {
381 lsp::Logger::error(fmt: "Failed to create memory buffer for file", vals: uri.file());
382 return;
383 }
384
385 // Build the set of include directories for this file.
386 llvm::SmallString<32> uriDirectory(uri.file());
387 llvm::sys::path::remove_filename(path&: uriDirectory);
388 includeDirs.push_back(x: uriDirectory.str().str());
389 llvm::append_range(C&: includeDirs, R: extraDirs);
390
391 sourceMgr.setIncludeDirs(includeDirs);
392 sourceMgr.AddNewSourceBuffer(F: std::move(memBuffer), IncludeLoc: SMLoc());
393
394 astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
395 if (auto lspDiag = getLspDiagnoticFromDiag(sourceMgr, diag, uri))
396 diagnostics.push_back(x: std::move(*lspDiag));
397 });
398 astModule = parsePDLLAST(ctx&: astContext, sourceMgr, /*enableDocumentation=*/true);
399
400 // Initialize the set of parsed includes.
401 lsp::gatherIncludeFiles(sourceMgr, includes&: parsedIncludes);
402
403 // If we failed to parse the module, there is nothing left to initialize.
404 if (failed(Result: astModule))
405 return;
406
407 // Prepare the AST index with the parsed module.
408 index.initialize(module: **astModule, odsContext);
409}
410
411//===----------------------------------------------------------------------===//
412// PDLDocument: Definitions and References
413//===----------------------------------------------------------------------===//
414
415void PDLDocument::getLocationsOf(const lsp::URIForFile &uri,
416 const lsp::Position &defPos,
417 std::vector<lsp::Location> &locations) {
418 SMLoc posLoc = defPos.getAsSMLoc(mgr&: sourceMgr);
419 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc);
420 if (!symbol)
421 return;
422
423 locations.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: symbol->getDefLoc(), uri));
424}
425
426void PDLDocument::findReferencesOf(const lsp::URIForFile &uri,
427 const lsp::Position &pos,
428 std::vector<lsp::Location> &references) {
429 SMLoc posLoc = pos.getAsSMLoc(mgr&: sourceMgr);
430 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc);
431 if (!symbol)
432 return;
433
434 references.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: symbol->getDefLoc(), uri));
435 for (SMRange refLoc : symbol->references)
436 references.push_back(x: getLocationFromLoc(mgr&: sourceMgr, range: refLoc, uri));
437}
438
439//===--------------------------------------------------------------------===//
440// PDLDocument: Document Links
441//===--------------------------------------------------------------------===//
442
443void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri,
444 std::vector<lsp::DocumentLink> &links) {
445 for (const lsp::SourceMgrInclude &include : parsedIncludes)
446 links.emplace_back(args: include.range, args: include.uri);
447}
448
449//===----------------------------------------------------------------------===//
450// PDLDocument: Hover
451//===----------------------------------------------------------------------===//
452
453std::optional<lsp::Hover>
454PDLDocument::findHover(const lsp::URIForFile &uri,
455 const lsp::Position &hoverPos) {
456 SMLoc posLoc = hoverPos.getAsSMLoc(mgr&: sourceMgr);
457
458 // Check for a reference to an include.
459 for (const lsp::SourceMgrInclude &include : parsedIncludes)
460 if (include.range.contains(pos: hoverPos))
461 return include.buildHover();
462
463 // Find the symbol at the given location.
464 SMRange hoverRange;
465 const PDLIndexSymbol *symbol = index.lookup(loc: posLoc, overlappedRange: &hoverRange);
466 if (!symbol)
467 return std::nullopt;
468
469 // Add hover for operation names.
470 if (const auto *op =
471 llvm::dyn_cast_if_present<const ods::Operation *>(Val: symbol->definition))
472 return buildHoverForOpName(op, hoverRange);
473 const auto *decl = cast<const ast::Decl *>(Val: symbol->definition);
474 return findHover(decl, hoverRange);
475}
476
477std::optional<lsp::Hover> PDLDocument::findHover(const ast::Decl *decl,
478 const SMRange &hoverRange) {
479 // Add hover for variables.
480 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: decl))
481 return buildHoverForVariable(varDecl, hoverRange);
482
483 // Add hover for patterns.
484 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(Val: decl))
485 return buildHoverForPattern(decl: patternDecl, hoverRange);
486
487 // Add hover for core constraints.
488 if (const auto *cst = dyn_cast<ast::CoreConstraintDecl>(Val: decl))
489 return buildHoverForCoreConstraint(decl: cst, hoverRange);
490
491 // Add hover for user constraints.
492 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(Val: decl))
493 return buildHoverForUserConstraintOrRewrite(typeName: "Constraint", decl: cst, hoverRange);
494
495 // Add hover for user rewrites.
496 if (const auto *rewrite = dyn_cast<ast::UserRewriteDecl>(Val: decl))
497 return buildHoverForUserConstraintOrRewrite(typeName: "Rewrite", decl: rewrite, hoverRange);
498
499 return std::nullopt;
500}
501
502lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op,
503 const SMRange &hoverRange) {
504 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
505 {
506 llvm::raw_string_ostream hoverOS(hover.contents.value);
507 hoverOS << "**OpName**: `" << op->getName() << "`\n***\n"
508 << op->getSummary() << "\n***\n"
509 << op->getDescription();
510 }
511 return hover;
512}
513
514lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl,
515 const SMRange &hoverRange) {
516 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
517 {
518 llvm::raw_string_ostream hoverOS(hover.contents.value);
519 hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n"
520 << "Type: `" << varDecl->getType() << "`\n";
521 }
522 return hover;
523}
524
525lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl,
526 const SMRange &hoverRange) {
527 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
528 {
529 llvm::raw_string_ostream hoverOS(hover.contents.value);
530 hoverOS << "**Pattern**";
531 if (const ast::Name *name = decl->getName())
532 hoverOS << ": `" << name->getName() << "`";
533 hoverOS << "\n***\n";
534 if (std::optional<uint16_t> benefit = decl->getBenefit())
535 hoverOS << "Benefit: " << *benefit << "\n";
536 if (decl->hasBoundedRewriteRecursion())
537 hoverOS << "HasBoundedRewriteRecursion\n";
538 hoverOS << "RootOp: `"
539 << decl->getRootRewriteStmt()->getRootOpExpr()->getType() << "`\n";
540
541 // Format the documentation for the decl.
542 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
543 hoverOS << "\n" << *doc << "\n";
544 }
545 return hover;
546}
547
548lsp::Hover
549PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl,
550 const SMRange &hoverRange) {
551 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
552 {
553 llvm::raw_string_ostream hoverOS(hover.contents.value);
554 hoverOS << "**Constraint**: `";
555 TypeSwitch<const ast::Decl *>(decl)
556 .Case(caseFn: [&](const ast::AttrConstraintDecl *) { hoverOS << "Attr"; })
557 .Case(caseFn: [&](const ast::OpConstraintDecl *opCst) {
558 hoverOS << "Op";
559 if (std::optional<StringRef> name = opCst->getName())
560 hoverOS << "<" << *name << ">";
561 })
562 .Case(caseFn: [&](const ast::TypeConstraintDecl *) { hoverOS << "Type"; })
563 .Case(caseFn: [&](const ast::TypeRangeConstraintDecl *) {
564 hoverOS << "TypeRange";
565 })
566 .Case(caseFn: [&](const ast::ValueConstraintDecl *) { hoverOS << "Value"; })
567 .Case(caseFn: [&](const ast::ValueRangeConstraintDecl *) {
568 hoverOS << "ValueRange";
569 });
570 hoverOS << "`\n";
571 }
572 return hover;
573}
574
575template <typename T>
576lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite(
577 StringRef typeName, const T *decl, const SMRange &hoverRange) {
578 lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
579 {
580 llvm::raw_string_ostream hoverOS(hover.contents.value);
581 hoverOS << "**" << typeName << "**: `" << decl->getName().getName()
582 << "`\n***\n";
583 ArrayRef<ast::VariableDecl *> inputs = decl->getInputs();
584 if (!inputs.empty()) {
585 hoverOS << "Parameters:\n";
586 for (const ast::VariableDecl *input : inputs)
587 hoverOS << "* " << input->getName().getName() << ": `"
588 << input->getType() << "`\n";
589 hoverOS << "***\n";
590 }
591 ast::Type resultType = decl->getResultType();
592 if (auto resultTupleTy = dyn_cast<ast::TupleType>(Val&: resultType)) {
593 if (!resultTupleTy.empty()) {
594 hoverOS << "Results:\n";
595 for (auto it : llvm::zip(t: resultTupleTy.getElementNames(),
596 u: resultTupleTy.getElementTypes())) {
597 StringRef name = std::get<0>(t&: it);
598 hoverOS << "* " << (name.empty() ? "" : (name + ": ")) << "`"
599 << std::get<1>(t&: it) << "`\n";
600 }
601 hoverOS << "***\n";
602 }
603 } else {
604 hoverOS << "Results:\n* `" << resultType << "`\n";
605 hoverOS << "***\n";
606 }
607
608 // Format the documentation for the decl.
609 if (std::optional<std::string> doc = getDocumentationFor(sourceMgr, decl))
610 hoverOS << "\n" << *doc << "\n";
611 }
612 return hover;
613}
614
615//===----------------------------------------------------------------------===//
616// PDLDocument: Document Symbols
617//===----------------------------------------------------------------------===//
618
619void PDLDocument::findDocumentSymbols(
620 std::vector<lsp::DocumentSymbol> &symbols) {
621 if (failed(Result: astModule))
622 return;
623
624 for (const ast::Decl *decl : (*astModule)->getChildren()) {
625 if (!isMainFileLoc(mgr&: sourceMgr, loc: decl->getLoc()))
626 continue;
627
628 if (const auto *patternDecl = dyn_cast<ast::PatternDecl>(Val: decl)) {
629 const ast::Name *name = patternDecl->getName();
630
631 SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc();
632 SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End);
633
634 symbols.emplace_back(
635 args: name ? name->getName() : "<pattern>", args: lsp::SymbolKind::Class,
636 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
637 } else if (const auto *cDecl = dyn_cast<ast::UserConstraintDecl>(Val: decl)) {
638 // TODO: Add source information for the code block body.
639 SMRange nameLoc = cDecl->getName().getLoc();
640 SMRange bodyLoc = nameLoc;
641
642 symbols.emplace_back(
643 args: cDecl->getName().getName(), args: lsp::SymbolKind::Function,
644 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
645 } else if (const auto *cDecl = dyn_cast<ast::UserRewriteDecl>(Val: decl)) {
646 // TODO: Add source information for the code block body.
647 SMRange nameLoc = cDecl->getName().getLoc();
648 SMRange bodyLoc = nameLoc;
649
650 symbols.emplace_back(
651 args: cDecl->getName().getName(), args: lsp::SymbolKind::Function,
652 args: lsp::Range(sourceMgr, bodyLoc), args: lsp::Range(sourceMgr, nameLoc));
653 }
654 }
655}
656
657//===----------------------------------------------------------------------===//
658// PDLDocument: Code Completion
659//===----------------------------------------------------------------------===//
660
661namespace {
662class LSPCodeCompleteContext : public CodeCompleteContext {
663public:
664 LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
665 lsp::CompletionList &completionList,
666 ods::Context &odsContext,
667 ArrayRef<std::string> includeDirs)
668 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
669 completionList(completionList), odsContext(odsContext),
670 includeDirs(includeDirs) {}
671
672 void codeCompleteTupleMemberAccess(ast::TupleType tupleType) final {
673 ArrayRef<ast::Type> elementTypes = tupleType.getElementTypes();
674 ArrayRef<StringRef> elementNames = tupleType.getElementNames();
675 for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
676 // Push back a completion item that uses the result index.
677 lsp::CompletionItem item;
678 item.label = llvm::formatv(Fmt: "{0} (field #{0})", Vals&: i).str();
679 item.insertText = Twine(i).str();
680 item.filterText = item.sortText = item.insertText;
681 item.kind = lsp::CompletionItemKind::Field;
682 item.detail = llvm::formatv(Fmt: "{0}: {1}", Vals&: i, Vals: elementTypes[i]);
683 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
684 completionList.items.emplace_back(args&: item);
685
686 // If the element has a name, push back a completion item with that name.
687 if (!elementNames[i].empty()) {
688 item.label =
689 llvm::formatv(Fmt: "{1} (field #{0})", Vals&: i, Vals: elementNames[i]).str();
690 item.filterText = item.label;
691 item.insertText = elementNames[i].str();
692 completionList.items.emplace_back(args&: item);
693 }
694 }
695 }
696
697 void codeCompleteOperationMemberAccess(ast::OperationType opType) final {
698 const ods::Operation *odsOp = opType.getODSOperation();
699 if (!odsOp)
700 return;
701
702 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
703 for (const auto &it : llvm::enumerate(First&: results)) {
704 const ods::OperandOrResult &result = it.value();
705 const ods::TypeConstraint &constraint = result.getConstraint();
706
707 // Push back a completion item that uses the result index.
708 lsp::CompletionItem item;
709 item.label = llvm::formatv(Fmt: "{0} (field #{0})", Vals: it.index()).str();
710 item.insertText = Twine(it.index()).str();
711 item.filterText = item.sortText = item.insertText;
712 item.kind = lsp::CompletionItemKind::Field;
713 switch (result.getVariableLengthKind()) {
714 case ods::VariableLengthKind::Single:
715 item.detail = llvm::formatv(Fmt: "{0}: Value", Vals: it.index()).str();
716 break;
717 case ods::VariableLengthKind::Optional:
718 item.detail = llvm::formatv(Fmt: "{0}: Value?", Vals: it.index()).str();
719 break;
720 case ods::VariableLengthKind::Variadic:
721 item.detail = llvm::formatv(Fmt: "{0}: ValueRange", Vals: it.index()).str();
722 break;
723 }
724 item.documentation = lsp::MarkupContent{
725 .kind: lsp::MarkupKind::Markdown,
726 .value: llvm::formatv(Fmt: "{0}\n\n```c++\n{1}\n```\n", Vals: constraint.getSummary(),
727 Vals: constraint.getCppClass())
728 .str()};
729 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
730 completionList.items.emplace_back(args&: item);
731
732 // If the result has a name, push back a completion item with the result
733 // name.
734 if (!result.getName().empty()) {
735 item.label =
736 llvm::formatv(Fmt: "{1} (field #{0})", Vals: it.index(), Vals: result.getName())
737 .str();
738 item.filterText = item.label;
739 item.insertText = result.getName().str();
740 completionList.items.emplace_back(args&: item);
741 }
742 }
743 }
744
745 void codeCompleteOperationAttributeName(StringRef opName) final {
746 const ods::Operation *odsOp = odsContext.lookupOperation(name: opName);
747 if (!odsOp)
748 return;
749
750 for (const ods::Attribute &attr : odsOp->getAttributes()) {
751 const ods::AttributeConstraint &constraint = attr.getConstraint();
752
753 lsp::CompletionItem item;
754 item.label = attr.getName().str();
755 item.kind = lsp::CompletionItemKind::Field;
756 item.detail = attr.isOptional() ? "optional" : "";
757 item.documentation = lsp::MarkupContent{
758 .kind: lsp::MarkupKind::Markdown,
759 .value: llvm::formatv(Fmt: "{0}\n\n```c++\n{1}\n```\n", Vals: constraint.getSummary(),
760 Vals: constraint.getCppClass())
761 .str()};
762 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
763 completionList.items.emplace_back(args&: item);
764 }
765 }
766
767 void codeCompleteConstraintName(ast::Type currentType,
768 bool allowInlineTypeConstraints,
769 const ast::DeclScope *scope) final {
770 auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType,
771 StringRef snippetText = "") {
772 lsp::CompletionItem item;
773 item.label = constraint.str();
774 item.kind = lsp::CompletionItemKind::Class;
775 item.detail = (constraint + " constraint").str();
776 item.documentation = lsp::MarkupContent{
777 .kind: lsp::MarkupKind::Markdown,
778 .value: ("A single entity core constraint of type `" + mlirType + "`").str()};
779 item.sortText = "0";
780 item.insertText = snippetText.str();
781 item.insertTextFormat = snippetText.empty()
782 ? lsp::InsertTextFormat::PlainText
783 : lsp::InsertTextFormat::Snippet;
784 completionList.items.emplace_back(args&: item);
785 };
786
787 // Insert completions for the core constraints. Some core constraints have
788 // additional characteristics, so we may add then even if a type has been
789 // inferred.
790 if (!currentType) {
791 addCoreConstraint("Attr", "mlir::Attribute");
792 addCoreConstraint("Op", "mlir::Operation *");
793 addCoreConstraint("Value", "mlir::Value");
794 addCoreConstraint("ValueRange", "mlir::ValueRange");
795 addCoreConstraint("Type", "mlir::Type");
796 addCoreConstraint("TypeRange", "mlir::TypeRange");
797 }
798 if (allowInlineTypeConstraints) {
799 /// Attr<Type>.
800 if (!currentType || isa<ast::AttributeType>(Val: currentType))
801 addCoreConstraint("Attr<type>", "mlir::Attribute", "Attr<$1>");
802 /// Value<Type>.
803 if (!currentType || isa<ast::ValueType>(Val: currentType))
804 addCoreConstraint("Value<type>", "mlir::Value", "Value<$1>");
805 /// ValueRange<TypeRange>.
806 if (!currentType || isa<ast::ValueRangeType>(Val: currentType))
807 addCoreConstraint("ValueRange<type>", "mlir::ValueRange",
808 "ValueRange<$1>");
809 }
810
811 // If a scope was provided, check it for potential constraints.
812 while (scope) {
813 for (const ast::Decl *decl : scope->getDecls()) {
814 if (const auto *cst = dyn_cast<ast::UserConstraintDecl>(Val: decl)) {
815 lsp::CompletionItem item;
816 item.label = cst->getName().getName().str();
817 item.kind = lsp::CompletionItemKind::Interface;
818 item.sortText = "2_" + item.label;
819
820 // Skip constraints that are not single-arg. We currently only
821 // complete variable constraints.
822 if (cst->getInputs().size() != 1)
823 continue;
824
825 // Ensure the input type matched the given type.
826 ast::Type constraintType = cst->getInputs()[0]->getType();
827 if (currentType && !currentType.refineWith(other: constraintType))
828 continue;
829
830 // Format the constraint signature.
831 {
832 llvm::raw_string_ostream strOS(item.detail);
833 strOS << "(";
834 llvm::interleaveComma(
835 c: cst->getInputs(), os&: strOS, each_fn: [&](const ast::VariableDecl *var) {
836 strOS << var->getName().getName() << ": " << var->getType();
837 });
838 strOS << ") -> " << cst->getResultType();
839 }
840
841 // Format the documentation for the constraint.
842 if (std::optional<std::string> doc =
843 getDocumentationFor(sourceMgr, decl: cst)) {
844 item.documentation =
845 lsp::MarkupContent{.kind: lsp::MarkupKind::Markdown, .value: std::move(*doc)};
846 }
847
848 completionList.items.emplace_back(args&: item);
849 }
850 }
851
852 scope = scope->getParentScope();
853 }
854 }
855
856 void codeCompleteDialectName() final {
857 // Code complete known dialects.
858 for (const ods::Dialect &dialect : odsContext.getDialects()) {
859 lsp::CompletionItem item;
860 item.label = dialect.getName().str();
861 item.kind = lsp::CompletionItemKind::Class;
862 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
863 completionList.items.emplace_back(args&: item);
864 }
865 }
866
867 void codeCompleteOperationName(StringRef dialectName) final {
868 const ods::Dialect *dialect = odsContext.lookupDialect(name: dialectName);
869 if (!dialect)
870 return;
871
872 for (const auto &it : dialect->getOperations()) {
873 const ods::Operation &op = *it.second;
874
875 lsp::CompletionItem item;
876 item.label = op.getName().drop_front(N: dialectName.size() + 1).str();
877 item.kind = lsp::CompletionItemKind::Field;
878 item.insertTextFormat = lsp::InsertTextFormat::PlainText;
879 completionList.items.emplace_back(args&: item);
880 }
881 }
882
883 void codeCompletePatternMetadata() final {
884 auto addSimpleConstraint = [&](StringRef constraint, StringRef desc,
885 StringRef snippetText = "") {
886 lsp::CompletionItem item;
887 item.label = constraint.str();
888 item.kind = lsp::CompletionItemKind::Class;
889 item.detail = "pattern metadata";
890 item.documentation =
891 lsp::MarkupContent{.kind: lsp::MarkupKind::Markdown, .value: desc.str()};
892 item.insertText = snippetText.str();
893 item.insertTextFormat = snippetText.empty()
894 ? lsp::InsertTextFormat::PlainText
895 : lsp::InsertTextFormat::Snippet;
896 completionList.items.emplace_back(args&: item);
897 };
898
899 addSimpleConstraint("benefit", "The `benefit` of matching the pattern.",
900 "benefit($1)");
901 addSimpleConstraint("recursion",
902 "The pattern properly handles recursive application.");
903 }
904
905 void codeCompleteIncludeFilename(StringRef curPath) final {
906 // Normalize the path to allow for interacting with the file system
907 // utilities.
908 SmallString<128> nativeRelDir(llvm::sys::path::convert_to_slash(path: curPath));
909 llvm::sys::path::native(path&: nativeRelDir);
910
911 // Set of already included completion paths.
912 StringSet<> seenResults;
913
914 // Functor used to add a single include completion item.
915 auto addIncludeCompletion = [&](StringRef path, bool isDirectory) {
916 lsp::CompletionItem item;
917 item.label = path.str();
918 item.kind = isDirectory ? lsp::CompletionItemKind::Folder
919 : lsp::CompletionItemKind::File;
920 if (seenResults.insert(key: item.label).second)
921 completionList.items.emplace_back(args&: item);
922 };
923
924 // Process the include directories for this file, adding any potential
925 // nested include files or directories.
926 for (StringRef includeDir : includeDirs) {
927 llvm::SmallString<128> dir = includeDir;
928 if (!nativeRelDir.empty())
929 llvm::sys::path::append(path&: dir, a: nativeRelDir);
930
931 std::error_code errorCode;
932 for (auto it = llvm::sys::fs::directory_iterator(dir, errorCode),
933 e = llvm::sys::fs::directory_iterator();
934 !errorCode && it != e; it.increment(ec&: errorCode)) {
935 StringRef filename = llvm::sys::path::filename(path: it->path());
936
937 // To know whether a symlink should be treated as file or a directory,
938 // we have to stat it. This should be cheap enough as there shouldn't be
939 // many symlinks.
940 llvm::sys::fs::file_type fileType = it->type();
941 if (fileType == llvm::sys::fs::file_type::symlink_file) {
942 if (auto fileStatus = it->status())
943 fileType = fileStatus->type();
944 }
945
946 switch (fileType) {
947 case llvm::sys::fs::file_type::directory_file:
948 addIncludeCompletion(filename, /*isDirectory=*/true);
949 break;
950 case llvm::sys::fs::file_type::regular_file: {
951 // Only consider concrete files that can actually be included by PDLL.
952 if (filename.ends_with(Suffix: ".pdll") || filename.ends_with(Suffix: ".td"))
953 addIncludeCompletion(filename, /*isDirectory=*/false);
954 break;
955 }
956 default:
957 break;
958 }
959 }
960 }
961
962 // Sort the completion results to make sure the output is deterministic in
963 // the face of different iteration schemes for different platforms.
964 llvm::sort(C&: completionList.items, Comp: [](const lsp::CompletionItem &lhs,
965 const lsp::CompletionItem &rhs) {
966 return lhs.label < rhs.label;
967 });
968 }
969
970private:
971 llvm::SourceMgr &sourceMgr;
972 lsp::CompletionList &completionList;
973 ods::Context &odsContext;
974 ArrayRef<std::string> includeDirs;
975};
976} // namespace
977
978lsp::CompletionList
979PDLDocument::getCodeCompletion(const lsp::URIForFile &uri,
980 const lsp::Position &completePos) {
981 SMLoc posLoc = completePos.getAsSMLoc(mgr&: sourceMgr);
982 if (!posLoc.isValid())
983 return lsp::CompletionList();
984
985 // To perform code completion, we run another parse of the module with the
986 // code completion context provided.
987 ods::Context tmpODSContext;
988 lsp::CompletionList completionList;
989 LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList,
990 tmpODSContext,
991 sourceMgr.getIncludeDirs());
992
993 ast::Context tmpContext(tmpODSContext);
994 (void)parsePDLLAST(ctx&: tmpContext, sourceMgr, /*enableDocumentation=*/true,
995 codeCompleteContext: &lspCompleteContext);
996
997 return completionList;
998}
999
1000//===----------------------------------------------------------------------===//
1001// PDLDocument: Signature Help
1002//===----------------------------------------------------------------------===//
1003
1004namespace {
1005class LSPSignatureHelpContext : public CodeCompleteContext {
1006public:
1007 LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr,
1008 lsp::SignatureHelp &signatureHelp,
1009 ods::Context &odsContext)
1010 : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr),
1011 signatureHelp(signatureHelp), odsContext(odsContext) {}
1012
1013 void codeCompleteCallSignature(const ast::CallableDecl *callable,
1014 unsigned currentNumArgs) final {
1015 signatureHelp.activeParameter = currentNumArgs;
1016
1017 lsp::SignatureInformation signatureInfo;
1018 {
1019 llvm::raw_string_ostream strOS(signatureInfo.label);
1020 strOS << callable->getName()->getName() << "(";
1021 auto formatParamFn = [&](const ast::VariableDecl *var) {
1022 unsigned paramStart = strOS.str().size();
1023 strOS << var->getName().getName() << ": " << var->getType();
1024 unsigned paramEnd = strOS.str().size();
1025 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1026 .labelString: StringRef(strOS.str()).slice(Start: paramStart, End: paramEnd).str(),
1027 .labelOffsets: std::make_pair(x&: paramStart, y&: paramEnd), /*paramDoc*/ .documentation: std::string()});
1028 };
1029 llvm::interleaveComma(c: callable->getInputs(), os&: strOS, each_fn: formatParamFn);
1030 strOS << ") -> " << callable->getResultType();
1031 }
1032
1033 // Format the documentation for the callable.
1034 if (std::optional<std::string> doc =
1035 getDocumentationFor(sourceMgr, decl: callable))
1036 signatureInfo.documentation = std::move(*doc);
1037
1038 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1039 }
1040
1041 void
1042 codeCompleteOperationOperandsSignature(std::optional<StringRef> opName,
1043 unsigned currentNumOperands) final {
1044 const ods::Operation *odsOp =
1045 opName ? odsContext.lookupOperation(name: *opName) : nullptr;
1046 codeCompleteOperationOperandOrResultSignature(
1047 opName, odsOp, values: odsOp ? odsOp->getOperands() : std::nullopt,
1048 currentValue: currentNumOperands, label: "operand", dataType: "Value");
1049 }
1050
1051 void codeCompleteOperationResultsSignature(std::optional<StringRef> opName,
1052 unsigned currentNumResults) final {
1053 const ods::Operation *odsOp =
1054 opName ? odsContext.lookupOperation(name: *opName) : nullptr;
1055 codeCompleteOperationOperandOrResultSignature(
1056 opName, odsOp, values: odsOp ? odsOp->getResults() : std::nullopt,
1057 currentValue: currentNumResults, label: "result", dataType: "Type");
1058 }
1059
1060 void codeCompleteOperationOperandOrResultSignature(
1061 std::optional<StringRef> opName, const ods::Operation *odsOp,
1062 ArrayRef<ods::OperandOrResult> values, unsigned currentValue,
1063 StringRef label, StringRef dataType) {
1064 signatureHelp.activeParameter = currentValue;
1065
1066 // If we have ODS information for the operation, add in the ODS signature
1067 // for the operation. We also verify that the current number of values is
1068 // not more than what is defined in ODS, as this will result in an error
1069 // anyways.
1070 if (odsOp && currentValue < values.size()) {
1071 lsp::SignatureInformation signatureInfo;
1072
1073 // Build the signature label.
1074 {
1075 llvm::raw_string_ostream strOS(signatureInfo.label);
1076 strOS << "(";
1077 auto formatFn = [&](const ods::OperandOrResult &value) {
1078 unsigned paramStart = strOS.str().size();
1079
1080 strOS << value.getName() << ": ";
1081
1082 StringRef constraintDoc = value.getConstraint().getSummary();
1083 std::string paramDoc;
1084 switch (value.getVariableLengthKind()) {
1085 case ods::VariableLengthKind::Single:
1086 strOS << dataType;
1087 paramDoc = constraintDoc.str();
1088 break;
1089 case ods::VariableLengthKind::Optional:
1090 strOS << dataType << "?";
1091 paramDoc = ("optional: " + constraintDoc).str();
1092 break;
1093 case ods::VariableLengthKind::Variadic:
1094 strOS << dataType << "Range";
1095 paramDoc = ("variadic: " + constraintDoc).str();
1096 break;
1097 }
1098
1099 unsigned paramEnd = strOS.str().size();
1100 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1101 .labelString: StringRef(strOS.str()).slice(Start: paramStart, End: paramEnd).str(),
1102 .labelOffsets: std::make_pair(x&: paramStart, y&: paramEnd), .documentation: paramDoc});
1103 };
1104 llvm::interleaveComma(c: values, os&: strOS, each_fn: formatFn);
1105 strOS << ")";
1106 }
1107 signatureInfo.documentation =
1108 llvm::formatv(Fmt: "`op<{0}>` ODS {1} specification", Vals&: *opName, Vals&: label)
1109 .str();
1110 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1111 }
1112
1113 // If there aren't any arguments yet, we also add the generic signature.
1114 if (currentValue == 0 && (!odsOp || !values.empty())) {
1115 lsp::SignatureInformation signatureInfo;
1116 signatureInfo.label =
1117 llvm::formatv(Fmt: "(<{0}s>: {1}Range)", Vals&: label, Vals&: dataType).str();
1118 signatureInfo.documentation =
1119 ("Generic operation " + label + " specification").str();
1120 signatureInfo.parameters.emplace_back(args: lsp::ParameterInformation{
1121 .labelString: StringRef(signatureInfo.label).drop_front().drop_back().str(),
1122 .labelOffsets: std::pair<unsigned, unsigned>(1, signatureInfo.label.size() - 1),
1123 .documentation: ("All of the " + label + "s of the operation.").str()});
1124 signatureHelp.signatures.emplace_back(args: std::move(signatureInfo));
1125 }
1126 }
1127
1128private:
1129 llvm::SourceMgr &sourceMgr;
1130 lsp::SignatureHelp &signatureHelp;
1131 ods::Context &odsContext;
1132};
1133} // namespace
1134
1135lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri,
1136 const lsp::Position &helpPos) {
1137 SMLoc posLoc = helpPos.getAsSMLoc(mgr&: sourceMgr);
1138 if (!posLoc.isValid())
1139 return lsp::SignatureHelp();
1140
1141 // To perform code completion, we run another parse of the module with the
1142 // code completion context provided.
1143 ods::Context tmpODSContext;
1144 lsp::SignatureHelp signatureHelp;
1145 LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp,
1146 tmpODSContext);
1147
1148 ast::Context tmpContext(tmpODSContext);
1149 (void)parsePDLLAST(ctx&: tmpContext, sourceMgr, /*enableDocumentation=*/true,
1150 codeCompleteContext: &completeContext);
1151
1152 return signatureHelp;
1153}
1154
1155//===----------------------------------------------------------------------===//
1156// PDLDocument: Inlay Hints
1157//===----------------------------------------------------------------------===//
1158
1159/// Returns true if the given name should be added as a hint for `expr`.
1160static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) {
1161 if (name.empty())
1162 return false;
1163
1164 // If the argument is a reference of the same name, don't add it as a hint.
1165 if (auto *ref = dyn_cast<ast::DeclRefExpr>(Val: expr)) {
1166 const ast::Name *declName = ref->getDecl()->getName();
1167 if (declName && declName->getName() == name)
1168 return false;
1169 }
1170
1171 return true;
1172}
1173
1174void PDLDocument::getInlayHints(const lsp::URIForFile &uri,
1175 const lsp::Range &range,
1176 std::vector<lsp::InlayHint> &inlayHints) {
1177 if (failed(Result: astModule))
1178 return;
1179 SMRange rangeLoc = range.getAsSMRange(mgr&: sourceMgr);
1180 if (!rangeLoc.isValid())
1181 return;
1182 (*astModule)->walk(walkFn: [&](const ast::Node *node) {
1183 SMRange loc = node->getLoc();
1184
1185 // Check that the location of this node is within the input range.
1186 if (!lsp::contains(range: rangeLoc, loc: loc.Start) &&
1187 !lsp::contains(range: rangeLoc, loc: loc.End))
1188 return;
1189
1190 // Handle hints for various types of nodes.
1191 llvm::TypeSwitch<const ast::Node *>(node)
1192 .Case<ast::VariableDecl, ast::CallExpr, ast::OperationExpr>(
1193 caseFn: [&](const auto *node) {
1194 this->getInlayHintsFor(node, uri, inlayHints);
1195 });
1196 });
1197}
1198
1199void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl,
1200 const lsp::URIForFile &uri,
1201 std::vector<lsp::InlayHint> &inlayHints) {
1202 // Check to see if the variable has a constraint list, if it does we don't
1203 // provide initializer hints.
1204 if (!decl->getConstraints().empty())
1205 return;
1206
1207 // Check to see if the variable has an initializer.
1208 if (const ast::Expr *expr = decl->getInitExpr()) {
1209 // Don't add hints for operation expression initialized variables given that
1210 // the type of the variable is easily inferred by the expression operation
1211 // name.
1212 if (isa<ast::OperationExpr>(Val: expr))
1213 return;
1214 }
1215
1216 lsp::InlayHint hint(lsp::InlayHintKind::Type,
1217 lsp::Position(sourceMgr, decl->getLoc().End));
1218 {
1219 llvm::raw_string_ostream labelOS(hint.label);
1220 labelOS << ": " << decl->getType();
1221 }
1222
1223 inlayHints.emplace_back(args: std::move(hint));
1224}
1225
1226void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr,
1227 const lsp::URIForFile &uri,
1228 std::vector<lsp::InlayHint> &inlayHints) {
1229 // Try to extract the callable of this call.
1230 const auto *callableRef = dyn_cast<ast::DeclRefExpr>(Val: expr->getCallableExpr());
1231 const auto *callable =
1232 callableRef ? dyn_cast<ast::CallableDecl>(Val: callableRef->getDecl())
1233 : nullptr;
1234 if (!callable)
1235 return;
1236
1237 // Add hints for the arguments to the call.
1238 for (const auto &it : llvm::zip(t: expr->getArguments(), u: callable->getInputs()))
1239 addParameterHintFor(inlayHints, expr: std::get<0>(t: it),
1240 label: std::get<1>(t: it)->getName().getName());
1241}
1242
1243void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr,
1244 const lsp::URIForFile &uri,
1245 std::vector<lsp::InlayHint> &inlayHints) {
1246 // Check for ODS information.
1247 ast::OperationType opType = dyn_cast<ast::OperationType>(Val: expr->getType());
1248 const auto *odsOp = opType ? opType.getODSOperation() : nullptr;
1249
1250 auto addOpHint = [&](const ast::Expr *valueExpr, StringRef label) {
1251 // If the value expression used the same location as the operation, don't
1252 // add a hint. This expression was materialized during parsing.
1253 if (expr->getLoc().Start == valueExpr->getLoc().Start)
1254 return;
1255 addParameterHintFor(inlayHints, expr: valueExpr, label);
1256 };
1257
1258 // Functor used to process hints for the operands and results of the
1259 // operation. They effectively have the same format, and thus can be processed
1260 // using the same logic.
1261 auto addOperandOrResultHints = [&](ArrayRef<ast::Expr *> values,
1262 ArrayRef<ods::OperandOrResult> odsValues,
1263 StringRef allValuesName) {
1264 if (values.empty())
1265 return;
1266
1267 // The values should either map to a single range, or be equivalent to the
1268 // ODS values.
1269 if (values.size() != odsValues.size()) {
1270 // Handle the case of a single element that covers the full range.
1271 if (values.size() == 1)
1272 return addOpHint(values.front(), allValuesName);
1273 return;
1274 }
1275
1276 for (const auto &it : llvm::zip(t&: values, u&: odsValues))
1277 addOpHint(std::get<0>(t: it), std::get<1>(t: it).getName());
1278 };
1279
1280 // Add hints for the operands and results of the operation.
1281 addOperandOrResultHints(expr->getOperands(),
1282 odsOp ? odsOp->getOperands()
1283 : ArrayRef<ods::OperandOrResult>(),
1284 "operands");
1285 addOperandOrResultHints(expr->getResultTypes(),
1286 odsOp ? odsOp->getResults()
1287 : ArrayRef<ods::OperandOrResult>(),
1288 "results");
1289}
1290
1291void PDLDocument::addParameterHintFor(std::vector<lsp::InlayHint> &inlayHints,
1292 const ast::Expr *expr, StringRef label) {
1293 if (!shouldAddHintFor(expr, name: label))
1294 return;
1295
1296 lsp::InlayHint hint(lsp::InlayHintKind::Parameter,
1297 lsp::Position(sourceMgr, expr->getLoc().Start));
1298 hint.label = (label + ":").str();
1299 hint.paddingRight = true;
1300 inlayHints.emplace_back(args: std::move(hint));
1301}
1302
1303//===----------------------------------------------------------------------===//
1304// PDLL ViewOutput
1305//===----------------------------------------------------------------------===//
1306
1307void PDLDocument::getPDLLViewOutput(raw_ostream &os,
1308 lsp::PDLLViewOutputKind kind) {
1309 if (failed(Result: astModule))
1310 return;
1311 if (kind == lsp::PDLLViewOutputKind::AST) {
1312 (*astModule)->print(os);
1313 return;
1314 }
1315
1316 // Generate the MLIR for the ast module. We also capture diagnostics here to
1317 // show to the user, which may be useful if PDLL isn't capturing constraints
1318 // expected by PDL.
1319 MLIRContext mlirContext;
1320 SourceMgrDiagnosticHandler diagHandler(sourceMgr, &mlirContext, os);
1321 OwningOpRef<ModuleOp> pdlModule =
1322 codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **astModule);
1323 if (!pdlModule)
1324 return;
1325 if (kind == lsp::PDLLViewOutputKind::MLIR) {
1326 pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
1327 return;
1328 }
1329
1330 // Otherwise, generate the output for C++.
1331 assert(kind == lsp::PDLLViewOutputKind::CPP &&
1332 "unexpected PDLLViewOutputKind");
1333 codegenPDLLToCPP(**astModule, *pdlModule, os);
1334}
1335
1336//===----------------------------------------------------------------------===//
1337// PDLTextFileChunk
1338//===----------------------------------------------------------------------===//
1339
1340namespace {
1341/// This class represents a single chunk of an PDL text file.
1342struct PDLTextFileChunk {
1343 PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri,
1344 StringRef contents,
1345 const std::vector<std::string> &extraDirs,
1346 std::vector<lsp::Diagnostic> &diagnostics)
1347 : lineOffset(lineOffset),
1348 document(uri, contents, extraDirs, diagnostics) {}
1349
1350 /// Adjust the line number of the given range to anchor at the beginning of
1351 /// the file, instead of the beginning of this chunk.
1352 void adjustLocForChunkOffset(lsp::Range &range) {
1353 adjustLocForChunkOffset(pos&: range.start);
1354 adjustLocForChunkOffset(pos&: range.end);
1355 }
1356 /// Adjust the line number of the given position to anchor at the beginning of
1357 /// the file, instead of the beginning of this chunk.
1358 void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; }
1359
1360 /// The line offset of this chunk from the beginning of the file.
1361 uint64_t lineOffset;
1362 /// The document referred to by this chunk.
1363 PDLDocument document;
1364};
1365} // namespace
1366
1367//===----------------------------------------------------------------------===//
1368// PDLTextFile
1369//===----------------------------------------------------------------------===//
1370
1371namespace {
1372/// This class represents a text file containing one or more PDL documents.
1373class PDLTextFile {
1374public:
1375 PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1376 int64_t version, const std::vector<std::string> &extraDirs,
1377 std::vector<lsp::Diagnostic> &diagnostics);
1378
1379 /// Return the current version of this text file.
1380 int64_t getVersion() const { return version; }
1381
1382 /// Update the file to the new version using the provided set of content
1383 /// changes. Returns failure if the update was unsuccessful.
1384 LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
1385 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1386 std::vector<lsp::Diagnostic> &diagnostics);
1387
1388 //===--------------------------------------------------------------------===//
1389 // LSP Queries
1390 //===--------------------------------------------------------------------===//
1391
1392 void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos,
1393 std::vector<lsp::Location> &locations);
1394 void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos,
1395 std::vector<lsp::Location> &references);
1396 void getDocumentLinks(const lsp::URIForFile &uri,
1397 std::vector<lsp::DocumentLink> &links);
1398 std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
1399 lsp::Position hoverPos);
1400 void findDocumentSymbols(std::vector<lsp::DocumentSymbol> &symbols);
1401 lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri,
1402 lsp::Position completePos);
1403 lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri,
1404 lsp::Position helpPos);
1405 void getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1406 std::vector<lsp::InlayHint> &inlayHints);
1407 lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind);
1408
1409private:
1410 using ChunkIterator = llvm::pointee_iterator<
1411 std::vector<std::unique_ptr<PDLTextFileChunk>>::iterator>;
1412
1413 /// Initialize the text file from the given file contents.
1414 void initialize(const lsp::URIForFile &uri, int64_t newVersion,
1415 std::vector<lsp::Diagnostic> &diagnostics);
1416
1417 /// Find the PDL document that contains the given position, and update the
1418 /// position to be anchored at the start of the found chunk instead of the
1419 /// beginning of the file.
1420 ChunkIterator getChunkItFor(lsp::Position &pos);
1421 PDLTextFileChunk &getChunkFor(lsp::Position &pos) {
1422 return *getChunkItFor(pos);
1423 }
1424
1425 /// The full string contents of the file.
1426 std::string contents;
1427
1428 /// The version of this file.
1429 int64_t version = 0;
1430
1431 /// The number of lines in the file.
1432 int64_t totalNumLines = 0;
1433
1434 /// The chunks of this file. The order of these chunks is the order in which
1435 /// they appear in the text file.
1436 std::vector<std::unique_ptr<PDLTextFileChunk>> chunks;
1437
1438 /// The extra set of include directories for this file.
1439 std::vector<std::string> extraIncludeDirs;
1440};
1441} // namespace
1442
1443PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents,
1444 int64_t version,
1445 const std::vector<std::string> &extraDirs,
1446 std::vector<lsp::Diagnostic> &diagnostics)
1447 : contents(fileContents.str()), extraIncludeDirs(extraDirs) {
1448 initialize(uri, newVersion: version, diagnostics);
1449}
1450
1451LogicalResult
1452PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
1453 ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
1454 std::vector<lsp::Diagnostic> &diagnostics) {
1455 if (failed(Result: lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
1456 lsp::Logger::error(fmt: "Failed to update contents of {0}", vals: uri.file());
1457 return failure();
1458 }
1459
1460 // If the file contents were properly changed, reinitialize the text file.
1461 initialize(uri, newVersion, diagnostics);
1462 return success();
1463}
1464
1465void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri,
1466 lsp::Position defPos,
1467 std::vector<lsp::Location> &locations) {
1468 PDLTextFileChunk &chunk = getChunkFor(pos&: defPos);
1469 chunk.document.getLocationsOf(uri, defPos, locations);
1470
1471 // Adjust any locations within this file for the offset of this chunk.
1472 if (chunk.lineOffset == 0)
1473 return;
1474 for (lsp::Location &loc : locations)
1475 if (loc.uri == uri)
1476 chunk.adjustLocForChunkOffset(range&: loc.range);
1477}
1478
1479void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri,
1480 lsp::Position pos,
1481 std::vector<lsp::Location> &references) {
1482 PDLTextFileChunk &chunk = getChunkFor(pos);
1483 chunk.document.findReferencesOf(uri, pos, references);
1484
1485 // Adjust any locations within this file for the offset of this chunk.
1486 if (chunk.lineOffset == 0)
1487 return;
1488 for (lsp::Location &loc : references)
1489 if (loc.uri == uri)
1490 chunk.adjustLocForChunkOffset(range&: loc.range);
1491}
1492
1493void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri,
1494 std::vector<lsp::DocumentLink> &links) {
1495 chunks.front()->document.getDocumentLinks(uri, links);
1496 for (const auto &it : llvm::drop_begin(RangeOrContainer&: chunks)) {
1497 size_t currentNumLinks = links.size();
1498 it->document.getDocumentLinks(uri, links);
1499
1500 // Adjust any links within this file to account for the offset of this
1501 // chunk.
1502 for (auto &link : llvm::drop_begin(RangeOrContainer&: links, N: currentNumLinks))
1503 it->adjustLocForChunkOffset(range&: link.range);
1504 }
1505}
1506
1507std::optional<lsp::Hover> PDLTextFile::findHover(const lsp::URIForFile &uri,
1508 lsp::Position hoverPos) {
1509 PDLTextFileChunk &chunk = getChunkFor(pos&: hoverPos);
1510 std::optional<lsp::Hover> hoverInfo = chunk.document.findHover(uri, hoverPos);
1511
1512 // Adjust any locations within this file for the offset of this chunk.
1513 if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range)
1514 chunk.adjustLocForChunkOffset(range&: *hoverInfo->range);
1515 return hoverInfo;
1516}
1517
1518void PDLTextFile::findDocumentSymbols(
1519 std::vector<lsp::DocumentSymbol> &symbols) {
1520 if (chunks.size() == 1)
1521 return chunks.front()->document.findDocumentSymbols(symbols);
1522
1523 // If there are multiple chunks in this file, we create top-level symbols for
1524 // each chunk.
1525 for (unsigned i = 0, e = chunks.size(); i < e; ++i) {
1526 PDLTextFileChunk &chunk = *chunks[i];
1527 lsp::Position startPos(chunk.lineOffset);
1528 lsp::Position endPos((i == e - 1) ? totalNumLines - 1
1529 : chunks[i + 1]->lineOffset);
1530 lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
1531 lsp::SymbolKind::Namespace,
1532 /*range=*/lsp::Range(startPos, endPos),
1533 /*selectionRange=*/lsp::Range(startPos));
1534 chunk.document.findDocumentSymbols(symbols&: symbol.children);
1535
1536 // Fixup the locations of document symbols within this chunk.
1537 if (i != 0) {
1538 SmallVector<lsp::DocumentSymbol *> symbolsToFix;
1539 for (lsp::DocumentSymbol &childSymbol : symbol.children)
1540 symbolsToFix.push_back(Elt: &childSymbol);
1541
1542 while (!symbolsToFix.empty()) {
1543 lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val();
1544 chunk.adjustLocForChunkOffset(range&: symbol->range);
1545 chunk.adjustLocForChunkOffset(range&: symbol->selectionRange);
1546
1547 for (lsp::DocumentSymbol &childSymbol : symbol->children)
1548 symbolsToFix.push_back(Elt: &childSymbol);
1549 }
1550 }
1551
1552 // Push the symbol for this chunk.
1553 symbols.emplace_back(args: std::move(symbol));
1554 }
1555}
1556
1557lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri,
1558 lsp::Position completePos) {
1559 PDLTextFileChunk &chunk = getChunkFor(pos&: completePos);
1560 lsp::CompletionList completionList =
1561 chunk.document.getCodeCompletion(uri, completePos);
1562
1563 // Adjust any completion locations.
1564 for (lsp::CompletionItem &item : completionList.items) {
1565 if (item.textEdit)
1566 chunk.adjustLocForChunkOffset(range&: item.textEdit->range);
1567 for (lsp::TextEdit &edit : item.additionalTextEdits)
1568 chunk.adjustLocForChunkOffset(range&: edit.range);
1569 }
1570 return completionList;
1571}
1572
1573lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri,
1574 lsp::Position helpPos) {
1575 return getChunkFor(pos&: helpPos).document.getSignatureHelp(uri, helpPos);
1576}
1577
1578void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range,
1579 std::vector<lsp::InlayHint> &inlayHints) {
1580 auto startIt = getChunkItFor(pos&: range.start);
1581 auto endIt = getChunkItFor(pos&: range.end);
1582
1583 // Functor used to get the chunks for a given file, and fixup any locations
1584 auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) {
1585 size_t currentNumHints = inlayHints.size();
1586 chunkIt->document.getInlayHints(uri, range, inlayHints);
1587
1588 // If this isn't the first chunk, update any positions to account for line
1589 // number differences.
1590 if (&*chunkIt != &*chunks.front()) {
1591 for (auto &hint : llvm::drop_begin(RangeOrContainer&: inlayHints, N: currentNumHints))
1592 chunkIt->adjustLocForChunkOffset(pos&: hint.position);
1593 }
1594 };
1595 // Returns the number of lines held by a given chunk.
1596 auto getNumLines = [](ChunkIterator chunkIt) {
1597 return (chunkIt + 1)->lineOffset - chunkIt->lineOffset;
1598 };
1599
1600 // Check if the range is fully within a single chunk.
1601 if (startIt == endIt)
1602 return getHintsForChunk(startIt, range);
1603
1604 // Otherwise, the range is split between multiple chunks. The first chunk
1605 // has the correct range start, but covers the total document.
1606 getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt)));
1607
1608 // Every chunk in between uses the full document.
1609 for (++startIt; startIt != endIt; ++startIt)
1610 getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt)));
1611
1612 // The range for the last chunk starts at the beginning of the document, up
1613 // through the end of the input range.
1614 getHintsForChunk(startIt, lsp::Range(0, range.end));
1615}
1616
1617lsp::PDLLViewOutputResult
1618PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) {
1619 lsp::PDLLViewOutputResult result;
1620 {
1621 llvm::raw_string_ostream outputOS(result.output);
1622 llvm::interleave(
1623 c: llvm::make_pointee_range(Range&: chunks),
1624 each_fn: [&](PDLTextFileChunk &chunk) {
1625 chunk.document.getPDLLViewOutput(os&: outputOS, kind);
1626 },
1627 between_fn: [&] { outputOS << "\n"
1628 << kDefaultSplitMarker << "\n\n"; });
1629 }
1630 return result;
1631}
1632
1633void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion,
1634 std::vector<lsp::Diagnostic> &diagnostics) {
1635 version = newVersion;
1636 chunks.clear();
1637
1638 // Split the file into separate PDL documents.
1639 SmallVector<StringRef, 8> subContents;
1640 StringRef(contents).split(A&: subContents, Separator: kDefaultSplitMarker);
1641 chunks.emplace_back(args: std::make_unique<PDLTextFileChunk>(
1642 /*lineOffset=*/args: 0, args: uri, args&: subContents.front(), args&: extraIncludeDirs,
1643 args&: diagnostics));
1644
1645 uint64_t lineOffset = subContents.front().count(C: '\n');
1646 for (StringRef docContents : llvm::drop_begin(RangeOrContainer&: subContents)) {
1647 unsigned currentNumDiags = diagnostics.size();
1648 auto chunk = std::make_unique<PDLTextFileChunk>(
1649 args&: lineOffset, args: uri, args&: docContents, args&: extraIncludeDirs, args&: diagnostics);
1650 lineOffset += docContents.count(C: '\n');
1651
1652 // Adjust locations used in diagnostics to account for the offset from the
1653 // beginning of the file.
1654 for (lsp::Diagnostic &diag :
1655 llvm::drop_begin(RangeOrContainer&: diagnostics, N: currentNumDiags)) {
1656 chunk->adjustLocForChunkOffset(range&: diag.range);
1657
1658 if (!diag.relatedInformation)
1659 continue;
1660 for (auto &it : *diag.relatedInformation)
1661 if (it.location.uri == uri)
1662 chunk->adjustLocForChunkOffset(range&: it.location.range);
1663 }
1664 chunks.emplace_back(args: std::move(chunk));
1665 }
1666 totalNumLines = lineOffset;
1667}
1668
1669PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) {
1670 if (chunks.size() == 1)
1671 return chunks.begin();
1672
1673 // Search for the first chunk with a greater line offset, the previous chunk
1674 // is the one that contains `pos`.
1675 auto it = llvm::upper_bound(
1676 Range&: chunks, Value&: pos, C: [](const lsp::Position &pos, const auto &chunk) {
1677 return static_cast<uint64_t>(pos.line) < chunk->lineOffset;
1678 });
1679 ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it);
1680 pos.line -= chunkIt->lineOffset;
1681 return chunkIt;
1682}
1683
1684//===----------------------------------------------------------------------===//
1685// PDLLServer::Impl
1686//===----------------------------------------------------------------------===//
1687
1688struct lsp::PDLLServer::Impl {
1689 explicit Impl(const Options &options)
1690 : options(options), compilationDatabase(options.compilationDatabases) {}
1691
1692 /// PDLL LSP options.
1693 const Options &options;
1694
1695 /// The compilation database containing additional information for files
1696 /// passed to the server.
1697 lsp::CompilationDatabase compilationDatabase;
1698
1699 /// The files held by the server, mapped by their URI file name.
1700 llvm::StringMap<std::unique_ptr<PDLTextFile>> files;
1701};
1702
1703//===----------------------------------------------------------------------===//
1704// PDLLServer
1705//===----------------------------------------------------------------------===//
1706
1707lsp::PDLLServer::PDLLServer(const Options &options)
1708 : impl(std::make_unique<Impl>(args: options)) {}
1709lsp::PDLLServer::~PDLLServer() = default;
1710
1711void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents,
1712 int64_t version,
1713 std::vector<Diagnostic> &diagnostics) {
1714 // Build the set of additional include directories.
1715 std::vector<std::string> additionalIncludeDirs = impl->options.extraDirs;
1716 const auto &fileInfo = impl->compilationDatabase.getFileInfo(filename: uri.file());
1717 llvm::append_range(C&: additionalIncludeDirs, R: fileInfo.includeDirs);
1718
1719 impl->files[uri.file()] = std::make_unique<PDLTextFile>(
1720 args: uri, args&: contents, args&: version, args&: additionalIncludeDirs, args&: diagnostics);
1721}
1722
1723void lsp::PDLLServer::updateDocument(
1724 const URIForFile &uri, ArrayRef<TextDocumentContentChangeEvent> changes,
1725 int64_t version, std::vector<Diagnostic> &diagnostics) {
1726 // Check that we actually have a document for this uri.
1727 auto it = impl->files.find(Key: uri.file());
1728 if (it == impl->files.end())
1729 return;
1730
1731 // Try to update the document. If we fail, erase the file from the server. A
1732 // failed updated generally means we've fallen out of sync somewhere.
1733 if (failed(Result: it->second->update(uri, newVersion: version, changes, diagnostics)))
1734 impl->files.erase(I: it);
1735}
1736
1737std::optional<int64_t> lsp::PDLLServer::removeDocument(const URIForFile &uri) {
1738 auto it = impl->files.find(Key: uri.file());
1739 if (it == impl->files.end())
1740 return std::nullopt;
1741
1742 int64_t version = it->second->getVersion();
1743 impl->files.erase(I: it);
1744 return version;
1745}
1746
1747void lsp::PDLLServer::getLocationsOf(const URIForFile &uri,
1748 const Position &defPos,
1749 std::vector<Location> &locations) {
1750 auto fileIt = impl->files.find(Key: uri.file());
1751 if (fileIt != impl->files.end())
1752 fileIt->second->getLocationsOf(uri, defPos, locations);
1753}
1754
1755void lsp::PDLLServer::findReferencesOf(const URIForFile &uri,
1756 const Position &pos,
1757 std::vector<Location> &references) {
1758 auto fileIt = impl->files.find(Key: uri.file());
1759 if (fileIt != impl->files.end())
1760 fileIt->second->findReferencesOf(uri, pos, references);
1761}
1762
1763void lsp::PDLLServer::getDocumentLinks(
1764 const URIForFile &uri, std::vector<DocumentLink> &documentLinks) {
1765 auto fileIt = impl->files.find(Key: uri.file());
1766 if (fileIt != impl->files.end())
1767 return fileIt->second->getDocumentLinks(uri, links&: documentLinks);
1768}
1769
1770std::optional<lsp::Hover> lsp::PDLLServer::findHover(const URIForFile &uri,
1771 const Position &hoverPos) {
1772 auto fileIt = impl->files.find(Key: uri.file());
1773 if (fileIt != impl->files.end())
1774 return fileIt->second->findHover(uri, hoverPos);
1775 return std::nullopt;
1776}
1777
1778void lsp::PDLLServer::findDocumentSymbols(
1779 const URIForFile &uri, std::vector<DocumentSymbol> &symbols) {
1780 auto fileIt = impl->files.find(Key: uri.file());
1781 if (fileIt != impl->files.end())
1782 fileIt->second->findDocumentSymbols(symbols);
1783}
1784
1785lsp::CompletionList
1786lsp::PDLLServer::getCodeCompletion(const URIForFile &uri,
1787 const Position &completePos) {
1788 auto fileIt = impl->files.find(Key: uri.file());
1789 if (fileIt != impl->files.end())
1790 return fileIt->second->getCodeCompletion(uri, completePos);
1791 return CompletionList();
1792}
1793
1794lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri,
1795 const Position &helpPos) {
1796 auto fileIt = impl->files.find(Key: uri.file());
1797 if (fileIt != impl->files.end())
1798 return fileIt->second->getSignatureHelp(uri, helpPos);
1799 return SignatureHelp();
1800}
1801
1802void lsp::PDLLServer::getInlayHints(const URIForFile &uri, const Range &range,
1803 std::vector<InlayHint> &inlayHints) {
1804 auto fileIt = impl->files.find(Key: uri.file());
1805 if (fileIt == impl->files.end())
1806 return;
1807 fileIt->second->getInlayHints(uri, range, inlayHints);
1808
1809 // Drop any duplicated hints that may have cropped up.
1810 llvm::sort(C&: inlayHints);
1811 inlayHints.erase(first: llvm::unique(R&: inlayHints), last: inlayHints.end());
1812}
1813
1814std::optional<lsp::PDLLViewOutputResult>
1815lsp::PDLLServer::getPDLLViewOutput(const URIForFile &uri,
1816 PDLLViewOutputKind kind) {
1817 auto fileIt = impl->files.find(Key: uri.file());
1818 if (fileIt != impl->files.end())
1819 return fileIt->second->getPDLLViewOutput(kind);
1820 return std::nullopt;
1821}
1822

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp