1//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// EnumsGen generates common utility functions for enums.
10//
11//===----------------------------------------------------------------------===//
12
13#include "FormatGen.h"
14#include "mlir/TableGen/Attribute.h"
15#include "mlir/TableGen/Format.h"
16#include "mlir/TableGen/GenInfo.h"
17#include "llvm/ADT/BitVector.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/raw_ostream.h"
22#include "llvm/TableGen/Error.h"
23#include "llvm/TableGen/Record.h"
24#include "llvm/TableGen/TableGenBackend.h"
25
26using llvm::formatv;
27using llvm::isDigit;
28using llvm::PrintFatalError;
29using llvm::raw_ostream;
30using llvm::Record;
31using llvm::RecordKeeper;
32using llvm::StringRef;
33using mlir::tblgen::Attribute;
34using mlir::tblgen::EnumAttr;
35using mlir::tblgen::EnumAttrCase;
36using mlir::tblgen::FmtContext;
37using mlir::tblgen::tgfmt;
38
39static std::string makeIdentifier(StringRef str) {
40 if (!str.empty() && isDigit(C: static_cast<unsigned char>(str.front()))) {
41 std::string newStr = std::string("_") + str.str();
42 return newStr;
43 }
44 return str.str();
45}
46
47static void emitEnumClass(const Record &enumDef, StringRef enumName,
48 StringRef underlyingType, StringRef description,
49 const std::vector<EnumAttrCase> &enumerants,
50 raw_ostream &os) {
51 os << "// " << description << "\n";
52 os << "enum class " << enumName;
53
54 if (!underlyingType.empty())
55 os << " : " << underlyingType;
56 os << " {\n";
57
58 for (const auto &enumerant : enumerants) {
59 auto symbol = makeIdentifier(str: enumerant.getSymbol());
60 auto value = enumerant.getValue();
61 if (value >= 0) {
62 os << formatv(Fmt: " {0} = {1},\n", Vals&: symbol, Vals&: value);
63 } else {
64 os << formatv(Fmt: " {0},\n", Vals&: symbol);
65 }
66 }
67 os << "};\n\n";
68}
69
70static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
71 StringRef cppNamespace, raw_ostream &os) {
72 if (enumAttr.getUnderlyingType().empty() ||
73 enumAttr.getConstBuilderTemplate().empty())
74 return;
75 auto cases = enumAttr.getAllCases();
76
77 // Check which cases shouldn't be printed using a keyword.
78 llvm::BitVector nonKeywordCases(cases.size());
79 for (auto [index, caseVal] : llvm::enumerate(First&: cases))
80 if (!mlir::tblgen::canFormatStringAsKeyword(value: caseVal.getStr()))
81 nonKeywordCases.set(index);
82
83 // Generate the parser and the start of the printer for the enum.
84 const char *parsedAndPrinterStart = R"(
85namespace mlir {
86template <typename T, typename>
87struct FieldParser;
88
89template<>
90struct FieldParser<{0}, {0}> {{
91 template <typename ParserT>
92 static FailureOr<{0}> parse(ParserT &parser) {{
93 // Parse the keyword/string containing the enum.
94 std::string enumKeyword;
95 auto loc = parser.getCurrentLocation();
96 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
97 return parser.emitError(loc, "expected keyword for {2}");
98
99 // Symbolize the keyword.
100 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
101 return *attr;
102 return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
103 }
104};
105} // namespace mlir
106
107namespace llvm {
108inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
109 auto valueStr = stringifyEnum(value);
110)";
111 os << formatv(Fmt: parsedAndPrinterStart, Vals&: qualName, Vals&: cppNamespace,
112 Vals: enumAttr.getSummary());
113
114 // If all cases require a string, always wrap.
115 if (nonKeywordCases.all()) {
116 os << " return p << '\"' << valueStr << '\"';\n"
117 "}\n"
118 "} // namespace llvm\n";
119 return;
120 }
121
122 // If there are any cases that can't be used with a keyword, switch on the
123 // case value to determine when to print in the string form.
124 if (nonKeywordCases.any()) {
125 os << " switch (value) {\n";
126 for (auto it : llvm::enumerate(First&: cases)) {
127 if (nonKeywordCases.test(Idx: it.index()))
128 continue;
129 StringRef symbol = it.value().getSymbol();
130 os << llvm::formatv(Fmt: " case {0}::{1}:\n", Vals&: qualName,
131 Vals: makeIdentifier(str: symbol));
132 }
133 os << " break;\n"
134 " default:\n"
135 " return p << '\"' << valueStr << '\"';\n"
136 " }\n";
137
138 // If this is a bit enum, conservatively print the string form if the value
139 // is not a power of two (i.e. not a single bit case) and not a known case.
140 } else if (enumAttr.isBitEnum()) {
141 // Process the known multi-bit cases that use valid keywords.
142 llvm::SmallVector<EnumAttrCase *> validMultiBitCases;
143 for (auto [index, caseVal] : llvm::enumerate(First&: cases)) {
144 uint64_t value = caseVal.getValue();
145 if (value && !llvm::has_single_bit(Value: value) && !nonKeywordCases.test(Idx: index))
146 validMultiBitCases.push_back(Elt: &caseVal);
147 }
148 if (!validMultiBitCases.empty()) {
149 os << " switch (value) {\n";
150 for (EnumAttrCase *caseVal : validMultiBitCases) {
151 StringRef symbol = caseVal->getSymbol();
152 os << llvm::formatv(Fmt: " case {0}::{1}:\n", Vals&: qualName,
153 Vals: llvm::isDigit(C: symbol.front()) ? ("_" + symbol)
154 : symbol);
155 }
156 os << " return p << valueStr;\n"
157 " default:\n"
158 " break;\n"
159 " }\n";
160 }
161
162 // All other multi-bit cases should be printed as strings.
163 os << formatv(Fmt: " auto underlyingValue = "
164 "static_cast<std::make_unsigned_t<{0}>>(value);\n",
165 Vals&: qualName);
166 os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
167 " return p << '\"' << valueStr << '\"';\n";
168 }
169 os << " return p << valueStr;\n"
170 "}\n"
171 "} // namespace llvm\n";
172}
173
174static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
175 StringRef cppNamespace, raw_ostream &os) {
176 if (underlyingType.empty())
177 underlyingType =
178 std::string(formatv(Fmt: "std::underlying_type_t<{0}>", Vals&: qualName));
179
180 const char *const mapInfo = R"(
181namespace llvm {
182template<> struct DenseMapInfo<{0}> {{
183 using StorageInfo = ::llvm::DenseMapInfo<{1}>;
184
185 static inline {0} getEmptyKey() {{
186 return static_cast<{0}>(StorageInfo::getEmptyKey());
187 }
188
189 static inline {0} getTombstoneKey() {{
190 return static_cast<{0}>(StorageInfo::getTombstoneKey());
191 }
192
193 static unsigned getHashValue(const {0} &val) {{
194 return StorageInfo::getHashValue(static_cast<{1}>(val));
195 }
196
197 static bool isEqual(const {0} &lhs, const {0} &rhs) {{
198 return lhs == rhs;
199 }
200};
201})";
202 os << formatv(Fmt: mapInfo, Vals&: qualName, Vals&: underlyingType);
203 os << "\n\n";
204}
205
206static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
207 EnumAttr enumAttr(enumDef);
208 StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
209 auto enumerants = enumAttr.getAllCases();
210
211 unsigned maxEnumVal = 0;
212 for (const auto &enumerant : enumerants) {
213 int64_t value = enumerant.getValue();
214 // Avoid generating the max value function if there is an enumerant without
215 // explicit value.
216 if (value < 0)
217 return;
218
219 maxEnumVal = std::max(a: maxEnumVal, b: static_cast<unsigned>(value));
220 }
221
222 // Emit the function to return the max enum value
223 os << formatv(Fmt: "inline constexpr unsigned {0}() {{\n", Vals&: maxEnumValFnName);
224 os << formatv(Fmt: " return {0};\n", Vals&: maxEnumVal);
225 os << "}\n\n";
226}
227
228// Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt
229// otherwise.
230static std::optional<EnumAttrCase>
231getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
232 for (auto attrCase : cases) {
233 if (attrCase.getValue() == 0)
234 return attrCase;
235 }
236 return std::nullopt;
237}
238
239// Emits the following inline function for bit enums:
240//
241// inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
242// inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
243// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
244// inline constexpr <enum-type> operator~(<enum-type> bits);
245// inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
246// inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
247// inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
248// inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
249// bool value=true);
250static void emitOperators(const Record &enumDef, raw_ostream &os) {
251 EnumAttr enumAttr(enumDef);
252 StringRef enumName = enumAttr.getEnumClassName();
253 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
254 int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits");
255 const char *const operators = R"(
256inline constexpr {0} operator|({0} a, {0} b) {{
257 return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b));
258}
259inline constexpr {0} operator&({0} a, {0} b) {{
260 return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b));
261}
262inline constexpr {0} operator^({0} a, {0} b) {{
263 return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
264}
265inline constexpr {0} operator~({0} bits) {{
266 // Ensure only bits that can be present in the enum are set
267 return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
268}
269inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
270 return (bits & bit) == bit;
271}
272inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
273 return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
274}
275inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
276 return bits & ~bit;
277}
278inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{
279 return value ? (bits | bit) : bitEnumClear(bits, bit);
280}
281 )";
282 os << formatv(Fmt: operators, Vals&: enumName, Vals&: underlyingType, Vals&: validBits);
283}
284
285static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
286 EnumAttr enumAttr(enumDef);
287 StringRef enumName = enumAttr.getEnumClassName();
288 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
289 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
290 auto enumerants = enumAttr.getAllCases();
291
292 os << formatv(Fmt: "{2} {1}({0} val) {{\n", Vals&: enumName, Vals&: symToStrFnName,
293 Vals&: symToStrFnRetType);
294 os << " switch (val) {\n";
295 for (const auto &enumerant : enumerants) {
296 auto symbol = enumerant.getSymbol();
297 auto str = enumerant.getStr();
298 os << formatv(Fmt: " case {0}::{1}: return \"{2}\";\n", Vals&: enumName,
299 Vals: makeIdentifier(str: symbol), Vals&: str);
300 }
301 os << " }\n";
302 os << " return \"\";\n";
303 os << "}\n\n";
304}
305
306static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
307 EnumAttr enumAttr(enumDef);
308 StringRef enumName = enumAttr.getEnumClassName();
309 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
310 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
311 StringRef separator = enumDef.getValueAsString(FieldName: "separator");
312 auto enumerants = enumAttr.getAllCases();
313 auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants);
314
315 os << formatv(Fmt: "{2} {1}({0} symbol) {{\n", Vals&: enumName, Vals&: symToStrFnName,
316 Vals&: symToStrFnRetType);
317
318 os << formatv(Fmt: " auto val = static_cast<{0}>(symbol);\n",
319 Vals: enumAttr.getUnderlyingType());
320 // If we have unknown bit set, return an empty string to signal errors.
321 int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits");
322 os << formatv(Fmt: " assert({0}u == ({0}u | val) && \"invalid bits set in bit "
323 "enum\");\n",
324 Vals&: validBits);
325 if (allBitsUnsetCase) {
326 os << " // Special case for all bits unset.\n";
327 os << formatv(Fmt: " if (val == 0) return \"{0}\";\n\n",
328 Vals: allBitsUnsetCase->getStr());
329 }
330 os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
331
332 // Add case string if the value has all case bits, and remove them to avoid
333 // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
334 const char *const formatCompareRemove = R"(
335 if ({0}u == ({0}u & val)) {{
336 strs.push_back("{1}");
337 val &= ~static_cast<{2}>({0});
338 }
339)";
340 // Add case string if the value has all case bits. Used for individual bit
341 // cases, and for groups when printBitEnumPrimaryGroups is 0.
342 const char *const formatCompare = R"(
343 if ({0}u == ({0}u & val))
344 strs.push_back("{1}");
345)";
346 // Optionally elide bits that are members of groups that will also be printed
347 // for more concise output.
348 if (enumAttr.printBitEnumPrimaryGroups()) {
349 os << " // Print bit enum groups before individual bits\n";
350 // Emit comparisons for group bit cases in reverse tablegen declaration
351 // order, removing bits for groups with all bits present.
352 for (const auto &enumerant : llvm::reverse(C&: enumerants)) {
353 if ((enumerant.getValue() != 0) &&
354 enumerant.getDef().isSubClassOf(Name: "BitEnumAttrCaseGroup")) {
355 os << formatv(Fmt: formatCompareRemove, Vals: enumerant.getValue(),
356 Vals: enumerant.getStr(), Vals: enumAttr.getUnderlyingType());
357 }
358 }
359 // Emit comparisons for individual bit cases in tablegen declaration order.
360 for (const auto &enumerant : enumerants) {
361 if ((enumerant.getValue() != 0) &&
362 enumerant.getDef().isSubClassOf(Name: "BitEnumAttrCaseBit"))
363 os << formatv(Fmt: formatCompare, Vals: enumerant.getValue(), Vals: enumerant.getStr());
364 }
365 } else {
366 // Emit comparisons for ALL nonzero cases (individual bits and groups) in
367 // tablegen declaration order.
368 for (const auto &enumerant : enumerants) {
369 if (enumerant.getValue() != 0)
370 os << formatv(Fmt: formatCompare, Vals: enumerant.getValue(), Vals: enumerant.getStr());
371 }
372 }
373 os << formatv(Fmt: " return ::llvm::join(strs, \"{0}\");\n", Vals&: separator);
374
375 os << "}\n\n";
376}
377
378static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
379 EnumAttr enumAttr(enumDef);
380 StringRef enumName = enumAttr.getEnumClassName();
381 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
382 auto enumerants = enumAttr.getAllCases();
383
384 os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
385 Vals&: enumName, Vals&: strToSymFnName);
386 os << formatv(Fmt: " return ::llvm::StringSwitch<::std::optional<{0}>>(str)\n",
387 Vals&: enumName);
388 for (const auto &enumerant : enumerants) {
389 auto symbol = enumerant.getSymbol();
390 auto str = enumerant.getStr();
391 os << formatv(Fmt: " .Case(\"{1}\", {0}::{2})\n", Vals&: enumName, Vals&: str,
392 Vals: makeIdentifier(str: symbol));
393 }
394 os << " .Default(::std::nullopt);\n";
395 os << "}\n";
396}
397
398static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
399 EnumAttr enumAttr(enumDef);
400 StringRef enumName = enumAttr.getEnumClassName();
401 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
402 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
403 StringRef separator = enumDef.getValueAsString(FieldName: "separator");
404 StringRef separatorTrimmed = separator.trim();
405 auto enumerants = enumAttr.getAllCases();
406 auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants);
407
408 os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
409 Vals&: enumName, Vals&: strToSymFnName);
410
411 if (allBitsUnsetCase) {
412 os << " // Special case for all bits unset.\n";
413 StringRef caseSymbol = allBitsUnsetCase->getSymbol();
414 os << formatv(Fmt: " if (str == \"{1}\") return {0}::{2};\n\n", Vals&: enumName,
415 Vals: allBitsUnsetCase->getStr(), Vals: makeIdentifier(str: caseSymbol));
416 }
417
418 // Split the string to get symbols for all the bits.
419 os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
420 // Remove whitespace from the separator string when parsing.
421 os << formatv(Fmt: " str.split(symbols, \"{0}\");\n\n", Vals&: separatorTrimmed);
422
423 os << formatv(Fmt: " {0} val = 0;\n", Vals&: underlyingType);
424 os << " for (auto symbol : symbols) {\n";
425
426 // Convert each symbol to the bit ordinal and set the corresponding bit.
427 os << formatv(Fmt: " auto bit = "
428 "llvm::StringSwitch<::std::optional<{0}>>(symbol.trim())\n",
429 Vals&: underlyingType);
430 for (const auto &enumerant : enumerants) {
431 // Skip the special enumerant for None.
432 if (auto val = enumerant.getValue())
433 os.indent(NumSpaces: 6) << formatv(Fmt: ".Case(\"{0}\", {1})\n", Vals: enumerant.getStr(), Vals&: val);
434 }
435 os.indent(NumSpaces: 6) << ".Default(::std::nullopt);\n";
436
437 os << " if (bit) { val |= *bit; } else { return ::std::nullopt; }\n";
438 os << " }\n";
439
440 os << formatv(Fmt: " return static_cast<{0}>(val);\n", Vals&: enumName);
441 os << "}\n\n";
442}
443
444static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
445 raw_ostream &os) {
446 EnumAttr enumAttr(enumDef);
447 StringRef enumName = enumAttr.getEnumClassName();
448 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
449 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
450 auto enumerants = enumAttr.getAllCases();
451
452 // Avoid generating the underlying value to symbol conversion function if
453 // there is an enumerant without explicit value.
454 if (llvm::any_of(Range&: enumerants, P: [](EnumAttrCase enumerant) {
455 return enumerant.getValue() < 0;
456 }))
457 return;
458
459 os << formatv(Fmt: "::std::optional<{0}> {1}({2} value) {{\n", Vals&: enumName,
460 Vals&: underlyingToSymFnName,
461 Vals: underlyingType.empty() ? std::string("unsigned")
462 : underlyingType)
463 << " switch (value) {\n";
464 for (const auto &enumerant : enumerants) {
465 auto symbol = enumerant.getSymbol();
466 auto value = enumerant.getValue();
467 os << formatv(Fmt: " case {0}: return {1}::{2};\n", Vals&: value, Vals&: enumName,
468 Vals: makeIdentifier(str: symbol));
469 }
470 os << " default: return ::std::nullopt;\n"
471 << " }\n"
472 << "}\n\n";
473}
474
475static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
476 EnumAttr enumAttr(enumDef);
477 StringRef enumName = enumAttr.getEnumClassName();
478 StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
479 llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
480 Attribute baseAttr(baseAttrDef);
481
482 // Emit classof method
483
484 os << formatv(Fmt: "bool {0}::classof(::mlir::Attribute attr) {{\n",
485 Vals&: attrClassName);
486
487 mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
488 if (baseAttrPred.isNull())
489 PrintFatalError(Msg: "ERROR: baseAttrClass for EnumAttr has no Predicate\n");
490
491 std::string condition = baseAttrPred.getCondition();
492 FmtContext verifyCtx;
493 verifyCtx.withSelf(subst: "attr");
494 os << tgfmt(fmt: " return $0;\n", /*ctx=*/nullptr, vals: tgfmt(fmt: condition, ctx: &verifyCtx));
495
496 os << "}\n";
497
498 // Emit get method
499
500 os << formatv(Fmt: "{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
501 Vals&: attrClassName, Vals&: enumName);
502
503 StringRef underlyingType = enumAttr.getUnderlyingType();
504
505 // Assuming that it is IntegerAttr constraint
506 int64_t bitwidth = 64;
507 if (baseAttrDef->getValue(Name: "valueType")) {
508 auto *valueTypeDef = baseAttrDef->getValueAsDef(FieldName: "valueType");
509 if (valueTypeDef->getValue(Name: "bitwidth"))
510 bitwidth = valueTypeDef->getValueAsInt(FieldName: "bitwidth");
511 }
512
513 os << formatv(Fmt: " ::mlir::IntegerType intType = "
514 "::mlir::IntegerType::get(context, {0});\n",
515 Vals&: bitwidth);
516 os << formatv(Fmt: " ::mlir::IntegerAttr baseAttr = "
517 "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
518 Vals&: underlyingType);
519 os << formatv(Fmt: " return ::llvm::cast<{0}>(baseAttr);\n", Vals&: attrClassName);
520
521 os << "}\n";
522
523 // Emit getValue method
524
525 os << formatv(Fmt: "{0} {1}::getValue() const {{\n", Vals&: enumName, Vals&: attrClassName);
526
527 os << formatv(Fmt: " return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
528 Vals&: enumName);
529
530 os << "}\n";
531}
532
533static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
534 raw_ostream &os) {
535 EnumAttr enumAttr(enumDef);
536 StringRef enumName = enumAttr.getEnumClassName();
537 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
538 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
539 auto enumerants = enumAttr.getAllCases();
540 auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants);
541
542 os << formatv(Fmt: "::std::optional<{0}> {1}({2} value) {{\n", Vals&: enumName,
543 Vals&: underlyingToSymFnName, Vals&: underlyingType);
544 if (allBitsUnsetCase) {
545 os << " // Special case for all bits unset.\n";
546 os << formatv(Fmt: " if (value == 0) return {0}::{1};\n\n", Vals&: enumName,
547 Vals: makeIdentifier(str: allBitsUnsetCase->getSymbol()));
548 }
549 int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits");
550 os << formatv(Fmt: " if (value & ~static_cast<{0}>({1}u)) return std::nullopt;\n",
551 Vals&: underlyingType, Vals&: validBits);
552 os << formatv(Fmt: " return static_cast<{0}>(value);\n", Vals&: enumName);
553 os << "}\n";
554}
555
556static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
557 EnumAttr enumAttr(enumDef);
558 StringRef enumName = enumAttr.getEnumClassName();
559 StringRef cppNamespace = enumAttr.getCppNamespace();
560 std::string underlyingType = std::string(enumAttr.getUnderlyingType());
561 StringRef description = enumAttr.getSummary();
562 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
563 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
564 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
565 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
566 auto enumerants = enumAttr.getAllCases();
567
568 llvm::SmallVector<StringRef, 2> namespaces;
569 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
570
571 for (auto ns : namespaces)
572 os << "namespace " << ns << " {\n";
573
574 // Emit the enum class definition
575 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
576
577 // Emit conversion function declarations
578 if (llvm::all_of(Range&: enumerants, P: [](EnumAttrCase enumerant) {
579 return enumerant.getValue() >= 0;
580 })) {
581 os << formatv(
582 Fmt: "::std::optional<{0}> {1}({2});\n", Vals&: enumName, Vals&: underlyingToSymFnName,
583 Vals: underlyingType.empty() ? std::string("unsigned") : underlyingType);
584 }
585 os << formatv(Fmt: "{2} {1}({0});\n", Vals&: enumName, Vals&: symToStrFnName, Vals&: symToStrFnRetType);
586 os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef);\n", Vals&: enumName,
587 Vals&: strToSymFnName);
588
589 if (enumAttr.isBitEnum()) {
590 emitOperators(enumDef, os);
591 } else {
592 emitMaxValueFn(enumDef, os);
593 }
594
595 // Generate a generic `stringifyEnum` function that forwards to the method
596 // specified by the user.
597 const char *const stringifyEnumStr = R"(
598inline {0} stringifyEnum({1} enumValue) {{
599 return {2}(enumValue);
600}
601)";
602 os << formatv(Fmt: stringifyEnumStr, Vals&: symToStrFnRetType, Vals&: enumName, Vals&: symToStrFnName);
603
604 // Generate a generic `symbolizeEnum` function that forwards to the method
605 // specified by the user.
606 const char *const symbolizeEnumStr = R"(
607template <typename EnumType>
608::std::optional<EnumType> symbolizeEnum(::llvm::StringRef);
609
610template <>
611inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
612 return {1}(str);
613}
614)";
615 os << formatv(Fmt: symbolizeEnumStr, Vals&: enumName, Vals&: strToSymFnName);
616
617 const char *const attrClassDecl = R"(
618class {1} : public ::mlir::{2} {
619public:
620 using ValueType = {0};
621 using ::mlir::{2}::{2};
622 static bool classof(::mlir::Attribute attr);
623 static {1} get(::mlir::MLIRContext *context, {0} val);
624 {0} getValue() const;
625};
626)";
627 if (enumAttr.genSpecializedAttr()) {
628 StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
629 StringRef baseAttrClassName = "IntegerAttr";
630 os << formatv(Fmt: attrClassDecl, Vals&: enumName, Vals&: attrClassName, Vals&: baseAttrClassName);
631 }
632
633 for (auto ns : llvm::reverse(C&: namespaces))
634 os << "} // namespace " << ns << "\n";
635
636 // Generate a generic parser and printer for the enum.
637 std::string qualName =
638 std::string(formatv(Fmt: "{0}::{1}", Vals&: cppNamespace, Vals&: enumName));
639 emitParserPrinter(enumAttr, qualName, cppNamespace, os);
640
641 // Emit DenseMapInfo for this enum class
642 emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
643}
644
645static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
646 llvm::emitSourceFileHeader(Desc: "Enum Utility Declarations", OS&: os, Record: recordKeeper);
647
648 auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined(ClassName: "EnumAttrInfo");
649 for (const auto *def : defs)
650 emitEnumDecl(enumDef: *def, os);
651
652 return false;
653}
654
655static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
656 EnumAttr enumAttr(enumDef);
657 StringRef cppNamespace = enumAttr.getCppNamespace();
658
659 llvm::SmallVector<StringRef, 2> namespaces;
660 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
661
662 for (auto ns : namespaces)
663 os << "namespace " << ns << " {\n";
664
665 if (enumAttr.isBitEnum()) {
666 emitSymToStrFnForBitEnum(enumDef, os);
667 emitStrToSymFnForBitEnum(enumDef, os);
668 emitUnderlyingToSymFnForBitEnum(enumDef, os);
669 } else {
670 emitSymToStrFnForIntEnum(enumDef, os);
671 emitStrToSymFnForIntEnum(enumDef, os);
672 emitUnderlyingToSymFnForIntEnum(enumDef, os);
673 }
674
675 if (enumAttr.genSpecializedAttr())
676 emitSpecializedAttrDef(enumDef, os);
677
678 for (auto ns : llvm::reverse(C&: namespaces))
679 os << "} // namespace " << ns << "\n";
680 os << "\n";
681}
682
683static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
684 llvm::emitSourceFileHeader(Desc: "Enum Utility Definitions", OS&: os, Record: recordKeeper);
685
686 auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined(ClassName: "EnumAttrInfo");
687 for (const auto *def : defs)
688 emitEnumDef(enumDef: *def, os);
689
690 return false;
691}
692
693// Registers the enum utility generator to mlir-tblgen.
694static mlir::GenRegistration
695 genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
696 [](const RecordKeeper &records, raw_ostream &os) {
697 return emitEnumDecls(recordKeeper: records, os);
698 });
699
700// Registers the enum utility generator to mlir-tblgen.
701static mlir::GenRegistration
702 genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
703 [](const RecordKeeper &records, raw_ostream &os) {
704 return emitEnumDefs(recordKeeper: records, os);
705 });
706

source code of mlir/tools/mlir-tblgen/EnumsGen.cpp