1//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// DialectGen uses the description of dialects to generate C++ definitions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "CppGenUtilities.h"
14#include "DialectGenUtilities.h"
15#include "mlir/TableGen/Class.h"
16#include "mlir/TableGen/CodeGenHelpers.h"
17#include "mlir/TableGen/Format.h"
18#include "mlir/TableGen/GenInfo.h"
19#include "mlir/TableGen/Interfaces.h"
20#include "mlir/TableGen/Operator.h"
21#include "mlir/TableGen/Trait.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/Signals.h"
26#include "llvm/TableGen/Error.h"
27#include "llvm/TableGen/Record.h"
28#include "llvm/TableGen/TableGenBackend.h"
29
30#define DEBUG_TYPE "mlir-tblgen-opdefgen"
31
32using namespace mlir;
33using namespace mlir::tblgen;
34using llvm::Record;
35using llvm::RecordKeeper;
36
37static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
38static llvm::cl::opt<std::string>
39 selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
40 llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
41
42/// Utility iterator used for filtering records for a specific dialect.
43namespace {
44using DialectFilterIterator =
45 llvm::filter_iterator<ArrayRef<Record *>::iterator,
46 std::function<bool(const Record *)>>;
47} // namespace
48
49static void populateDiscardableAttributes(
50 Dialect &dialect, const llvm::DagInit *discardableAttrDag,
51 SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
52 for (int i : llvm::seq<int>(Begin: 0, End: discardableAttrDag->getNumArgs())) {
53 const llvm::Init *arg = discardableAttrDag->getArg(Num: i);
54
55 StringRef givenName = discardableAttrDag->getArgNameStr(Num: i);
56 if (givenName.empty())
57 PrintFatalError(ErrorLoc: dialect.getDef()->getLoc(),
58 Msg: "discardable attributes must be named");
59 discardableAttributes.push_back(
60 Elt: {givenName.str(), arg->getAsUnquotedString()});
61 }
62}
63
64/// Given a set of records for a T, filter the ones that correspond to
65/// the given dialect.
66template <typename T>
67static iterator_range<DialectFilterIterator>
68filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
69 auto filterFn = [&](const Record *record) {
70 return T(record).getDialect() == dialect;
71 };
72 return {DialectFilterIterator(records.begin(), records.end(), filterFn),
73 DialectFilterIterator(records.end(), records.end(), filterFn)};
74}
75
76std::optional<Dialect>
77tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
78 if (dialects.empty()) {
79 llvm::errs() << "no dialect was found\n";
80 return std::nullopt;
81 }
82
83 // Select the dialect to gen for.
84 if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
85 return dialects.front();
86
87 if (selectedDialect.getNumOccurrences() == 0) {
88 llvm::errs() << "when more than 1 dialect is present, one must be selected "
89 "via '-dialect'\n";
90 return std::nullopt;
91 }
92
93 const auto *dialectIt = llvm::find_if(Range&: dialects, P: [](const Dialect &dialect) {
94 return dialect.getName() == selectedDialect;
95 });
96 if (dialectIt == dialects.end()) {
97 llvm::errs() << "selected dialect with '-dialect' does not exist\n";
98 return std::nullopt;
99 }
100 return *dialectIt;
101}
102
103//===----------------------------------------------------------------------===//
104// GEN: Dialect declarations
105//===----------------------------------------------------------------------===//
106
107/// The code block for the start of a dialect class declaration.
108///
109/// {0}: The name of the dialect class.
110/// {1}: The dialect namespace.
111/// {2}: The dialect parent class.
112/// {3}: The summary and description comments.
113static const char *const dialectDeclBeginStr = R"(
114{3}
115class {0} : public ::mlir::{2} {
116 explicit {0}(::mlir::MLIRContext *context);
117
118 void initialize();
119 friend class ::mlir::MLIRContext;
120public:
121 ~{0}() override;
122 static constexpr ::llvm::StringLiteral getDialectNamespace() {
123 return ::llvm::StringLiteral("{1}");
124 }
125)";
126
127/// Registration for a single dependent dialect: to be inserted in the ctor
128/// above for each dependent dialect.
129const char *const dialectRegistrationTemplate =
130 "getContext()->loadDialect<{0}>();";
131
132/// The code block for the attribute parser/printer hooks.
133static const char *const attrParserDecl = R"(
134 /// Parse an attribute registered to this dialect.
135 ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
136 ::mlir::Type type) const override;
137
138 /// Print an attribute registered to this dialect.
139 void printAttribute(::mlir::Attribute attr,
140 ::mlir::DialectAsmPrinter &os) const override;
141)";
142
143/// The code block for the type parser/printer hooks.
144static const char *const typeParserDecl = R"(
145 /// Parse a type registered to this dialect.
146 ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
147
148 /// Print a type registered to this dialect.
149 void printType(::mlir::Type type,
150 ::mlir::DialectAsmPrinter &os) const override;
151)";
152
153/// The code block for the canonicalization pattern registration hook.
154static const char *const canonicalizerDecl = R"(
155 /// Register canonicalization patterns.
156 void getCanonicalizationPatterns(
157 ::mlir::RewritePatternSet &results) const override;
158)";
159
160/// The code block for the constant materializer hook.
161static const char *const constantMaterializerDecl = R"(
162 /// Materialize a single constant operation from a given attribute value with
163 /// the desired resultant type.
164 ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
165 ::mlir::Attribute value,
166 ::mlir::Type type,
167 ::mlir::Location loc) override;
168)";
169
170/// The code block for the operation attribute verifier hook.
171static const char *const opAttrVerifierDecl = R"(
172 /// Provides a hook for verifying dialect attributes attached to the given
173 /// op.
174 ::llvm::LogicalResult verifyOperationAttribute(
175 ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
176)";
177
178/// The code block for the region argument attribute verifier hook.
179static const char *const regionArgAttrVerifierDecl = R"(
180 /// Provides a hook for verifying dialect attributes attached to the given
181 /// op's region argument.
182 ::llvm::LogicalResult verifyRegionArgAttribute(
183 ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
184 ::mlir::NamedAttribute attribute) override;
185)";
186
187/// The code block for the region result attribute verifier hook.
188static const char *const regionResultAttrVerifierDecl = R"(
189 /// Provides a hook for verifying dialect attributes attached to the given
190 /// op's region result.
191 ::llvm::LogicalResult verifyRegionResultAttribute(
192 ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
193 ::mlir::NamedAttribute attribute) override;
194)";
195
196/// The code block for the op interface fallback hook.
197static const char *const operationInterfaceFallbackDecl = R"(
198 /// Provides a hook for op interface.
199 void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
200 mlir::OperationName opName) override;
201)";
202
203/// The code block for the discardable attribute helper.
204static const char *const discardableAttrHelperDecl = R"(
205 /// Helper to manage the discardable attribute `{1}`.
206 class {0}AttrHelper {{
207 ::mlir::StringAttr name;
208 public:
209 static constexpr ::llvm::StringLiteral getNameStr() {{
210 return "{4}.{1}";
211 }
212 constexpr ::mlir::StringAttr getName() {{
213 return name;
214 }
215
216 {0}AttrHelper(::mlir::MLIRContext *ctx)
217 : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
218
219 {2} getAttr(::mlir::Operation *op) {{
220 return op->getAttrOfType<{2}>(name);
221 }
222 void setAttr(::mlir::Operation *op, {2} val) {{
223 op->setAttr(name, val);
224 }
225 bool isAttrPresent(::mlir::Operation *op) {{
226 return op->hasAttrOfType<{2}>(name);
227 }
228 void removeAttr(::mlir::Operation *op) {{
229 assert(op->hasAttrOfType<{2}>(name));
230 op->removeAttr(name);
231 }
232 };
233 {0}AttrHelper get{0}AttrHelper() {
234 return {3}AttrName;
235 }
236 private:
237 {0}AttrHelper {3}AttrName;
238 public:
239)";
240
241/// Generate the declaration for the given dialect class.
242static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
243 // Emit all nested namespaces.
244 {
245 NamespaceEmitter nsEmitter(os, dialect);
246
247 // Emit the start of the decl.
248 std::string cppName = dialect.getCppClassName();
249 StringRef superClassName =
250 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
251
252 std::string comments = tblgen::emitSummaryAndDescComments(
253 summary: dialect.getSummary(), description: dialect.getDescription());
254 os << llvm::formatv(Fmt: dialectDeclBeginStr, Vals&: cppName, Vals: dialect.getName(),
255 Vals&: superClassName, Vals&: comments);
256
257 // If the dialect requested the default attribute printer and parser, emit
258 // the declarations for the hooks.
259 if (dialect.useDefaultAttributePrinterParser())
260 os << attrParserDecl;
261 // If the dialect requested the default type printer and parser, emit the
262 // delcarations for the hooks.
263 if (dialect.useDefaultTypePrinterParser())
264 os << typeParserDecl;
265
266 // Add the decls for the various features of the dialect.
267 if (dialect.hasCanonicalizer())
268 os << canonicalizerDecl;
269 if (dialect.hasConstantMaterializer())
270 os << constantMaterializerDecl;
271 if (dialect.hasOperationAttrVerify())
272 os << opAttrVerifierDecl;
273 if (dialect.hasRegionArgAttrVerify())
274 os << regionArgAttrVerifierDecl;
275 if (dialect.hasRegionResultAttrVerify())
276 os << regionResultAttrVerifierDecl;
277 if (dialect.hasOperationInterfaceFallback())
278 os << operationInterfaceFallbackDecl;
279
280 const llvm::DagInit *discardableAttrDag =
281 dialect.getDiscardableAttributes();
282 SmallVector<std::pair<std::string, std::string>> discardableAttributes;
283 populateDiscardableAttributes(dialect, discardableAttrDag,
284 discardableAttributes);
285
286 for (const auto &attrPair : discardableAttributes) {
287 std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
288 input: attrPair.first, /*capitalizeFirst=*/true);
289 std::string camelName = llvm::convertToCamelFromSnakeCase(
290 input: attrPair.first, /*capitalizeFirst=*/false);
291 os << llvm::formatv(Fmt: discardableAttrHelperDecl, Vals&: camelNameUpper,
292 Vals: attrPair.first, Vals: attrPair.second, Vals&: camelName,
293 Vals: dialect.getName());
294 }
295
296 if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
297 os << *extraDecl;
298
299 // End the dialect decl.
300 os << "};\n";
301 }
302 if (!dialect.getCppNamespace().empty())
303 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
304 << "::" << dialect.getCppClassName() << ")\n";
305}
306
307static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) {
308 emitSourceFileHeader(Desc: "Dialect Declarations", OS&: os, Record: records);
309
310 auto dialectDefs = records.getAllDerivedDefinitions(ClassName: "Dialect");
311 if (dialectDefs.empty())
312 return false;
313
314 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
315 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
316 if (!dialect)
317 return true;
318 emitDialectDecl(dialect&: *dialect, os);
319 return false;
320}
321
322//===----------------------------------------------------------------------===//
323// GEN: Dialect definitions
324//===----------------------------------------------------------------------===//
325
326/// The code block to generate a dialect constructor definition.
327///
328/// {0}: The name of the dialect class.
329/// {1}: Initialization code that is emitted in the ctor body before calling
330/// initialize(), such as dependent dialect registration.
331/// {2}: The dialect parent class.
332/// {3}: Extra members to initialize
333static const char *const dialectConstructorStr = R"(
334{0}::{0}(::mlir::MLIRContext *context)
335 : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
336 {3}
337 {{
338 {1}
339 initialize();
340}
341)";
342
343/// The code block to generate a default destructor definition.
344///
345/// {0}: The name of the dialect class.
346static const char *const dialectDestructorStr = R"(
347{0}::~{0}() = default;
348
349)";
350
351static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
352 raw_ostream &os) {
353 std::string cppClassName = dialect.getCppClassName();
354
355 // Emit the TypeID explicit specializations to have a single symbol def.
356 if (!dialect.getCppNamespace().empty())
357 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
358 << "::" << cppClassName << ")\n";
359
360 // Emit all nested namespaces.
361 NamespaceEmitter nsEmitter(os, dialect);
362
363 /// Build the list of dependent dialects.
364 std::string dependentDialectRegistrations;
365 {
366 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
367 llvm::interleave(
368 c: dialect.getDependentDialects(), os&: dialectsOs,
369 each_fn: [&](StringRef dependentDialect) {
370 dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate,
371 Vals&: dependentDialect);
372 },
373 separator: "\n ");
374 }
375
376 // Emit the constructor and destructor.
377 StringRef superClassName =
378 dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
379
380 const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
381 SmallVector<std::pair<std::string, std::string>> discardableAttributes;
382 populateDiscardableAttributes(dialect, discardableAttrDag,
383 discardableAttributes);
384 std::string discardableAttributesInit;
385 for (const auto &attrPair : discardableAttributes) {
386 std::string camelName = llvm::convertToCamelFromSnakeCase(
387 input: attrPair.first, /*capitalizeFirst=*/false);
388 llvm::raw_string_ostream os(discardableAttributesInit);
389 os << ", " << camelName << "AttrName(context)";
390 }
391
392 os << llvm::formatv(Fmt: dialectConstructorStr, Vals&: cppClassName,
393 Vals&: dependentDialectRegistrations, Vals&: superClassName,
394 Vals&: discardableAttributesInit);
395 if (!dialect.hasNonDefaultDestructor())
396 os << llvm::formatv(Fmt: dialectDestructorStr, Vals&: cppClassName);
397}
398
399static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {
400 emitSourceFileHeader(Desc: "Dialect Definitions", OS&: os, Record: records);
401
402 auto dialectDefs = records.getAllDerivedDefinitions(ClassName: "Dialect");
403 if (dialectDefs.empty())
404 return false;
405
406 SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
407 std::optional<Dialect> dialect = findDialectToGenerate(dialects);
408 if (!dialect)
409 return true;
410 emitDialectDef(dialect&: *dialect, records, os);
411 return false;
412}
413
414//===----------------------------------------------------------------------===//
415// GEN: Dialect registration hooks
416//===----------------------------------------------------------------------===//
417
418static mlir::GenRegistration
419 genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
420 [](const RecordKeeper &records, raw_ostream &os) {
421 return emitDialectDecls(records, os);
422 });
423
424static mlir::GenRegistration
425 genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
426 [](const RecordKeeper &records, raw_ostream &os) {
427 return emitDialectDefs(records, os);
428 });
429

source code of mlir/tools/mlir-tblgen/DialectGen.cpp