1 | //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // DialectGen uses the description of dialects to generate C++ definitions. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "DialectGenUtilities.h" |
14 | #include "mlir/TableGen/Class.h" |
15 | #include "mlir/TableGen/CodeGenHelpers.h" |
16 | #include "mlir/TableGen/Format.h" |
17 | #include "mlir/TableGen/GenInfo.h" |
18 | #include "mlir/TableGen/Interfaces.h" |
19 | #include "mlir/TableGen/Operator.h" |
20 | #include "mlir/TableGen/Trait.h" |
21 | #include "llvm/ADT/Sequence.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/Support/CommandLine.h" |
24 | #include "llvm/Support/Signals.h" |
25 | #include "llvm/TableGen/Error.h" |
26 | #include "llvm/TableGen/Record.h" |
27 | #include "llvm/TableGen/TableGenBackend.h" |
28 | |
29 | #define DEBUG_TYPE "mlir-tblgen-opdefgen" |
30 | |
31 | using namespace mlir; |
32 | using namespace mlir::tblgen; |
33 | |
34 | static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*"); |
35 | llvm::cl::opt<std::string> |
36 | selectedDialect("dialect", llvm::cl::desc( "The dialect to gen for"), |
37 | llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); |
38 | |
39 | /// Utility iterator used for filtering records for a specific dialect. |
40 | namespace { |
41 | using DialectFilterIterator = |
42 | llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator, |
43 | std::function<bool(const llvm::Record *)>>; |
44 | } // namespace |
45 | |
46 | static void populateDiscardableAttributes( |
47 | Dialect &dialect, llvm::DagInit *discardableAttrDag, |
48 | SmallVector<std::pair<std::string, std::string>> &discardableAttributes) { |
49 | for (int i : llvm::seq<int>(Begin: 0, End: discardableAttrDag->getNumArgs())) { |
50 | llvm::Init *arg = discardableAttrDag->getArg(Num: i); |
51 | |
52 | StringRef givenName = discardableAttrDag->getArgNameStr(Num: i); |
53 | if (givenName.empty()) |
54 | PrintFatalError(ErrorLoc: dialect.getDef()->getLoc(), |
55 | Msg: "discardable attributes must be named"); |
56 | discardableAttributes.push_back( |
57 | Elt: {givenName.str(), arg->getAsUnquotedString()}); |
58 | } |
59 | } |
60 | |
61 | /// Given a set of records for a T, filter the ones that correspond to |
62 | /// the given dialect. |
63 | template <typename T> |
64 | static iterator_range<DialectFilterIterator> |
65 | filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) { |
66 | auto filterFn = [&](const llvm::Record *record) { |
67 | return T(record).getDialect() == dialect; |
68 | }; |
69 | return {DialectFilterIterator(records.begin(), records.end(), filterFn), |
70 | DialectFilterIterator(records.end(), records.end(), filterFn)}; |
71 | } |
72 | |
73 | std::optional<Dialect> |
74 | tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) { |
75 | if (dialects.empty()) { |
76 | llvm::errs() << "no dialect was found\n"; |
77 | return std::nullopt; |
78 | } |
79 | |
80 | // Select the dialect to gen for. |
81 | if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0) |
82 | return dialects.front(); |
83 | |
84 | if (selectedDialect.getNumOccurrences() == 0) { |
85 | llvm::errs() << "when more than 1 dialect is present, one must be selected " |
86 | "via '-dialect'\n"; |
87 | return std::nullopt; |
88 | } |
89 | |
90 | const auto *dialectIt = llvm::find_if(Range&: dialects, P: [](const Dialect &dialect) { |
91 | return dialect.getName() == selectedDialect; |
92 | }); |
93 | if (dialectIt == dialects.end()) { |
94 | llvm::errs() << "selected dialect with '-dialect' does not exist\n"; |
95 | return std::nullopt; |
96 | } |
97 | return *dialectIt; |
98 | } |
99 | |
100 | //===----------------------------------------------------------------------===// |
101 | // GEN: Dialect declarations |
102 | //===----------------------------------------------------------------------===// |
103 | |
104 | /// The code block for the start of a dialect class declaration. |
105 | /// |
106 | /// {0}: The name of the dialect class. |
107 | /// {1}: The dialect namespace. |
108 | /// {2}: The dialect parent class. |
109 | static const char *const dialectDeclBeginStr = R"( |
110 | class {0} : public ::mlir::{2} { |
111 | explicit {0}(::mlir::MLIRContext *context); |
112 | |
113 | void initialize(); |
114 | friend class ::mlir::MLIRContext; |
115 | public: |
116 | ~{0}() override; |
117 | static constexpr ::llvm::StringLiteral getDialectNamespace() { |
118 | return ::llvm::StringLiteral("{1}"); |
119 | } |
120 | )"; |
121 | |
122 | /// Registration for a single dependent dialect: to be inserted in the ctor |
123 | /// above for each dependent dialect. |
124 | const char *const dialectRegistrationTemplate = |
125 | "getContext()->loadDialect<{0}>();"; |
126 | |
127 | /// The code block for the attribute parser/printer hooks. |
128 | static const char *const attrParserDecl = R"( |
129 | /// Parse an attribute registered to this dialect. |
130 | ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, |
131 | ::mlir::Type type) const override; |
132 | |
133 | /// Print an attribute registered to this dialect. |
134 | void printAttribute(::mlir::Attribute attr, |
135 | ::mlir::DialectAsmPrinter &os) const override; |
136 | )"; |
137 | |
138 | /// The code block for the type parser/printer hooks. |
139 | static const char *const typeParserDecl = R"( |
140 | /// Parse a type registered to this dialect. |
141 | ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; |
142 | |
143 | /// Print a type registered to this dialect. |
144 | void printType(::mlir::Type type, |
145 | ::mlir::DialectAsmPrinter &os) const override; |
146 | )"; |
147 | |
148 | /// The code block for the canonicalization pattern registration hook. |
149 | static const char *const canonicalizerDecl = R"( |
150 | /// Register canonicalization patterns. |
151 | void getCanonicalizationPatterns( |
152 | ::mlir::RewritePatternSet &results) const override; |
153 | )"; |
154 | |
155 | /// The code block for the constant materializer hook. |
156 | static const char *const constantMaterializerDecl = R"( |
157 | /// Materialize a single constant operation from a given attribute value with |
158 | /// the desired resultant type. |
159 | ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, |
160 | ::mlir::Attribute value, |
161 | ::mlir::Type type, |
162 | ::mlir::Location loc) override; |
163 | )"; |
164 | |
165 | /// The code block for the operation attribute verifier hook. |
166 | static const char *const opAttrVerifierDecl = R"( |
167 | /// Provides a hook for verifying dialect attributes attached to the given |
168 | /// op. |
169 | ::mlir::LogicalResult verifyOperationAttribute( |
170 | ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override; |
171 | )"; |
172 | |
173 | /// The code block for the region argument attribute verifier hook. |
174 | static const char *const regionArgAttrVerifierDecl = R"( |
175 | /// Provides a hook for verifying dialect attributes attached to the given |
176 | /// op's region argument. |
177 | ::mlir::LogicalResult verifyRegionArgAttribute( |
178 | ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex, |
179 | ::mlir::NamedAttribute attribute) override; |
180 | )"; |
181 | |
182 | /// The code block for the region result attribute verifier hook. |
183 | static const char *const regionResultAttrVerifierDecl = R"( |
184 | /// Provides a hook for verifying dialect attributes attached to the given |
185 | /// op's region result. |
186 | ::mlir::LogicalResult verifyRegionResultAttribute( |
187 | ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex, |
188 | ::mlir::NamedAttribute attribute) override; |
189 | )"; |
190 | |
191 | /// The code block for the op interface fallback hook. |
192 | static const char *const operationInterfaceFallbackDecl = R"( |
193 | /// Provides a hook for op interface. |
194 | void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, |
195 | mlir::OperationName opName) override; |
196 | )"; |
197 | |
198 | /// The code block for the discardable attribute helper. |
199 | static const char *const discardableAttrHelperDecl = R"( |
200 | /// Helper to manage the discardable attribute `{1}`. |
201 | class {0}AttrHelper {{ |
202 | ::mlir::StringAttr name; |
203 | public: |
204 | static constexpr ::llvm::StringLiteral getNameStr() {{ |
205 | return "{4}.{1}"; |
206 | } |
207 | constexpr ::mlir::StringAttr getName() {{ |
208 | return name; |
209 | } |
210 | |
211 | {0}AttrHelper(::mlir::MLIRContext *ctx) |
212 | : name(::mlir::StringAttr::get(ctx, getNameStr())) {{} |
213 | |
214 | {2} getAttr(::mlir::Operation *op) {{ |
215 | return op->getAttrOfType<{2}>(name); |
216 | } |
217 | void setAttr(::mlir::Operation *op, {2} val) {{ |
218 | op->setAttr(name, val); |
219 | } |
220 | bool isAttrPresent(::mlir::Operation *op) {{ |
221 | return op->hasAttrOfType<{2}>(name); |
222 | } |
223 | void removeAttr(::mlir::Operation *op) {{ |
224 | assert(op->hasAttrOfType<{2}>(name)); |
225 | op->removeAttr(name); |
226 | } |
227 | }; |
228 | {0}AttrHelper get{0}AttrHelper() { |
229 | return {3}AttrName; |
230 | } |
231 | private: |
232 | {0}AttrHelper {3}AttrName; |
233 | public: |
234 | )"; |
235 | |
236 | /// Generate the declaration for the given dialect class. |
237 | static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { |
238 | // Emit all nested namespaces. |
239 | { |
240 | NamespaceEmitter nsEmitter(os, dialect); |
241 | |
242 | // Emit the start of the decl. |
243 | std::string cppName = dialect.getCppClassName(); |
244 | StringRef superClassName = |
245 | dialect.isExtensible() ? "ExtensibleDialect": "Dialect"; |
246 | os << llvm::formatv(Fmt: dialectDeclBeginStr, Vals&: cppName, Vals: dialect.getName(), |
247 | Vals&: superClassName); |
248 | |
249 | // If the dialect requested the default attribute printer and parser, emit |
250 | // the declarations for the hooks. |
251 | if (dialect.useDefaultAttributePrinterParser()) |
252 | os << attrParserDecl; |
253 | // If the dialect requested the default type printer and parser, emit the |
254 | // delcarations for the hooks. |
255 | if (dialect.useDefaultTypePrinterParser()) |
256 | os << typeParserDecl; |
257 | |
258 | // Add the decls for the various features of the dialect. |
259 | if (dialect.hasCanonicalizer()) |
260 | os << canonicalizerDecl; |
261 | if (dialect.hasConstantMaterializer()) |
262 | os << constantMaterializerDecl; |
263 | if (dialect.hasOperationAttrVerify()) |
264 | os << opAttrVerifierDecl; |
265 | if (dialect.hasRegionArgAttrVerify()) |
266 | os << regionArgAttrVerifierDecl; |
267 | if (dialect.hasRegionResultAttrVerify()) |
268 | os << regionResultAttrVerifierDecl; |
269 | if (dialect.hasOperationInterfaceFallback()) |
270 | os << operationInterfaceFallbackDecl; |
271 | |
272 | llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes(); |
273 | SmallVector<std::pair<std::string, std::string>> discardableAttributes; |
274 | populateDiscardableAttributes(dialect, discardableAttrDag, |
275 | discardableAttributes); |
276 | |
277 | for (const auto &attrPair : discardableAttributes) { |
278 | std::string camelNameUpper = llvm::convertToCamelFromSnakeCase( |
279 | input: attrPair.first, /*capitalizeFirst=*/true); |
280 | std::string camelName = llvm::convertToCamelFromSnakeCase( |
281 | input: attrPair.first, /*capitalizeFirst=*/false); |
282 | os << llvm::formatv(Fmt: discardableAttrHelperDecl, Vals&: camelNameUpper, |
283 | Vals: attrPair.first, Vals: attrPair.second, Vals&: camelName, |
284 | Vals: dialect.getName()); |
285 | } |
286 | |
287 | if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration()) |
288 | os << *extraDecl; |
289 | |
290 | // End the dialect decl. |
291 | os << "};\n"; |
292 | } |
293 | if (!dialect.getCppNamespace().empty()) |
294 | os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("<< dialect.getCppNamespace() |
295 | << "::"<< dialect.getCppClassName() << ")\n"; |
296 | } |
297 | |
298 | static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, |
299 | raw_ostream &os) { |
300 | emitSourceFileHeader(Desc: "Dialect Declarations", OS&: os, Record: recordKeeper); |
301 | |
302 | auto dialectDefs = recordKeeper.getAllDerivedDefinitions(ClassName: "Dialect"); |
303 | if (dialectDefs.empty()) |
304 | return false; |
305 | |
306 | SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); |
307 | std::optional<Dialect> dialect = findDialectToGenerate(dialects); |
308 | if (!dialect) |
309 | return true; |
310 | emitDialectDecl(dialect&: *dialect, os); |
311 | return false; |
312 | } |
313 | |
314 | //===----------------------------------------------------------------------===// |
315 | // GEN: Dialect definitions |
316 | //===----------------------------------------------------------------------===// |
317 | |
318 | /// The code block to generate a dialect constructor definition. |
319 | /// |
320 | /// {0}: The name of the dialect class. |
321 | /// {1}: Initialization code that is emitted in the ctor body before calling |
322 | /// initialize(), such as dependent dialect registration. |
323 | /// {2}: The dialect parent class. |
324 | /// {3}: Extra members to initialize |
325 | static const char *const dialectConstructorStr = R"( |
326 | {0}::{0}(::mlir::MLIRContext *context) |
327 | : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) |
328 | {3} |
329 | {{ |
330 | {1} |
331 | initialize(); |
332 | } |
333 | )"; |
334 | |
335 | /// The code block to generate a default destructor definition. |
336 | /// |
337 | /// {0}: The name of the dialect class. |
338 | static const char *const dialectDestructorStr = R"( |
339 | {0}::~{0}() = default; |
340 | |
341 | )"; |
342 | |
343 | static void emitDialectDef(Dialect &dialect, |
344 | const llvm::RecordKeeper &recordKeeper, |
345 | raw_ostream &os) { |
346 | std::string cppClassName = dialect.getCppClassName(); |
347 | |
348 | // Emit the TypeID explicit specializations to have a single symbol def. |
349 | if (!dialect.getCppNamespace().empty()) |
350 | os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("<< dialect.getCppNamespace() |
351 | << "::"<< cppClassName << ")\n"; |
352 | |
353 | // Emit all nested namespaces. |
354 | NamespaceEmitter nsEmitter(os, dialect); |
355 | |
356 | /// Build the list of dependent dialects. |
357 | std::string dependentDialectRegistrations; |
358 | { |
359 | llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); |
360 | llvm::interleave( |
361 | c: dialect.getDependentDialects(), os&: dialectsOs, |
362 | each_fn: [&](StringRef dependentDialect) { |
363 | dialectsOs << llvm::formatv(Fmt: dialectRegistrationTemplate, |
364 | Vals&: dependentDialect); |
365 | }, |
366 | separator: "\n "); |
367 | } |
368 | |
369 | // Emit the constructor and destructor. |
370 | StringRef superClassName = |
371 | dialect.isExtensible() ? "ExtensibleDialect": "Dialect"; |
372 | |
373 | llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes(); |
374 | SmallVector<std::pair<std::string, std::string>> discardableAttributes; |
375 | populateDiscardableAttributes(dialect, discardableAttrDag, |
376 | discardableAttributes); |
377 | std::string discardableAttributesInit; |
378 | for (const auto &attrPair : discardableAttributes) { |
379 | std::string camelName = llvm::convertToCamelFromSnakeCase( |
380 | input: attrPair.first, /*capitalizeFirst=*/false); |
381 | llvm::raw_string_ostream os(discardableAttributesInit); |
382 | os << ", "<< camelName << "AttrName(context)"; |
383 | } |
384 | |
385 | os << llvm::formatv(Fmt: dialectConstructorStr, Vals&: cppClassName, |
386 | Vals&: dependentDialectRegistrations, Vals&: superClassName, |
387 | Vals&: discardableAttributesInit); |
388 | if (!dialect.hasNonDefaultDestructor()) |
389 | os << llvm::formatv(Fmt: dialectDestructorStr, Vals&: cppClassName); |
390 | } |
391 | |
392 | static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper, |
393 | raw_ostream &os) { |
394 | emitSourceFileHeader(Desc: "Dialect Definitions", OS&: os, Record: recordKeeper); |
395 | |
396 | auto dialectDefs = recordKeeper.getAllDerivedDefinitions(ClassName: "Dialect"); |
397 | if (dialectDefs.empty()) |
398 | return false; |
399 | |
400 | SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end()); |
401 | std::optional<Dialect> dialect = findDialectToGenerate(dialects); |
402 | if (!dialect) |
403 | return true; |
404 | emitDialectDef(dialect&: *dialect, recordKeeper, os); |
405 | return false; |
406 | } |
407 | |
408 | //===----------------------------------------------------------------------===// |
409 | // GEN: Dialect registration hooks |
410 | //===----------------------------------------------------------------------===// |
411 | |
412 | static mlir::GenRegistration |
413 | genDialectDecls("gen-dialect-decls", "Generate dialect declarations", |
414 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
415 | return emitDialectDecls(recordKeeper: records, os); |
416 | }); |
417 | |
418 | static mlir::GenRegistration |
419 | genDialectDefs("gen-dialect-defs", "Generate dialect definitions", |
420 | [](const llvm::RecordKeeper &records, raw_ostream &os) { |
421 | return emitDialectDefs(recordKeeper: records, os); |
422 | }); |
423 |
Definitions
- dialectGenCat
- selectedDialect
- populateDiscardableAttributes
- filterForDialect
- findDialectToGenerate
- dialectDeclBeginStr
- dialectRegistrationTemplate
- attrParserDecl
- typeParserDecl
- canonicalizerDecl
- constantMaterializerDecl
- opAttrVerifierDecl
- regionArgAttrVerifierDecl
- regionResultAttrVerifierDecl
- operationInterfaceFallbackDecl
- discardableAttrHelperDecl
- emitDialectDecl
- emitDialectDecls
- dialectConstructorStr
- dialectDestructorStr
- emitDialectDef
- emitDialectDefs
- genDialectDecls
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more