1//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// EnumsGen generates common utility functions for enums.
10//
11//===----------------------------------------------------------------------===//
12
13#include "FormatGen.h"
14#include "mlir/TableGen/Attribute.h"
15#include "mlir/TableGen/EnumInfo.h"
16#include "mlir/TableGen/Format.h"
17#include "mlir/TableGen/GenInfo.h"
18#include "llvm/ADT/BitVector.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/FormatVariadic.h"
22#include "llvm/Support/raw_ostream.h"
23#include "llvm/TableGen/Error.h"
24#include "llvm/TableGen/Record.h"
25#include "llvm/TableGen/TableGenBackend.h"
26
27using llvm::formatv;
28using llvm::isDigit;
29using llvm::PrintFatalError;
30using llvm::Record;
31using llvm::RecordKeeper;
32using namespace mlir;
33using mlir::tblgen::Attribute;
34using mlir::tblgen::EnumCase;
35using mlir::tblgen::EnumInfo;
36using mlir::tblgen::FmtContext;
37using mlir::tblgen::tgfmt;
38
39static std::string makeIdentifier(StringRef str) {
40 if (!str.empty() && isDigit(C: static_cast<unsigned char>(str.front()))) {
41 std::string newStr = std::string("_") + str.str();
42 return newStr;
43 }
44 return str.str();
45}
46
47static void emitEnumClass(const Record &enumDef, StringRef enumName,
48 StringRef underlyingType, StringRef description,
49 const std::vector<EnumCase> &enumerants,
50 raw_ostream &os) {
51 os << "// " << description << "\n";
52 os << "enum class " << enumName;
53
54 if (!underlyingType.empty())
55 os << " : " << underlyingType;
56 os << " {\n";
57
58 for (const auto &enumerant : enumerants) {
59 auto symbol = makeIdentifier(str: enumerant.getSymbol());
60 auto value = enumerant.getValue();
61 if (value >= 0) {
62 os << formatv(Fmt: " {0} = {1},\n", Vals&: symbol, Vals&: value);
63 } else {
64 os << formatv(Fmt: " {0},\n", Vals&: symbol);
65 }
66 }
67 os << "};\n\n";
68}
69
70static void emitParserPrinter(const EnumInfo &enumInfo, StringRef qualName,
71 StringRef cppNamespace, raw_ostream &os) {
72 std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
73 if (enumInfo.getUnderlyingType().empty() ||
74 (enumAttrInfo && enumAttrInfo->getConstBuilderTemplate().empty()))
75 return;
76 auto cases = enumInfo.getAllCases();
77
78 // Check which cases shouldn't be printed using a keyword.
79 llvm::BitVector nonKeywordCases(cases.size());
80 std::string casesList;
81 llvm::raw_string_ostream caseListOs(casesList);
82 caseListOs << "[";
83 llvm::interleaveComma(c: llvm::enumerate(First&: cases), os&: caseListOs,
84 each_fn: [&](auto enumerant) {
85 StringRef name = enumerant.value().getStr();
86 if (!mlir::tblgen::canFormatStringAsKeyword(value: name)) {
87 nonKeywordCases.set(enumerant.index());
88 caseListOs << "\\\"" << name << "\\\"";
89 }
90 caseListOs << name;
91 });
92 caseListOs << "]";
93
94 // Generate the parser and the start of the printer for the enum, excluding
95 // non-quoted bit enums.
96 const char *parsedAndPrinterStart = R"(
97namespace mlir {
98template <typename T, typename>
99struct FieldParser;
100
101template<>
102struct FieldParser<{0}, {0}> {{
103 template <typename ParserT>
104 static FailureOr<{0}> parse(ParserT &parser) {{
105 // Parse the keyword/string containing the enum.
106 std::string enumKeyword;
107 auto loc = parser.getCurrentLocation();
108 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
109 return parser.emitError(loc, "expected keyword for {2}");
110
111 // Symbolize the keyword.
112 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
113 return *attr;
114 return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
115 }
116};
117
118/// Support for std::optional, useful in attribute/type definition where the enum is
119/// used as:
120///
121/// let parameters = (ins OptionalParameter<"std::optional<TheEnumName>">:$value);
122template<>
123struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
124 template <typename ParserT>
125 static FailureOr<std::optional<{0}>> parse(ParserT &parser) {{
126 // Parse the keyword/string containing the enum.
127 std::string enumKeyword;
128 auto loc = parser.getCurrentLocation();
129 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
130 return std::optional<{0}>{{};
131
132 // Symbolize the keyword.
133 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
134 return attr;
135 return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
136 }
137};
138} // namespace mlir
139
140namespace llvm {
141inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
142 auto valueStr = stringifyEnum(value);
143)";
144
145 const char *parsedAndPrinterStartUnquotedBitEnum = R"(
146 namespace mlir {
147 template <typename T, typename>
148 struct FieldParser;
149
150 template<>
151 struct FieldParser<{0}, {0}> {{
152 template <typename ParserT>
153 static FailureOr<{0}> parse(ParserT &parser) {{
154 {0} flags = {{};
155 do {{
156 // Parse the keyword containing a part of the enum.
157 ::llvm::StringRef enumKeyword;
158 auto loc = parser.getCurrentLocation();
159 if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{
160 return parser.emitError(loc, "expected keyword for {2}");
161 }
162
163 // Symbolize the keyword.
164 if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{
165 flags = flags | *flag;
166 } else {{
167 return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
168 }
169 } while (::mlir::succeeded(parser.{5}()));
170 return flags;
171 }
172 };
173
174 /// Support for std::optional, useful in attribute/type definition where the enum is
175 /// used as:
176 ///
177 /// let parameters = (ins OptionalParameter<"std::optional<TheEnumName>">:$value);
178 template<>
179 struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
180 template <typename ParserT>
181 static FailureOr<std::optional<{0}>> parse(ParserT &parser) {{
182 {0} flags = {{};
183 bool firstIter = true;
184 do {{
185 // Parse the keyword containing a part of the enum.
186 ::llvm::StringRef enumKeyword;
187 auto loc = parser.getCurrentLocation();
188 if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{
189 if (firstIter)
190 return std::optional<{0}>{{};
191 return parser.emitError(loc, "expected keyword for {2} after '{4}'");
192 }
193 firstIter = false;
194
195 // Symbolize the keyword.
196 if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{
197 flags = flags | *flag;
198 } else {{
199 return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
200 }
201 } while(::mlir::succeeded(parser.{5}()));
202 return std::optional<{0}>{{flags};
203 }
204 };
205 } // namespace mlir
206
207 namespace llvm {
208 inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
209 auto valueStr = stringifyEnum(value);
210 )";
211
212 bool isNewStyleBitEnum =
213 enumInfo.isBitEnum() && !enumInfo.printBitEnumQuoted();
214
215 if (isNewStyleBitEnum) {
216 if (nonKeywordCases.any())
217 return PrintFatalError(
218 Msg: "bit enum " + qualName +
219 " cannot be printed unquoted with cases that cannot be keywords");
220 StringRef separator = enumInfo.getDef().getValueAsString(FieldName: "separator");
221 StringRef parseSeparatorFn =
222 llvm::StringSwitch<StringRef>(separator.trim())
223 .Case(S: "|", Value: "parseOptionalVerticalBar")
224 .Case(S: ",", Value: "parseOptionalComma")
225 .Default(Value: "error, enum seperator must be '|' or ','");
226 os << formatv(Fmt: parsedAndPrinterStartUnquotedBitEnum, Vals&: qualName, Vals&: cppNamespace,
227 Vals: enumInfo.getSummary(), Vals&: casesList, Vals&: separator,
228 Vals&: parseSeparatorFn);
229 } else {
230 os << formatv(Fmt: parsedAndPrinterStart, Vals&: qualName, Vals&: cppNamespace,
231 Vals: enumInfo.getSummary(), Vals&: casesList);
232 }
233
234 // If all cases require a string, always wrap.
235 if (nonKeywordCases.all()) {
236 os << " return p << '\"' << valueStr << '\"';\n"
237 "}\n"
238 "} // namespace llvm\n";
239 return;
240 }
241
242 // If there are any cases that can't be used with a keyword, switch on the
243 // case value to determine when to print in the string form.
244 if (nonKeywordCases.any()) {
245 os << " switch (value) {\n";
246 for (auto it : llvm::enumerate(First&: cases)) {
247 if (nonKeywordCases.test(Idx: it.index()))
248 continue;
249 StringRef symbol = it.value().getSymbol();
250 os << llvm::formatv(Fmt: " case {0}::{1}:\n", Vals&: qualName,
251 Vals: makeIdentifier(str: symbol));
252 }
253 os << " break;\n"
254 " default:\n"
255 " return p << '\"' << valueStr << '\"';\n"
256 " }\n";
257
258 // If this is a bit enum, conservatively print the string form if the value
259 // is not a power of two (i.e. not a single bit case) and not a known case.
260 // Only do this if we're using the old-style parser that parses the enum as
261 // one keyword, as opposed to the new form, where we can print the value
262 // as-is.
263 } else if (enumInfo.isBitEnum() && !isNewStyleBitEnum) {
264 // Process the known multi-bit cases that use valid keywords.
265 SmallVector<EnumCase *> validMultiBitCases;
266 for (auto [index, caseVal] : llvm::enumerate(First&: cases)) {
267 uint64_t value = caseVal.getValue();
268 if (value && !llvm::has_single_bit(Value: value) && !nonKeywordCases.test(Idx: index))
269 validMultiBitCases.push_back(Elt: &caseVal);
270 }
271 if (!validMultiBitCases.empty()) {
272 os << " switch (value) {\n";
273 for (EnumCase *caseVal : validMultiBitCases) {
274 StringRef symbol = caseVal->getSymbol();
275 os << llvm::formatv(Fmt: " case {0}::{1}:\n", Vals&: qualName,
276 Vals: llvm::isDigit(C: symbol.front()) ? ("_" + symbol)
277 : symbol);
278 }
279 os << " return p << valueStr;\n"
280 " default:\n"
281 " break;\n"
282 " }\n";
283 }
284
285 // All other multi-bit cases should be printed as strings.
286 os << formatv(Fmt: " auto underlyingValue = "
287 "static_cast<std::make_unsigned_t<{0}>>(value);\n",
288 Vals&: qualName);
289 os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
290 " return p << '\"' << valueStr << '\"';\n";
291 }
292 os << " return p << valueStr;\n"
293 "}\n"
294 "} // namespace llvm\n";
295}
296
297static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
298 StringRef cppNamespace, raw_ostream &os) {
299 if (underlyingType.empty())
300 underlyingType =
301 std::string(formatv(Fmt: "std::underlying_type_t<{0}>", Vals&: qualName));
302
303 const char *const mapInfo = R"(
304namespace llvm {
305template<> struct DenseMapInfo<{0}> {{
306 using StorageInfo = ::llvm::DenseMapInfo<{1}>;
307
308 static inline {0} getEmptyKey() {{
309 return static_cast<{0}>(StorageInfo::getEmptyKey());
310 }
311
312 static inline {0} getTombstoneKey() {{
313 return static_cast<{0}>(StorageInfo::getTombstoneKey());
314 }
315
316 static unsigned getHashValue(const {0} &val) {{
317 return StorageInfo::getHashValue(static_cast<{1}>(val));
318 }
319
320 static bool isEqual(const {0} &lhs, const {0} &rhs) {{
321 return lhs == rhs;
322 }
323};
324})";
325 os << formatv(Fmt: mapInfo, Vals&: qualName, Vals&: underlyingType);
326 os << "\n\n";
327}
328
329static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
330 EnumInfo enumInfo(enumDef);
331 StringRef maxEnumValFnName = enumInfo.getMaxEnumValFnName();
332 auto enumerants = enumInfo.getAllCases();
333
334 unsigned maxEnumVal = 0;
335 for (const auto &enumerant : enumerants) {
336 int64_t value = enumerant.getValue();
337 // Avoid generating the max value function if there is an enumerant without
338 // explicit value.
339 if (value < 0)
340 return;
341
342 maxEnumVal = std::max(a: maxEnumVal, b: static_cast<unsigned>(value));
343 }
344
345 // Emit the function to return the max enum value
346 os << formatv(Fmt: "inline constexpr unsigned {0}() {{\n", Vals&: maxEnumValFnName);
347 os << formatv(Fmt: " return {0};\n", Vals&: maxEnumVal);
348 os << "}\n\n";
349}
350
351// Returns the EnumCase whose value is zero if exists; returns std::nullopt
352// otherwise.
353static std::optional<EnumCase>
354getAllBitsUnsetCase(llvm::ArrayRef<EnumCase> cases) {
355 for (auto attrCase : cases) {
356 if (attrCase.getValue() == 0)
357 return attrCase;
358 }
359 return std::nullopt;
360}
361
362// Emits the following inline function for bit enums:
363//
364// inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
365// inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
366// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
367// inline constexpr <enum-type> operator~(<enum-type> bits);
368// inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
369// inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
370// inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
371// inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
372// bool value=true);
373static void emitOperators(const Record &enumDef, raw_ostream &os) {
374 EnumInfo enumInfo(enumDef);
375 StringRef enumName = enumInfo.getEnumClassName();
376 std::string underlyingType = std::string(enumInfo.getUnderlyingType());
377 int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits");
378 const char *const operators = R"(
379inline constexpr {0} operator|({0} a, {0} b) {{
380 return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b));
381}
382inline constexpr {0} operator&({0} a, {0} b) {{
383 return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b));
384}
385inline constexpr {0} operator^({0} a, {0} b) {{
386 return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
387}
388inline constexpr {0} operator~({0} bits) {{
389 // Ensure only bits that can be present in the enum are set
390 return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
391}
392inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
393 return (bits & bit) == bit;
394}
395inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
396 return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
397}
398inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
399 return bits & ~bit;
400}
401inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{
402 return value ? (bits | bit) : bitEnumClear(bits, bit);
403}
404 )";
405 os << formatv(Fmt: operators, Vals&: enumName, Vals&: underlyingType, Vals&: validBits);
406}
407
408static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
409 EnumInfo enumInfo(enumDef);
410 StringRef enumName = enumInfo.getEnumClassName();
411 StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
412 StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
413 auto enumerants = enumInfo.getAllCases();
414
415 os << formatv(Fmt: "{2} {1}({0} val) {{\n", Vals&: enumName, Vals&: symToStrFnName,
416 Vals&: symToStrFnRetType);
417 os << " switch (val) {\n";
418 for (const auto &enumerant : enumerants) {
419 auto symbol = enumerant.getSymbol();
420 auto str = enumerant.getStr();
421 os << formatv(Fmt: " case {0}::{1}: return \"{2}\";\n", Vals&: enumName,
422 Vals: makeIdentifier(str: symbol), Vals&: str);
423 }
424 os << " }\n";
425 os << " return \"\";\n";
426 os << "}\n\n";
427}
428
429static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
430 EnumInfo enumInfo(enumDef);
431 StringRef enumName = enumInfo.getEnumClassName();
432 StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
433 StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
434 StringRef separator = enumDef.getValueAsString(FieldName: "separator");
435 auto enumerants = enumInfo.getAllCases();
436 auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants);
437
438 os << formatv(Fmt: "{2} {1}({0} symbol) {{\n", Vals&: enumName, Vals&: symToStrFnName,
439 Vals&: symToStrFnRetType);
440
441 os << formatv(Fmt: " auto val = static_cast<{0}>(symbol);\n",
442 Vals: enumInfo.getUnderlyingType());
443 // If we have unknown bit set, return an empty string to signal errors.
444 int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits");
445 os << formatv(Fmt: " assert({0}u == ({0}u | val) && \"invalid bits set in bit "
446 "enum\");\n",
447 Vals&: validBits);
448 if (allBitsUnsetCase) {
449 os << " // Special case for all bits unset.\n";
450 os << formatv(Fmt: " if (val == 0) return \"{0}\";\n\n",
451 Vals: allBitsUnsetCase->getStr());
452 }
453 os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
454
455 // Add case string if the value has all case bits, and remove them to avoid
456 // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
457 const char *const formatCompareRemove = R"(
458 if ({0}u == ({0}u & val)) {{
459 strs.push_back("{1}");
460 val &= ~static_cast<{2}>({0});
461 }
462)";
463 // Add case string if the value has all case bits. Used for individual bit
464 // cases, and for groups when printBitEnumPrimaryGroups is 0.
465 const char *const formatCompare = R"(
466 if ({0}u == ({0}u & val))
467 strs.push_back("{1}");
468)";
469 // Optionally elide bits that are members of groups that will also be printed
470 // for more concise output.
471 if (enumInfo.printBitEnumPrimaryGroups()) {
472 os << " // Print bit enum groups before individual bits\n";
473 // Emit comparisons for group bit cases in reverse tablegen declaration
474 // order, removing bits for groups with all bits present.
475 for (const auto &enumerant : llvm::reverse(C&: enumerants)) {
476 if ((enumerant.getValue() != 0) &&
477 (enumerant.getDef().isSubClassOf(Name: "BitEnumCaseGroup") ||
478 enumerant.getDef().isSubClassOf(Name: "BitEnumAttrCaseGroup"))) {
479 os << formatv(Fmt: formatCompareRemove, Vals: enumerant.getValue(),
480 Vals: enumerant.getStr(), Vals: enumInfo.getUnderlyingType());
481 }
482 }
483 // Emit comparisons for individual bit cases in tablegen declaration order.
484 for (const auto &enumerant : enumerants) {
485 if ((enumerant.getValue() != 0) &&
486 (enumerant.getDef().isSubClassOf(Name: "BitEnumCaseBit") ||
487 enumerant.getDef().isSubClassOf(Name: "BitEnumAttrCaseBit")))
488 os << formatv(Fmt: formatCompare, Vals: enumerant.getValue(), Vals: enumerant.getStr());
489 }
490 } else {
491 // Emit comparisons for ALL nonzero cases (individual bits and groups) in
492 // tablegen declaration order.
493 for (const auto &enumerant : enumerants) {
494 if (enumerant.getValue() != 0)
495 os << formatv(Fmt: formatCompare, Vals: enumerant.getValue(), Vals: enumerant.getStr());
496 }
497 }
498 os << formatv(Fmt: " return ::llvm::join(strs, \"{0}\");\n", Vals&: separator);
499
500 os << "}\n\n";
501}
502
503static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
504 EnumInfo enumInfo(enumDef);
505 StringRef enumName = enumInfo.getEnumClassName();
506 StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
507 auto enumerants = enumInfo.getAllCases();
508
509 os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
510 Vals&: enumName, Vals&: strToSymFnName);
511 os << formatv(Fmt: " return ::llvm::StringSwitch<::std::optional<{0}>>(str)\n",
512 Vals&: enumName);
513 for (const auto &enumerant : enumerants) {
514 auto symbol = enumerant.getSymbol();
515 auto str = enumerant.getStr();
516 os << formatv(Fmt: " .Case(\"{1}\", {0}::{2})\n", Vals&: enumName, Vals&: str,
517 Vals: makeIdentifier(str: symbol));
518 }
519 os << " .Default(::std::nullopt);\n";
520 os << "}\n";
521}
522
523static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
524 EnumInfo enumInfo(enumDef);
525 StringRef enumName = enumInfo.getEnumClassName();
526 std::string underlyingType = std::string(enumInfo.getUnderlyingType());
527 StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
528 StringRef separator = enumDef.getValueAsString(FieldName: "separator");
529 StringRef separatorTrimmed = separator.trim();
530 auto enumerants = enumInfo.getAllCases();
531 auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants);
532
533 os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
534 Vals&: enumName, Vals&: strToSymFnName);
535
536 if (allBitsUnsetCase) {
537 os << " // Special case for all bits unset.\n";
538 StringRef caseSymbol = allBitsUnsetCase->getSymbol();
539 os << formatv(Fmt: " if (str == \"{1}\") return {0}::{2};\n\n", Vals&: enumName,
540 Vals: allBitsUnsetCase->getStr(), Vals: makeIdentifier(str: caseSymbol));
541 }
542
543 // Split the string to get symbols for all the bits.
544 os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
545 // Remove whitespace from the separator string when parsing.
546 os << formatv(Fmt: " str.split(symbols, \"{0}\");\n\n", Vals&: separatorTrimmed);
547
548 os << formatv(Fmt: " {0} val = 0;\n", Vals&: underlyingType);
549 os << " for (auto symbol : symbols) {\n";
550
551 // Convert each symbol to the bit ordinal and set the corresponding bit.
552 os << formatv(Fmt: " auto bit = "
553 "llvm::StringSwitch<::std::optional<{0}>>(symbol.trim())\n",
554 Vals&: underlyingType);
555 for (const auto &enumerant : enumerants) {
556 // Skip the special enumerant for None.
557 if (auto val = enumerant.getValue())
558 os.indent(NumSpaces: 6) << formatv(Fmt: ".Case(\"{0}\", {1})\n", Vals: enumerant.getStr(), Vals&: val);
559 }
560 os.indent(NumSpaces: 6) << ".Default(::std::nullopt);\n";
561
562 os << " if (bit) { val |= *bit; } else { return ::std::nullopt; }\n";
563 os << " }\n";
564
565 os << formatv(Fmt: " return static_cast<{0}>(val);\n", Vals&: enumName);
566 os << "}\n\n";
567}
568
569static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
570 raw_ostream &os) {
571 EnumInfo enumInfo(enumDef);
572 StringRef enumName = enumInfo.getEnumClassName();
573 std::string underlyingType = std::string(enumInfo.getUnderlyingType());
574 StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
575 auto enumerants = enumInfo.getAllCases();
576
577 // Avoid generating the underlying value to symbol conversion function if
578 // there is an enumerant without explicit value.
579 if (llvm::any_of(Range&: enumerants,
580 P: [](EnumCase enumerant) { return enumerant.getValue() < 0; }))
581 return;
582
583 os << formatv(Fmt: "::std::optional<{0}> {1}({2} value) {{\n", Vals&: enumName,
584 Vals&: underlyingToSymFnName,
585 Vals: underlyingType.empty() ? std::string("unsigned")
586 : underlyingType)
587 << " switch (value) {\n";
588 for (const auto &enumerant : enumerants) {
589 auto symbol = enumerant.getSymbol();
590 auto value = enumerant.getValue();
591 os << formatv(Fmt: " case {0}: return {1}::{2};\n", Vals&: value, Vals&: enumName,
592 Vals: makeIdentifier(str: symbol));
593 }
594 os << " default: return ::std::nullopt;\n"
595 << " }\n"
596 << "}\n\n";
597}
598
599static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
600 EnumInfo enumInfo(enumDef);
601 StringRef enumName = enumInfo.getEnumClassName();
602 StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
603 const Record *baseAttrDef = enumInfo.getBaseAttrClass();
604 Attribute baseAttr(baseAttrDef);
605
606 // Emit classof method
607
608 os << formatv(Fmt: "bool {0}::classof(::mlir::Attribute attr) {{\n",
609 Vals&: attrClassName);
610
611 mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
612 if (baseAttrPred.isNull())
613 PrintFatalError(Msg: "ERROR: baseAttrClass for EnumAttr has no Predicate\n");
614
615 std::string condition = baseAttrPred.getCondition();
616 FmtContext verifyCtx;
617 verifyCtx.withSelf(subst: "attr");
618 os << tgfmt(fmt: " return $0;\n", /*ctx=*/nullptr, vals: tgfmt(fmt: condition, ctx: &verifyCtx));
619
620 os << "}\n";
621
622 // Emit get method
623
624 os << formatv(Fmt: "{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
625 Vals&: attrClassName, Vals&: enumName);
626
627 StringRef underlyingType = enumInfo.getUnderlyingType();
628
629 // Assuming that it is IntegerAttr constraint
630 int64_t bitwidth = 64;
631 if (baseAttrDef->getValue(Name: "valueType")) {
632 auto *valueTypeDef = baseAttrDef->getValueAsDef(FieldName: "valueType");
633 if (valueTypeDef->getValue(Name: "bitwidth"))
634 bitwidth = valueTypeDef->getValueAsInt(FieldName: "bitwidth");
635 }
636
637 os << formatv(Fmt: " ::mlir::IntegerType intType = "
638 "::mlir::IntegerType::get(context, {0});\n",
639 Vals&: bitwidth);
640 os << formatv(Fmt: " ::mlir::IntegerAttr baseAttr = "
641 "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
642 Vals&: underlyingType);
643 os << formatv(Fmt: " return ::llvm::cast<{0}>(baseAttr);\n", Vals&: attrClassName);
644
645 os << "}\n";
646
647 // Emit getValue method
648
649 os << formatv(Fmt: "{0} {1}::getValue() const {{\n", Vals&: enumName, Vals&: attrClassName);
650
651 os << formatv(Fmt: " return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
652 Vals&: enumName);
653
654 os << "}\n";
655}
656
657static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
658 raw_ostream &os) {
659 EnumInfo enumInfo(enumDef);
660 StringRef enumName = enumInfo.getEnumClassName();
661 std::string underlyingType = std::string(enumInfo.getUnderlyingType());
662 StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
663 auto enumerants = enumInfo.getAllCases();
664 auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants);
665
666 os << formatv(Fmt: "::std::optional<{0}> {1}({2} value) {{\n", Vals&: enumName,
667 Vals&: underlyingToSymFnName, Vals&: underlyingType);
668 if (allBitsUnsetCase) {
669 os << " // Special case for all bits unset.\n";
670 os << formatv(Fmt: " if (value == 0) return {0}::{1};\n\n", Vals&: enumName,
671 Vals: makeIdentifier(str: allBitsUnsetCase->getSymbol()));
672 }
673 int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits");
674 os << formatv(Fmt: " if (value & ~static_cast<{0}>({1}u)) return std::nullopt;\n",
675 Vals&: underlyingType, Vals&: validBits);
676 os << formatv(Fmt: " return static_cast<{0}>(value);\n", Vals&: enumName);
677 os << "}\n";
678}
679
680static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
681 EnumInfo enumInfo(enumDef);
682 StringRef enumName = enumInfo.getEnumClassName();
683 StringRef cppNamespace = enumInfo.getCppNamespace();
684 std::string underlyingType = std::string(enumInfo.getUnderlyingType());
685 StringRef description = enumInfo.getSummary();
686 StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
687 StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
688 StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
689 StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
690 auto enumerants = enumInfo.getAllCases();
691
692 SmallVector<StringRef, 2> namespaces;
693 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
694
695 for (auto ns : namespaces)
696 os << "namespace " << ns << " {\n";
697
698 // Emit the enum class definition
699 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
700
701 // Emit conversion function declarations
702 if (llvm::all_of(Range&: enumerants, P: [](EnumCase enumerant) {
703 return enumerant.getValue() >= 0;
704 })) {
705 os << formatv(
706 Fmt: "::std::optional<{0}> {1}({2});\n", Vals&: enumName, Vals&: underlyingToSymFnName,
707 Vals: underlyingType.empty() ? std::string("unsigned") : underlyingType);
708 }
709 os << formatv(Fmt: "{2} {1}({0});\n", Vals&: enumName, Vals&: symToStrFnName, Vals&: symToStrFnRetType);
710 os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef);\n", Vals&: enumName,
711 Vals&: strToSymFnName);
712
713 if (enumInfo.isBitEnum()) {
714 emitOperators(enumDef, os);
715 } else {
716 emitMaxValueFn(enumDef, os);
717 }
718
719 // Generate a generic `stringifyEnum` function that forwards to the method
720 // specified by the user.
721 const char *const stringifyEnumStr = R"(
722inline {0} stringifyEnum({1} enumValue) {{
723 return {2}(enumValue);
724}
725)";
726 os << formatv(Fmt: stringifyEnumStr, Vals&: symToStrFnRetType, Vals&: enumName, Vals&: symToStrFnName);
727
728 // Generate a generic `symbolizeEnum` function that forwards to the method
729 // specified by the user.
730 const char *const symbolizeEnumStr = R"(
731template <typename EnumType>
732::std::optional<EnumType> symbolizeEnum(::llvm::StringRef);
733
734template <>
735inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
736 return {1}(str);
737}
738)";
739 os << formatv(Fmt: symbolizeEnumStr, Vals&: enumName, Vals&: strToSymFnName);
740
741 const char *const attrClassDecl = R"(
742class {1} : public ::mlir::{2} {
743public:
744 using ValueType = {0};
745 using ::mlir::{2}::{2};
746 static bool classof(::mlir::Attribute attr);
747 static {1} get(::mlir::MLIRContext *context, {0} val);
748 {0} getValue() const;
749};
750)";
751 if (enumInfo.genSpecializedAttr()) {
752 StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
753 StringRef baseAttrClassName = "IntegerAttr";
754 os << formatv(Fmt: attrClassDecl, Vals&: enumName, Vals&: attrClassName, Vals&: baseAttrClassName);
755 }
756
757 for (auto ns : llvm::reverse(C&: namespaces))
758 os << "} // namespace " << ns << "\n";
759
760 // Generate a generic parser and printer for the enum.
761 std::string qualName =
762 std::string(formatv(Fmt: "{0}::{1}", Vals&: cppNamespace, Vals&: enumName));
763 emitParserPrinter(enumInfo, qualName, cppNamespace, os);
764
765 // Emit DenseMapInfo for this enum class
766 emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
767}
768
769static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
770 llvm::emitSourceFileHeader(Desc: "Enum Utility Declarations", OS&: os, Record: records);
771
772 for (const Record *def :
773 records.getAllDerivedDefinitionsIfDefined(ClassName: "EnumInfo"))
774 emitEnumDecl(enumDef: *def, os);
775
776 return false;
777}
778
779static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
780 EnumInfo enumInfo(enumDef);
781 StringRef cppNamespace = enumInfo.getCppNamespace();
782
783 SmallVector<StringRef, 2> namespaces;
784 llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::");
785
786 for (auto ns : namespaces)
787 os << "namespace " << ns << " {\n";
788
789 if (enumInfo.isBitEnum()) {
790 emitSymToStrFnForBitEnum(enumDef, os);
791 emitStrToSymFnForBitEnum(enumDef, os);
792 emitUnderlyingToSymFnForBitEnum(enumDef, os);
793 } else {
794 emitSymToStrFnForIntEnum(enumDef, os);
795 emitStrToSymFnForIntEnum(enumDef, os);
796 emitUnderlyingToSymFnForIntEnum(enumDef, os);
797 }
798
799 if (enumInfo.genSpecializedAttr())
800 emitSpecializedAttrDef(enumDef, os);
801
802 for (auto ns : llvm::reverse(C&: namespaces))
803 os << "} // namespace " << ns << "\n";
804 os << "\n";
805}
806
807static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
808 llvm::emitSourceFileHeader(Desc: "Enum Utility Definitions", OS&: os, Record: records);
809
810 for (const Record *def :
811 records.getAllDerivedDefinitionsIfDefined(ClassName: "EnumInfo"))
812 emitEnumDef(enumDef: *def, os);
813
814 return false;
815}
816
817// Registers the enum utility generator to mlir-tblgen.
818static mlir::GenRegistration
819 genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
820 [](const RecordKeeper &records, raw_ostream &os) {
821 return emitEnumDecls(records, os);
822 });
823
824// Registers the enum utility generator to mlir-tblgen.
825static mlir::GenRegistration
826 genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
827 [](const RecordKeeper &records, raw_ostream &os) {
828 return emitEnumDefs(records, os);
829 });
830

source code of mlir/tools/mlir-tblgen/EnumsGen.cpp