1//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// DialectGen uses the description of dialects to generate C++ definitions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "DialectGenUtilities.h"
14#include "mlir/TableGen/Class.h"
15#include "mlir/TableGen/CodeGenHelpers.h"
16#include "mlir/TableGen/Format.h"
17#include "mlir/TableGen/GenInfo.h"
18#include "mlir/TableGen/Interfaces.h"
19#include "mlir/TableGen/Operator.h"
20#include "mlir/TableGen/Trait.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/Support/CommandLine.h"
24#include "llvm/Support/Signals.h"
25#include "llvm/TableGen/Error.h"
26#include "llvm/TableGen/Record.h"
27#include "llvm/TableGen/TableGenBackend.h"
28
29#define DEBUG_TYPE "mlir-tblgen-opdefgen"
30
31using namespace mlir;
32using namespace mlir::tblgen;
33
34static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
35llvm::cl::opt<std::string>
36 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
37 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
38
39/// Utility iterator used for filtering records for a specific dialect.
40namespace {
41using DialectFilterIterator =
42 llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
43 std::function<bool(const llvm::Record *)>>;
44} // namespace
45
46static void populateDiscardableAttributes(
47 Dialect &dialect, llvm::DagInit *discardableAttrDag,
48 SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
49 for (int i : llvm::seq<int>(Begin: 0, End: discardableAttrDag->getNumArgs())) {
50 llvm::Init *arg = discardableAttrDag->getArg(Num: i);
51
52 StringRef givenName = discardableAttrDag->getArgNameStr(Num: i);
53 if (givenName.empty())
54 PrintFatalError(ErrorLoc: dialect.getDef()->getLoc(),
55 Msg: "discardable attributes must be named");
56 discardableAttributes.push_back(
57 Elt: {givenName.str(), arg->getAsUnquotedString()});
58 }
59}
60
61/// Given a set of records for a T, filter the ones that correspond to
62/// the given dialect.
63template <typename T>
64static iterator_range<DialectFilterIterator>
65filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
66 auto filterFn = [&](const llvm::Record *record) {
67 return T(record).getDialect() == dialect;
68 };
69 return {DialectFilterIterator(records.begin(), records.end(), filterFn),
70 DialectFilterIterator(records.end(), records.end(), filterFn)};
71}
72
73std::optional<Dialect>
74tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
75 if (dialects.empty()) {
76 llvm::errs() << "no dialect was found\n";
77 return std::nullopt;
78 }
79
80 // Select the dialect to gen for.
81 if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
82 return dialects.front();
83
84 if (selectedDialect.getNumOccurrences() == 0) {
85 llvm::errs() << "when more than 1 dialect is present, one must be selected "
86 "via '-dialect'\n";
87 return std::nullopt;
88 }
89
90 const auto *dialectIt = llvm::find_if(Range&: dialects, P: [](const Dialect &dialect) {
91 return dialect.getName() == selectedDialect;
92 });
93 if (dialectIt == dialects.end()) {
94 llvm::errs() << "selected dialect with '-dialect' does not exist\n";
95 return std::nullopt;
96 }
97 return *dialectIt;
98}
99
100//===----------------------------------------------------------------------===//
101// GEN: Dialect declarations
102//===----------------------------------------------------------------------===//
103
104/// The code block for the start of a dialect class declaration.
105///
106/// {0}: The name of the dialect class.
107/// {1}: The dialect namespace.
108/// {2}: The dialect parent class.
109static const char *const dialectDeclBeginStr = R"(
110class {0} : public ::mlir::{2} {
111 explicit {0}(::mlir::MLIRContext *context);
112
113 void initialize();
114 friend class ::mlir::MLIRContext;
115public:
116 ~{0}() override;
117 static constexpr ::llvm::StringLiteral getDialectNamespace() {
118 return ::llvm::StringLiteral("{1}");
119 }
120)";
121
122/// Registration for a single dependent dialect: to be inserted in the ctor
123/// above for each dependent dialect.
124const char *const dialectRegistrationTemplate =
125 "getContext()->loadDialect<{0}>();";
126
127/// The code block for the attribute parser/printer hooks.
128static const char *const attrParserDecl = R"(
129 /// Parse an attribute registered to this dialect.
130 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
131 ::mlir::Type type) const override;
132
133 /// Print an attribute registered to this dialect.
134 void printAttribute(::mlir::Attribute attr,
135 ::mlir::DialectAsmPrinter &os) const override;
136)";
137
138/// The code block for the type parser/printer hooks.
139static const char *const typeParserDecl = R"(
140 /// Parse a type registered to this dialect.
141 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
142
143 /// Print a type registered to this dialect.
144 void printType(::mlir::Type type,
145 ::mlir::DialectAsmPrinter &os) const override;
146)";
147
148/// The code block for the canonicalization pattern registration hook.
149static const char *const canonicalizerDecl = R"(
150 /// Register canonicalization patterns.
151 void getCanonicalizationPatterns(
152 ::mlir::RewritePatternSet &results) const override;
153)";
154
155/// The code block for the constant materializer hook.
156static const char *const constantMaterializerDecl = R"(
157 /// Materialize a single constant operation from a given attribute value with
158 /// the desired resultant type.
159 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
160 ::mlir::Attribute value,
161 ::mlir::Type type,
162 ::mlir::Location loc) override;
163)";
164
165/// The code block for the operation attribute verifier hook.
166static const char *const opAttrVerifierDecl = R"(
167 /// Provides a hook for verifying dialect attributes attached to the given
168 /// op.
169 ::mlir::LogicalResult verifyOperationAttribute(
170 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
171)";
172
173/// The code block for the region argument attribute verifier hook.
174static const char *const regionArgAttrVerifierDecl = R"(
175 /// Provides a hook for verifying dialect attributes attached to the given
176 /// op's region argument.
177 ::mlir::LogicalResult verifyRegionArgAttribute(
178 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
179 ::mlir::NamedAttribute attribute) override;
180)";
181
182/// The code block for the region result attribute verifier hook.
183static const char *const regionResultAttrVerifierDecl = R"(
184 /// Provides a hook for verifying dialect attributes attached to the given
185 /// op's region result.
186 ::mlir::LogicalResult verifyRegionResultAttribute(
187 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
188 ::mlir::NamedAttribute attribute) override;
189)";
190
191/// The code block for the op interface fallback hook.
192static const char *const operationInterfaceFallbackDecl = R"(
193 /// Provides a hook for op interface.
194 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
195 mlir::OperationName opName) override;
196)";
197
198/// The code block for the discardable attribute helper.
199static const char *const discardableAttrHelperDecl = R"(
200 /// Helper to manage the discardable attribute `{1}`.
201 class {0}AttrHelper {{
202 ::mlir::StringAttr name;
203 public:
204 static constexpr ::llvm::StringLiteral getNameStr() {{
205 return "{4}.{1}";
206 }
207 constexpr ::mlir::StringAttr getName() {{
208 return name;
209 }
210
211 {0}AttrHelper(::mlir::MLIRContext *ctx)
212 : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
213
214 {2} getAttr(::mlir::Operation *op) {{
215 return op->getAttrOfType<{2}>(name);
216 }
217 void setAttr(::mlir::Operation *op, {2} val) {{
218 op->setAttr(name, val);
219 }
220 bool isAttrPresent(::mlir::Operation *op) {{
221 return op->hasAttrOfType<{2}>(name);
222 }
223 void removeAttr(::mlir::Operation *op) {{
224 assert(op->hasAttrOfType<{2}>(name));
225 op->removeAttr(name);
226 }
227 };
228 {0}AttrHelper get{0}AttrHelper() {
229 return {3}AttrName;
230 }
231 private:
232 {0}AttrHelper {3}AttrName;
233 public:
234)";
235
236/// Generate the declaration for the given dialect class.
237static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
238 // Emit all nested namespaces.
239 {
240 NamespaceEmitter nsEmitter(os, dialect);
241
242 // Emit the start of the decl.
243 std::string cppName = dialect.getCppClassName();
244 StringRef superClassName =
245 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
246 os << llvm::formatv(Fmt: dialectDeclBeginStr, Vals&: cppName, Vals: dialect.getName(),
247 Vals&: superClassName);
248
249 // If the dialect requested the default attribute printer and parser, emit
250 // the declarations for the hooks.
251 if (dialect.useDefaultAttributePrinterParser())
252 os << attrParserDecl;
253 // If the dialect requested the default type printer and parser, emit the
254 // delcarations for the hooks.
255 if (dialect.useDefaultTypePrinterParser())
256 os << typeParserDecl;
257
258 // Add the decls for the various features of the dialect.
259 if (dialect.hasCanonicalizer())
260 os << canonicalizerDecl;
261 if (dialect.hasConstantMaterializer())
262 os << constantMaterializerDecl;
263 if (dialect.hasOperationAttrVerify())
264 os << opAttrVerifierDecl;
265 if (dialect.hasRegionArgAttrVerify())
266 os << regionArgAttrVerifierDecl;
267 if (dialect.hasRegionResultAttrVerify())
268 os << regionResultAttrVerifierDecl;
269 if (dialect.hasOperationInterfaceFallback())
270 os << operationInterfaceFallbackDecl;
271
272 llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
273 SmallVector<std::pair<std::string, std::string>> discardableAttributes;
274 populateDiscardableAttributes(dialect, discardableAttrDag,
275 discardableAttributes);
276
277 for (const auto &attrPair : discardableAttributes) {
278 std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
279 input: attrPair.first, /*capitalizeFirst=*/true);
280 std::string camelName = llvm::convertToCamelFromSnakeCase(
281 input: attrPair.first, /*capitalizeFirst=*/false);
282 os << llvm::formatv(Fmt: discardableAttrHelperDecl, Vals&: camelNameUpper,
283 Vals: attrPair.first, Vals: attrPair.second, Vals&: camelName,
284 Vals: dialect.getName());
285 }
286
287 if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
288 os << *extraDecl;
289
290 // End the dialect decl.
291 os << "};\n";
292 }
293 if (!dialect.getCppNamespace().empty())
294 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
295 << "::" << dialect.getCppClassName() << ")\n";
296}
297
298static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
299 raw_ostream &os) {
300 emitSourceFileHeader(Desc: "Dialect Declarations", OS&: os, Record: recordKeeper);
301
302 auto dialectDefs = recordKeeper.getAllDerivedDefinitions(ClassName: "Dialect");
303 if (dialectDefs.empty())
304 return false;
305
306 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
307 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
308 if (!dialect)
309 return true;
310 emitDialectDecl(dialect&: *dialect, os);
311 return false;
312}
313
314//===----------------------------------------------------------------------===//
315// GEN: Dialect definitions
316//===----------------------------------------------------------------------===//
317
318/// The code block to generate a dialect constructor definition.
319///
320/// {0}: The name of the dialect class.
321/// {1}: Initialization code that is emitted in the ctor body before calling
322/// initialize(), such as dependent dialect registration.
323/// {2}: The dialect parent class.
324/// {3}: Extra members to initialize
325static const char *const dialectConstructorStr = R"(
326{0}::{0}(::mlir::MLIRContext *context)
327 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
328 {3}
329 {{
330 {1}
331 initialize();
332}
333)";
334
335/// The code block to generate a default destructor definition.
336///
337/// {0}: The name of the dialect class.
338static const char *const dialectDestructorStr = R"(
339{0}::~{0}() = default;
340
341)";
342
343static void emitDialectDef(Dialect &dialect,
344 const llvm::RecordKeeper &recordKeeper,
345 raw_ostream &os) {
346 std::string cppClassName = dialect.getCppClassName();
347
348 // Emit the TypeID explicit specializations to have a single symbol def.
349 if (!dialect.getCppNamespace().empty())
350 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
351 << "::" << cppClassName << ")\n";
352
353 // Emit all nested namespaces.
354 NamespaceEmitter nsEmitter(os, dialect);
355
356 /// Build the list of dependent dialects.
357 std::string dependentDialectRegistrations;
358 {
359 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
360 llvm::interleave(
361 c: dialect.getDependentDialects(), os&: dialectsOs,
362 each_fn: [&](StringRef dependentDialect) {
363 dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate,
364 Vals&: dependentDialect);
365 },
366 separator: "\n ");
367 }
368
369 // Emit the constructor and destructor.
370 StringRef superClassName =
371 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
372
373 llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
374 SmallVector<std::pair<std::string, std::string>> discardableAttributes;
375 populateDiscardableAttributes(dialect, discardableAttrDag,
376 discardableAttributes);
377 std::string discardableAttributesInit;
378 for (const auto &attrPair : discardableAttributes) {
379 std::string camelName = llvm::convertToCamelFromSnakeCase(
380 input: attrPair.first, /*capitalizeFirst=*/false);
381 llvm::raw_string_ostream os(discardableAttributesInit);
382 os << ", " << camelName << "AttrName(context)";
383 }
384
385 os << llvm::formatv(Fmt: dialectConstructorStr, Vals&: cppClassName,
386 Vals&: dependentDialectRegistrations, Vals&: superClassName,
387 Vals&: discardableAttributesInit);
388 if (!dialect.hasNonDefaultDestructor())
389 os << llvm::formatv(Fmt: dialectDestructorStr, Vals&: cppClassName);
390}
391
392static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
393 raw_ostream &os) {
394 emitSourceFileHeader(Desc: "Dialect Definitions", OS&: os, Record: recordKeeper);
395
396 auto dialectDefs = recordKeeper.getAllDerivedDefinitions(ClassName: "Dialect");
397 if (dialectDefs.empty())
398 return false;
399
400 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
401 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
402 if (!dialect)
403 return true;
404 emitDialectDef(dialect&: *dialect, recordKeeper, os);
405 return false;
406}
407
408//===----------------------------------------------------------------------===//
409// GEN: Dialect registration hooks
410//===----------------------------------------------------------------------===//
411
412static mlir::GenRegistration
413 genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
414 [](const llvm::RecordKeeper &records, raw_ostream &os) {
415 return emitDialectDecls(recordKeeper: records, os);
416 });
417
418static mlir::GenRegistration
419 genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
420 [](const llvm::RecordKeeper &records, raw_ostream &os) {
421 return emitDialectDefs(recordKeeper: records, os);
422 });
423

source code of mlir/tools/mlir-tblgen/DialectGen.cpp