1 | //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // DialectGen uses the description of dialects to generate C++ definitions. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "CppGenUtilities.h" |
14 | #include "DialectGenUtilities.h" |
15 | #include "mlir/TableGen/Class.h" |
16 | #include "mlir/TableGen/CodeGenHelpers.h" |
17 | #include "mlir/TableGen/Format.h" |
18 | #include "mlir/TableGen/GenInfo.h" |
19 | #include "mlir/TableGen/Interfaces.h" |
20 | #include "mlir/TableGen/Operator.h" |
21 | #include "mlir/TableGen/Trait.h" |
22 | #include "llvm/ADT/Sequence.h" |
23 | #include "llvm/ADT/StringExtras.h" |
24 | #include "llvm/Support/CommandLine.h" |
25 | #include "llvm/Support/Signals.h" |
26 | #include "llvm/TableGen/Error.h" |
27 | #include "llvm/TableGen/Record.h" |
28 | #include "llvm/TableGen/TableGenBackend.h" |
29 | |
30 | #define DEBUG_TYPE "mlir-tblgen-opdefgen" |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::tblgen; |
34 | using llvm::Record; |
35 | using llvm::RecordKeeper; |
36 | |
37 | static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*" ); |
38 | static llvm::cl::opt<std::string> |
39 | selectedDialect("dialect" , llvm::cl::desc("The dialect to gen for" ), |
40 | llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); |
41 | |
42 | /// Utility iterator used for filtering records for a specific dialect. |
43 | namespace { |
44 | using DialectFilterIterator = |
45 | llvm::filter_iterator<ArrayRef<Record *>::iterator, |
46 | std::function<bool(const Record *)>>; |
47 | } // namespace |
48 | |
49 | static void populateDiscardableAttributes( |
50 | Dialect &dialect, const llvm::DagInit *discardableAttrDag, |
51 | SmallVector<std::pair<std::string, std::string>> &discardableAttributes) { |
52 | for (int i : llvm::seq<int>(Begin: 0, End: discardableAttrDag->getNumArgs())) { |
53 | const llvm::Init *arg = discardableAttrDag->getArg(Num: i); |
54 | |
55 | StringRef givenName = discardableAttrDag->getArgNameStr(Num: i); |
56 | if (givenName.empty()) |
57 | PrintFatalError(ErrorLoc: dialect.getDef()->getLoc(), |
58 | Msg: "discardable attributes must be named" ); |
59 | discardableAttributes.push_back( |
60 | Elt: {givenName.str(), arg->getAsUnquotedString()}); |
61 | } |
62 | } |
63 | |
64 | /// Given a set of records for a T, filter the ones that correspond to |
65 | /// the given dialect. |
66 | template <typename T> |
67 | static iterator_range<DialectFilterIterator> |
68 | filterForDialect(ArrayRef<Record *> records, Dialect &dialect) { |
69 | auto filterFn = [&](const Record *record) { |
70 | return T(record).getDialect() == dialect; |
71 | }; |
72 | return {DialectFilterIterator(records.begin(), records.end(), filterFn), |
73 | DialectFilterIterator(records.end(), records.end(), filterFn)}; |
74 | } |
75 | |
76 | std::optional<Dialect> |
77 | tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) { |
78 | if (dialects.empty()) { |
79 | llvm::errs() << "no dialect was found\n" ; |
80 | return std::nullopt; |
81 | } |
82 | |
83 | // Select the dialect to gen for. |
84 | if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0) |
85 | return dialects.front(); |
86 | |
87 | if (selectedDialect.getNumOccurrences() == 0) { |
88 | llvm::errs() << "when more than 1 dialect is present, one must be selected " |
89 | "via '-dialect'\n" ; |
90 | return std::nullopt; |
91 | } |
92 | |
93 | const auto *dialectIt = llvm::find_if(Range&: dialects, P: [](const Dialect &dialect) { |
94 | return dialect.getName() == selectedDialect; |
95 | }); |
96 | if (dialectIt == dialects.end()) { |
97 | llvm::errs() << "selected dialect with '-dialect' does not exist\n" ; |
98 | return std::nullopt; |
99 | } |
100 | return *dialectIt; |
101 | } |
102 | |
103 | //===----------------------------------------------------------------------===// |
104 | // GEN: Dialect declarations |
105 | //===----------------------------------------------------------------------===// |
106 | |
107 | /// The code block for the start of a dialect class declaration. |
108 | /// |
109 | /// {0}: The name of the dialect class. |
110 | /// {1}: The dialect namespace. |
111 | /// {2}: The dialect parent class. |
112 | /// {3}: The summary and description comments. |
113 | static const char *const dialectDeclBeginStr = R"( |
114 | {3} |
115 | class {0} : public ::mlir::{2} { |
116 | explicit {0}(::mlir::MLIRContext *context); |
117 | |
118 | void initialize(); |
119 | friend class ::mlir::MLIRContext; |
120 | public: |
121 | ~{0}() override; |
122 | static constexpr ::llvm::StringLiteral getDialectNamespace() { |
123 | return ::llvm::StringLiteral("{1}"); |
124 | } |
125 | )" ; |
126 | |
127 | /// Registration for a single dependent dialect: to be inserted in the ctor |
128 | /// above for each dependent dialect. |
129 | const char *const dialectRegistrationTemplate = |
130 | "getContext()->loadDialect<{0}>();" ; |
131 | |
132 | /// The code block for the attribute parser/printer hooks. |
133 | static const char *const attrParserDecl = R"( |
134 | /// Parse an attribute registered to this dialect. |
135 | ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, |
136 | ::mlir::Type type) const override; |
137 | |
138 | /// Print an attribute registered to this dialect. |
139 | void printAttribute(::mlir::Attribute attr, |
140 | ::mlir::DialectAsmPrinter &os) const override; |
141 | )" ; |
142 | |
143 | /// The code block for the type parser/printer hooks. |
144 | static const char *const typeParserDecl = R"( |
145 | /// Parse a type registered to this dialect. |
146 | ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; |
147 | |
148 | /// Print a type registered to this dialect. |
149 | void printType(::mlir::Type type, |
150 | ::mlir::DialectAsmPrinter &os) const override; |
151 | )" ; |
152 | |
153 | /// The code block for the canonicalization pattern registration hook. |
154 | static const char *const canonicalizerDecl = R"( |
155 | /// Register canonicalization patterns. |
156 | void getCanonicalizationPatterns( |
157 | ::mlir::RewritePatternSet &results) const override; |
158 | )" ; |
159 | |
160 | /// The code block for the constant materializer hook. |
161 | static const char *const constantMaterializerDecl = R"( |
162 | /// Materialize a single constant operation from a given attribute value with |
163 | /// the desired resultant type. |
164 | ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, |
165 | ::mlir::Attribute value, |
166 | ::mlir::Type type, |
167 | ::mlir::Location loc) override; |
168 | )" ; |
169 | |
170 | /// The code block for the operation attribute verifier hook. |
171 | static const char *const opAttrVerifierDecl = R"( |
172 | /// Provides a hook for verifying dialect attributes attached to the given |
173 | /// op. |
174 | ::llvm::LogicalResult verifyOperationAttribute( |
175 | ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override; |
176 | )" ; |
177 | |
178 | /// The code block for the region argument attribute verifier hook. |
179 | static const char *const regionArgAttrVerifierDecl = R"( |
180 | /// Provides a hook for verifying dialect attributes attached to the given |
181 | /// op's region argument. |
182 | ::llvm::LogicalResult verifyRegionArgAttribute( |
183 | ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex, |
184 | ::mlir::NamedAttribute attribute) override; |
185 | )" ; |
186 | |
187 | /// The code block for the region result attribute verifier hook. |
188 | static const char *const regionResultAttrVerifierDecl = R"( |
189 | /// Provides a hook for verifying dialect attributes attached to the given |
190 | /// op's region result. |
191 | ::llvm::LogicalResult verifyRegionResultAttribute( |
192 | ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex, |
193 | ::mlir::NamedAttribute attribute) override; |
194 | )" ; |
195 | |
196 | /// The code block for the op interface fallback hook. |
197 | static const char *const operationInterfaceFallbackDecl = R"( |
198 | /// Provides a hook for op interface. |
199 | void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, |
200 | mlir::OperationName opName) override; |
201 | )" ; |
202 | |
203 | /// The code block for the discardable attribute helper. |
204 | static const char *const discardableAttrHelperDecl = R"( |
205 | /// Helper to manage the discardable attribute `{1}`. |
206 | class {0}AttrHelper {{ |
207 | ::mlir::StringAttr name; |
208 | public: |
209 | static constexpr ::llvm::StringLiteral getNameStr() {{ |
210 | return "{4}.{1}"; |
211 | } |
212 | constexpr ::mlir::StringAttr getName() {{ |
213 | return name; |
214 | } |
215 | |
216 | {0}AttrHelper(::mlir::MLIRContext *ctx) |
217 | : name(::mlir::StringAttr::get(ctx, getNameStr())) {{} |
218 | |
219 | {2} getAttr(::mlir::Operation *op) {{ |
220 | return op->getAttrOfType<{2}>(name); |
221 | } |
222 | void setAttr(::mlir::Operation *op, {2} val) {{ |
223 | op->setAttr(name, val); |
224 | } |
225 | bool isAttrPresent(::mlir::Operation *op) {{ |
226 | return op->hasAttrOfType<{2}>(name); |
227 | } |
228 | void removeAttr(::mlir::Operation *op) {{ |
229 | assert(op->hasAttrOfType<{2}>(name)); |
230 | op->removeAttr(name); |
231 | } |
232 | }; |
233 | {0}AttrHelper get{0}AttrHelper() { |
234 | return {3}AttrName; |
235 | } |
236 | private: |
237 | {0}AttrHelper {3}AttrName; |
238 | public: |
239 | )" ; |
240 | |
241 | /// Generate the declaration for the given dialect class. |
242 | static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { |
243 | // Emit all nested namespaces. |
244 | { |
245 | NamespaceEmitter nsEmitter(os, dialect); |
246 | |
247 | // Emit the start of the decl. |
248 | std::string cppName = dialect.getCppClassName(); |
249 | StringRef superClassName = |
250 | dialect.isExtensible() ? "ExtensibleDialect" : "Dialect" ; |
251 | |
252 | std::string = tblgen::emitSummaryAndDescComments( |
253 | summary: dialect.getSummary(), description: dialect.getDescription()); |
254 | os << llvm::formatv(Fmt: dialectDeclBeginStr, Vals&: cppName, Vals: dialect.getName(), |
255 | Vals&: superClassName, Vals&: comments); |
256 | |
257 | // If the dialect requested the default attribute printer and parser, emit |
258 | // the declarations for the hooks. |
259 | if (dialect.useDefaultAttributePrinterParser()) |
260 | os << attrParserDecl; |
261 | // If the dialect requested the default type printer and parser, emit the |
262 | // delcarations for the hooks. |
263 | if (dialect.useDefaultTypePrinterParser()) |
264 | os << typeParserDecl; |
265 | |
266 | // Add the decls for the various features of the dialect. |
267 | if (dialect.hasCanonicalizer()) |
268 | os << canonicalizerDecl; |
269 | if (dialect.hasConstantMaterializer()) |
270 | os << constantMaterializerDecl; |
271 | if (dialect.hasOperationAttrVerify()) |
272 | os << opAttrVerifierDecl; |
273 | if (dialect.hasRegionArgAttrVerify()) |
274 | os << regionArgAttrVerifierDecl; |
275 | if (dialect.hasRegionResultAttrVerify()) |
276 | os << regionResultAttrVerifierDecl; |
277 | if (dialect.hasOperationInterfaceFallback()) |
278 | os << operationInterfaceFallbackDecl; |
279 | |
280 | const llvm::DagInit *discardableAttrDag = |
281 | dialect.getDiscardableAttributes(); |
282 | SmallVector<std::pair<std::string, std::string>> discardableAttributes; |
283 | populateDiscardableAttributes(dialect, discardableAttrDag, |
284 | discardableAttributes); |
285 | |
286 | for (const auto &attrPair : discardableAttributes) { |
287 | std::string camelNameUpper = llvm::convertToCamelFromSnakeCase( |
288 | input: attrPair.first, /*capitalizeFirst=*/true); |
289 | std::string camelName = llvm::convertToCamelFromSnakeCase( |
290 | input: attrPair.first, /*capitalizeFirst=*/false); |
291 | os << llvm::formatv(Fmt: discardableAttrHelperDecl, Vals&: camelNameUpper, |
292 | Vals: attrPair.first, Vals: attrPair.second, Vals&: camelName, |
293 | Vals: dialect.getName()); |
294 | } |
295 | |
296 | if (std::optional<StringRef> = dialect.getExtraClassDeclaration()) |
297 | os << *extraDecl; |
298 | |
299 | // End the dialect decl. |
300 | os << "};\n" ; |
301 | } |
302 | if (!dialect.getCppNamespace().empty()) |
303 | os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() |
304 | << "::" << dialect.getCppClassName() << ")\n" ; |
305 | } |
306 | |
307 | static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) { |
308 | emitSourceFileHeader(Desc: "Dialect Declarations" , OS&: os, Record: records); |
309 | |
310 | auto dialectDefs = records.getAllDerivedDefinitions(ClassName: "Dialect" ); |
311 | if (dialectDefs.empty()) |
312 | return false; |
313 | |
314 | SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); |
315 | std::optional<Dialect> dialect = findDialectToGenerate(dialects); |
316 | if (!dialect) |
317 | return true; |
318 | emitDialectDecl(dialect&: *dialect, os); |
319 | return false; |
320 | } |
321 | |
322 | //===----------------------------------------------------------------------===// |
323 | // GEN: Dialect definitions |
324 | //===----------------------------------------------------------------------===// |
325 | |
326 | /// The code block to generate a dialect constructor definition. |
327 | /// |
328 | /// {0}: The name of the dialect class. |
329 | /// {1}: Initialization code that is emitted in the ctor body before calling |
330 | /// initialize(), such as dependent dialect registration. |
331 | /// {2}: The dialect parent class. |
332 | /// {3}: Extra members to initialize |
333 | static const char *const dialectConstructorStr = R"( |
334 | {0}::{0}(::mlir::MLIRContext *context) |
335 | : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) |
336 | {3} |
337 | {{ |
338 | {1} |
339 | initialize(); |
340 | } |
341 | )" ; |
342 | |
343 | /// The code block to generate a default destructor definition. |
344 | /// |
345 | /// {0}: The name of the dialect class. |
346 | static const char *const dialectDestructorStr = R"( |
347 | {0}::~{0}() = default; |
348 | |
349 | )" ; |
350 | |
351 | static void emitDialectDef(Dialect &dialect, const RecordKeeper &records, |
352 | raw_ostream &os) { |
353 | std::string cppClassName = dialect.getCppClassName(); |
354 | |
355 | // Emit the TypeID explicit specializations to have a single symbol def. |
356 | if (!dialect.getCppNamespace().empty()) |
357 | os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() |
358 | << "::" << cppClassName << ")\n" ; |
359 | |
360 | // Emit all nested namespaces. |
361 | NamespaceEmitter nsEmitter(os, dialect); |
362 | |
363 | /// Build the list of dependent dialects. |
364 | std::string dependentDialectRegistrations; |
365 | { |
366 | llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
367 | llvm::interleave( |
368 | c: dialect.getDependentDialects(), os&: dialectsOs, |
369 | each_fn: [&](StringRef dependentDialect) { |
370 | dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate, |
371 | Vals&: dependentDialect); |
372 | }, |
373 | separator: "\n " ); |
374 | } |
375 | |
376 | // Emit the constructor and destructor. |
377 | StringRef superClassName = |
378 | dialect.isExtensible() ? "ExtensibleDialect" : "Dialect" ; |
379 | |
380 | const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes(); |
381 | SmallVector<std::pair<std::string, std::string>> discardableAttributes; |
382 | populateDiscardableAttributes(dialect, discardableAttrDag, |
383 | discardableAttributes); |
384 | std::string discardableAttributesInit; |
385 | for (const auto &attrPair : discardableAttributes) { |
386 | std::string camelName = llvm::convertToCamelFromSnakeCase( |
387 | input: attrPair.first, /*capitalizeFirst=*/false); |
388 | llvm::raw_string_ostream os(discardableAttributesInit); |
389 | os << ", " << camelName << "AttrName(context)" ; |
390 | } |
391 | |
392 | os << llvm::formatv(Fmt: dialectConstructorStr, Vals&: cppClassName, |
393 | Vals&: dependentDialectRegistrations, Vals&: superClassName, |
394 | Vals&: discardableAttributesInit); |
395 | if (!dialect.hasNonDefaultDestructor()) |
396 | os << llvm::formatv(Fmt: dialectDestructorStr, Vals&: cppClassName); |
397 | } |
398 | |
399 | static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) { |
400 | emitSourceFileHeader(Desc: "Dialect Definitions" , OS&: os, Record: records); |
401 | |
402 | auto dialectDefs = records.getAllDerivedDefinitions(ClassName: "Dialect" ); |
403 | if (dialectDefs.empty()) |
404 | return false; |
405 | |
406 | SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); |
407 | std::optional<Dialect> dialect = findDialectToGenerate(dialects); |
408 | if (!dialect) |
409 | return true; |
410 | emitDialectDef(dialect&: *dialect, records, os); |
411 | return false; |
412 | } |
413 | |
414 | //===----------------------------------------------------------------------===// |
415 | // GEN: Dialect registration hooks |
416 | //===----------------------------------------------------------------------===// |
417 | |
418 | static mlir::GenRegistration |
419 | genDialectDecls("gen-dialect-decls" , "Generate dialect declarations" , |
420 | [](const RecordKeeper &records, raw_ostream &os) { |
421 | return emitDialectDecls(records, os); |
422 | }); |
423 | |
424 | static mlir::GenRegistration |
425 | genDialectDefs("gen-dialect-defs" , "Generate dialect definitions" , |
426 | [](const RecordKeeper &records, raw_ostream &os) { |
427 | return emitDialectDefs(records, os); |
428 | }); |
429 | |