1 | //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // EnumsGen generates common utility functions for enums. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "FormatGen.h" |
14 | #include "mlir/TableGen/Attribute.h" |
15 | #include "mlir/TableGen/Format.h" |
16 | #include "mlir/TableGen/GenInfo.h" |
17 | #include "llvm/ADT/BitVector.h" |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "llvm/ADT/StringExtras.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "llvm/Support/raw_ostream.h" |
22 | #include "llvm/TableGen/Error.h" |
23 | #include "llvm/TableGen/Record.h" |
24 | #include "llvm/TableGen/TableGenBackend.h" |
25 | |
26 | using llvm::formatv; |
27 | using llvm::isDigit; |
28 | using llvm::PrintFatalError; |
29 | using llvm::raw_ostream; |
30 | using llvm::Record; |
31 | using llvm::RecordKeeper; |
32 | using llvm::StringRef; |
33 | using mlir::tblgen::Attribute; |
34 | using mlir::tblgen::EnumAttr; |
35 | using mlir::tblgen::EnumAttrCase; |
36 | using mlir::tblgen::FmtContext; |
37 | using mlir::tblgen::tgfmt; |
38 | |
39 | static std::string makeIdentifier(StringRef str) { |
40 | if (!str.empty() && isDigit(C: static_cast<unsigned char>(str.front()))) { |
41 | std::string newStr = std::string("_" ) + str.str(); |
42 | return newStr; |
43 | } |
44 | return str.str(); |
45 | } |
46 | |
47 | static void emitEnumClass(const Record &enumDef, StringRef enumName, |
48 | StringRef underlyingType, StringRef description, |
49 | const std::vector<EnumAttrCase> &enumerants, |
50 | raw_ostream &os) { |
51 | os << "// " << description << "\n" ; |
52 | os << "enum class " << enumName; |
53 | |
54 | if (!underlyingType.empty()) |
55 | os << " : " << underlyingType; |
56 | os << " {\n" ; |
57 | |
58 | for (const auto &enumerant : enumerants) { |
59 | auto symbol = makeIdentifier(str: enumerant.getSymbol()); |
60 | auto value = enumerant.getValue(); |
61 | if (value >= 0) { |
62 | os << formatv(Fmt: " {0} = {1},\n" , Vals&: symbol, Vals&: value); |
63 | } else { |
64 | os << formatv(Fmt: " {0},\n" , Vals&: symbol); |
65 | } |
66 | } |
67 | os << "};\n\n" ; |
68 | } |
69 | |
70 | static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName, |
71 | StringRef cppNamespace, raw_ostream &os) { |
72 | if (enumAttr.getUnderlyingType().empty() || |
73 | enumAttr.getConstBuilderTemplate().empty()) |
74 | return; |
75 | auto cases = enumAttr.getAllCases(); |
76 | |
77 | // Check which cases shouldn't be printed using a keyword. |
78 | llvm::BitVector nonKeywordCases(cases.size()); |
79 | for (auto [index, caseVal] : llvm::enumerate(First&: cases)) |
80 | if (!mlir::tblgen::canFormatStringAsKeyword(value: caseVal.getStr())) |
81 | nonKeywordCases.set(index); |
82 | |
83 | // Generate the parser and the start of the printer for the enum. |
84 | const char *parsedAndPrinterStart = R"( |
85 | namespace mlir { |
86 | template <typename T, typename> |
87 | struct FieldParser; |
88 | |
89 | template<> |
90 | struct FieldParser<{0}, {0}> {{ |
91 | template <typename ParserT> |
92 | static FailureOr<{0}> parse(ParserT &parser) {{ |
93 | // Parse the keyword/string containing the enum. |
94 | std::string enumKeyword; |
95 | auto loc = parser.getCurrentLocation(); |
96 | if (failed(parser.parseOptionalKeywordOrString(&enumKeyword))) |
97 | return parser.emitError(loc, "expected keyword for {2}"); |
98 | |
99 | // Symbolize the keyword. |
100 | if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword)) |
101 | return *attr; |
102 | return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword; |
103 | } |
104 | }; |
105 | } // namespace mlir |
106 | |
107 | namespace llvm { |
108 | inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ |
109 | auto valueStr = stringifyEnum(value); |
110 | )" ; |
111 | os << formatv(Fmt: parsedAndPrinterStart, Vals&: qualName, Vals&: cppNamespace, |
112 | Vals: enumAttr.getSummary()); |
113 | |
114 | // If all cases require a string, always wrap. |
115 | if (nonKeywordCases.all()) { |
116 | os << " return p << '\"' << valueStr << '\"';\n" |
117 | "}\n" |
118 | "} // namespace llvm\n" ; |
119 | return; |
120 | } |
121 | |
122 | // If there are any cases that can't be used with a keyword, switch on the |
123 | // case value to determine when to print in the string form. |
124 | if (nonKeywordCases.any()) { |
125 | os << " switch (value) {\n" ; |
126 | for (auto it : llvm::enumerate(First&: cases)) { |
127 | if (nonKeywordCases.test(Idx: it.index())) |
128 | continue; |
129 | StringRef symbol = it.value().getSymbol(); |
130 | os << llvm::formatv(Fmt: " case {0}::{1}:\n" , Vals&: qualName, |
131 | Vals: makeIdentifier(str: symbol)); |
132 | } |
133 | os << " break;\n" |
134 | " default:\n" |
135 | " return p << '\"' << valueStr << '\"';\n" |
136 | " }\n" ; |
137 | |
138 | // If this is a bit enum, conservatively print the string form if the value |
139 | // is not a power of two (i.e. not a single bit case) and not a known case. |
140 | } else if (enumAttr.isBitEnum()) { |
141 | // Process the known multi-bit cases that use valid keywords. |
142 | llvm::SmallVector<EnumAttrCase *> validMultiBitCases; |
143 | for (auto [index, caseVal] : llvm::enumerate(First&: cases)) { |
144 | uint64_t value = caseVal.getValue(); |
145 | if (value && !llvm::has_single_bit(Value: value) && !nonKeywordCases.test(Idx: index)) |
146 | validMultiBitCases.push_back(Elt: &caseVal); |
147 | } |
148 | if (!validMultiBitCases.empty()) { |
149 | os << " switch (value) {\n" ; |
150 | for (EnumAttrCase *caseVal : validMultiBitCases) { |
151 | StringRef symbol = caseVal->getSymbol(); |
152 | os << llvm::formatv(Fmt: " case {0}::{1}:\n" , Vals&: qualName, |
153 | Vals: llvm::isDigit(C: symbol.front()) ? ("_" + symbol) |
154 | : symbol); |
155 | } |
156 | os << " return p << valueStr;\n" |
157 | " default:\n" |
158 | " break;\n" |
159 | " }\n" ; |
160 | } |
161 | |
162 | // All other multi-bit cases should be printed as strings. |
163 | os << formatv(Fmt: " auto underlyingValue = " |
164 | "static_cast<std::make_unsigned_t<{0}>>(value);\n" , |
165 | Vals&: qualName); |
166 | os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n" |
167 | " return p << '\"' << valueStr << '\"';\n" ; |
168 | } |
169 | os << " return p << valueStr;\n" |
170 | "}\n" |
171 | "} // namespace llvm\n" ; |
172 | } |
173 | |
174 | static void emitDenseMapInfo(StringRef qualName, std::string underlyingType, |
175 | StringRef cppNamespace, raw_ostream &os) { |
176 | if (underlyingType.empty()) |
177 | underlyingType = |
178 | std::string(formatv(Fmt: "std::underlying_type_t<{0}>" , Vals&: qualName)); |
179 | |
180 | const char *const mapInfo = R"( |
181 | namespace llvm { |
182 | template<> struct DenseMapInfo<{0}> {{ |
183 | using StorageInfo = ::llvm::DenseMapInfo<{1}>; |
184 | |
185 | static inline {0} getEmptyKey() {{ |
186 | return static_cast<{0}>(StorageInfo::getEmptyKey()); |
187 | } |
188 | |
189 | static inline {0} getTombstoneKey() {{ |
190 | return static_cast<{0}>(StorageInfo::getTombstoneKey()); |
191 | } |
192 | |
193 | static unsigned getHashValue(const {0} &val) {{ |
194 | return StorageInfo::getHashValue(static_cast<{1}>(val)); |
195 | } |
196 | |
197 | static bool isEqual(const {0} &lhs, const {0} &rhs) {{ |
198 | return lhs == rhs; |
199 | } |
200 | }; |
201 | })" ; |
202 | os << formatv(Fmt: mapInfo, Vals&: qualName, Vals&: underlyingType); |
203 | os << "\n\n" ; |
204 | } |
205 | |
206 | static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) { |
207 | EnumAttr enumAttr(enumDef); |
208 | StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); |
209 | auto enumerants = enumAttr.getAllCases(); |
210 | |
211 | unsigned maxEnumVal = 0; |
212 | for (const auto &enumerant : enumerants) { |
213 | int64_t value = enumerant.getValue(); |
214 | // Avoid generating the max value function if there is an enumerant without |
215 | // explicit value. |
216 | if (value < 0) |
217 | return; |
218 | |
219 | maxEnumVal = std::max(a: maxEnumVal, b: static_cast<unsigned>(value)); |
220 | } |
221 | |
222 | // Emit the function to return the max enum value |
223 | os << formatv(Fmt: "inline constexpr unsigned {0}() {{\n" , Vals&: maxEnumValFnName); |
224 | os << formatv(Fmt: " return {0};\n" , Vals&: maxEnumVal); |
225 | os << "}\n\n" ; |
226 | } |
227 | |
228 | // Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt |
229 | // otherwise. |
230 | static std::optional<EnumAttrCase> |
231 | getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) { |
232 | for (auto attrCase : cases) { |
233 | if (attrCase.getValue() == 0) |
234 | return attrCase; |
235 | } |
236 | return std::nullopt; |
237 | } |
238 | |
239 | // Emits the following inline function for bit enums: |
240 | // |
241 | // inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b); |
242 | // inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b); |
243 | // inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b); |
244 | // inline constexpr <enum-type> operator~(<enum-type> bits); |
245 | // inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit); |
246 | // inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit); |
247 | // inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit); |
248 | // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit, |
249 | // bool value=true); |
250 | static void emitOperators(const Record &enumDef, raw_ostream &os) { |
251 | EnumAttr enumAttr(enumDef); |
252 | StringRef enumName = enumAttr.getEnumClassName(); |
253 | std::string underlyingType = std::string(enumAttr.getUnderlyingType()); |
254 | int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits" ); |
255 | const char *const operators = R"( |
256 | inline constexpr {0} operator|({0} a, {0} b) {{ |
257 | return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b)); |
258 | } |
259 | inline constexpr {0} operator&({0} a, {0} b) {{ |
260 | return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b)); |
261 | } |
262 | inline constexpr {0} operator^({0} a, {0} b) {{ |
263 | return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b)); |
264 | } |
265 | inline constexpr {0} operator~({0} bits) {{ |
266 | // Ensure only bits that can be present in the enum are set |
267 | return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u)); |
268 | } |
269 | inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{ |
270 | return (bits & bit) == bit; |
271 | } |
272 | inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{ |
273 | return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0; |
274 | } |
275 | inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{ |
276 | return bits & ~bit; |
277 | } |
278 | inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{ |
279 | return value ? (bits | bit) : bitEnumClear(bits, bit); |
280 | } |
281 | )" ; |
282 | os << formatv(Fmt: operators, Vals&: enumName, Vals&: underlyingType, Vals&: validBits); |
283 | } |
284 | |
285 | static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) { |
286 | EnumAttr enumAttr(enumDef); |
287 | StringRef enumName = enumAttr.getEnumClassName(); |
288 | StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); |
289 | StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); |
290 | auto enumerants = enumAttr.getAllCases(); |
291 | |
292 | os << formatv(Fmt: "{2} {1}({0} val) {{\n" , Vals&: enumName, Vals&: symToStrFnName, |
293 | Vals&: symToStrFnRetType); |
294 | os << " switch (val) {\n" ; |
295 | for (const auto &enumerant : enumerants) { |
296 | auto symbol = enumerant.getSymbol(); |
297 | auto str = enumerant.getStr(); |
298 | os << formatv(Fmt: " case {0}::{1}: return \"{2}\";\n" , Vals&: enumName, |
299 | Vals: makeIdentifier(str: symbol), Vals&: str); |
300 | } |
301 | os << " }\n" ; |
302 | os << " return \"\";\n" ; |
303 | os << "}\n\n" ; |
304 | } |
305 | |
306 | static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { |
307 | EnumAttr enumAttr(enumDef); |
308 | StringRef enumName = enumAttr.getEnumClassName(); |
309 | StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); |
310 | StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); |
311 | StringRef separator = enumDef.getValueAsString(FieldName: "separator" ); |
312 | auto enumerants = enumAttr.getAllCases(); |
313 | auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants); |
314 | |
315 | os << formatv(Fmt: "{2} {1}({0} symbol) {{\n" , Vals&: enumName, Vals&: symToStrFnName, |
316 | Vals&: symToStrFnRetType); |
317 | |
318 | os << formatv(Fmt: " auto val = static_cast<{0}>(symbol);\n" , |
319 | Vals: enumAttr.getUnderlyingType()); |
320 | // If we have unknown bit set, return an empty string to signal errors. |
321 | int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits" ); |
322 | os << formatv(Fmt: " assert({0}u == ({0}u | val) && \"invalid bits set in bit " |
323 | "enum\");\n" , |
324 | Vals&: validBits); |
325 | if (allBitsUnsetCase) { |
326 | os << " // Special case for all bits unset.\n" ; |
327 | os << formatv(Fmt: " if (val == 0) return \"{0}\";\n\n" , |
328 | Vals: allBitsUnsetCase->getStr()); |
329 | } |
330 | os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n" ; |
331 | |
332 | // Add case string if the value has all case bits, and remove them to avoid |
333 | // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1. |
334 | const char *const formatCompareRemove = R"( |
335 | if ({0}u == ({0}u & val)) {{ |
336 | strs.push_back("{1}"); |
337 | val &= ~static_cast<{2}>({0}); |
338 | } |
339 | )" ; |
340 | // Add case string if the value has all case bits. Used for individual bit |
341 | // cases, and for groups when printBitEnumPrimaryGroups is 0. |
342 | const char *const formatCompare = R"( |
343 | if ({0}u == ({0}u & val)) |
344 | strs.push_back("{1}"); |
345 | )" ; |
346 | // Optionally elide bits that are members of groups that will also be printed |
347 | // for more concise output. |
348 | if (enumAttr.printBitEnumPrimaryGroups()) { |
349 | os << " // Print bit enum groups before individual bits\n" ; |
350 | // Emit comparisons for group bit cases in reverse tablegen declaration |
351 | // order, removing bits for groups with all bits present. |
352 | for (const auto &enumerant : llvm::reverse(C&: enumerants)) { |
353 | if ((enumerant.getValue() != 0) && |
354 | enumerant.getDef().isSubClassOf(Name: "BitEnumAttrCaseGroup" )) { |
355 | os << formatv(Fmt: formatCompareRemove, Vals: enumerant.getValue(), |
356 | Vals: enumerant.getStr(), Vals: enumAttr.getUnderlyingType()); |
357 | } |
358 | } |
359 | // Emit comparisons for individual bit cases in tablegen declaration order. |
360 | for (const auto &enumerant : enumerants) { |
361 | if ((enumerant.getValue() != 0) && |
362 | enumerant.getDef().isSubClassOf(Name: "BitEnumAttrCaseBit" )) |
363 | os << formatv(Fmt: formatCompare, Vals: enumerant.getValue(), Vals: enumerant.getStr()); |
364 | } |
365 | } else { |
366 | // Emit comparisons for ALL nonzero cases (individual bits and groups) in |
367 | // tablegen declaration order. |
368 | for (const auto &enumerant : enumerants) { |
369 | if (enumerant.getValue() != 0) |
370 | os << formatv(Fmt: formatCompare, Vals: enumerant.getValue(), Vals: enumerant.getStr()); |
371 | } |
372 | } |
373 | os << formatv(Fmt: " return ::llvm::join(strs, \"{0}\");\n" , Vals&: separator); |
374 | |
375 | os << "}\n\n" ; |
376 | } |
377 | |
378 | static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) { |
379 | EnumAttr enumAttr(enumDef); |
380 | StringRef enumName = enumAttr.getEnumClassName(); |
381 | StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); |
382 | auto enumerants = enumAttr.getAllCases(); |
383 | |
384 | os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef str) {{\n" , |
385 | Vals&: enumName, Vals&: strToSymFnName); |
386 | os << formatv(Fmt: " return ::llvm::StringSwitch<::std::optional<{0}>>(str)\n" , |
387 | Vals&: enumName); |
388 | for (const auto &enumerant : enumerants) { |
389 | auto symbol = enumerant.getSymbol(); |
390 | auto str = enumerant.getStr(); |
391 | os << formatv(Fmt: " .Case(\"{1}\", {0}::{2})\n" , Vals&: enumName, Vals&: str, |
392 | Vals: makeIdentifier(str: symbol)); |
393 | } |
394 | os << " .Default(::std::nullopt);\n" ; |
395 | os << "}\n" ; |
396 | } |
397 | |
398 | static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { |
399 | EnumAttr enumAttr(enumDef); |
400 | StringRef enumName = enumAttr.getEnumClassName(); |
401 | std::string underlyingType = std::string(enumAttr.getUnderlyingType()); |
402 | StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); |
403 | StringRef separator = enumDef.getValueAsString(FieldName: "separator" ); |
404 | StringRef separatorTrimmed = separator.trim(); |
405 | auto enumerants = enumAttr.getAllCases(); |
406 | auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants); |
407 | |
408 | os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef str) {{\n" , |
409 | Vals&: enumName, Vals&: strToSymFnName); |
410 | |
411 | if (allBitsUnsetCase) { |
412 | os << " // Special case for all bits unset.\n" ; |
413 | StringRef caseSymbol = allBitsUnsetCase->getSymbol(); |
414 | os << formatv(Fmt: " if (str == \"{1}\") return {0}::{2};\n\n" , Vals&: enumName, |
415 | Vals: allBitsUnsetCase->getStr(), Vals: makeIdentifier(str: caseSymbol)); |
416 | } |
417 | |
418 | // Split the string to get symbols for all the bits. |
419 | os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n" ; |
420 | // Remove whitespace from the separator string when parsing. |
421 | os << formatv(Fmt: " str.split(symbols, \"{0}\");\n\n" , Vals&: separatorTrimmed); |
422 | |
423 | os << formatv(Fmt: " {0} val = 0;\n" , Vals&: underlyingType); |
424 | os << " for (auto symbol : symbols) {\n" ; |
425 | |
426 | // Convert each symbol to the bit ordinal and set the corresponding bit. |
427 | os << formatv(Fmt: " auto bit = " |
428 | "llvm::StringSwitch<::std::optional<{0}>>(symbol.trim())\n" , |
429 | Vals&: underlyingType); |
430 | for (const auto &enumerant : enumerants) { |
431 | // Skip the special enumerant for None. |
432 | if (auto val = enumerant.getValue()) |
433 | os.indent(NumSpaces: 6) << formatv(Fmt: ".Case(\"{0}\", {1})\n" , Vals: enumerant.getStr(), Vals&: val); |
434 | } |
435 | os.indent(NumSpaces: 6) << ".Default(::std::nullopt);\n" ; |
436 | |
437 | os << " if (bit) { val |= *bit; } else { return ::std::nullopt; }\n" ; |
438 | os << " }\n" ; |
439 | |
440 | os << formatv(Fmt: " return static_cast<{0}>(val);\n" , Vals&: enumName); |
441 | os << "}\n\n" ; |
442 | } |
443 | |
444 | static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef, |
445 | raw_ostream &os) { |
446 | EnumAttr enumAttr(enumDef); |
447 | StringRef enumName = enumAttr.getEnumClassName(); |
448 | std::string underlyingType = std::string(enumAttr.getUnderlyingType()); |
449 | StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); |
450 | auto enumerants = enumAttr.getAllCases(); |
451 | |
452 | // Avoid generating the underlying value to symbol conversion function if |
453 | // there is an enumerant without explicit value. |
454 | if (llvm::any_of(Range&: enumerants, P: [](EnumAttrCase enumerant) { |
455 | return enumerant.getValue() < 0; |
456 | })) |
457 | return; |
458 | |
459 | os << formatv(Fmt: "::std::optional<{0}> {1}({2} value) {{\n" , Vals&: enumName, |
460 | Vals&: underlyingToSymFnName, |
461 | Vals: underlyingType.empty() ? std::string("unsigned" ) |
462 | : underlyingType) |
463 | << " switch (value) {\n" ; |
464 | for (const auto &enumerant : enumerants) { |
465 | auto symbol = enumerant.getSymbol(); |
466 | auto value = enumerant.getValue(); |
467 | os << formatv(Fmt: " case {0}: return {1}::{2};\n" , Vals&: value, Vals&: enumName, |
468 | Vals: makeIdentifier(str: symbol)); |
469 | } |
470 | os << " default: return ::std::nullopt;\n" |
471 | << " }\n" |
472 | << "}\n\n" ; |
473 | } |
474 | |
475 | static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { |
476 | EnumAttr enumAttr(enumDef); |
477 | StringRef enumName = enumAttr.getEnumClassName(); |
478 | StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); |
479 | llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass(); |
480 | Attribute baseAttr(baseAttrDef); |
481 | |
482 | // Emit classof method |
483 | |
484 | os << formatv(Fmt: "bool {0}::classof(::mlir::Attribute attr) {{\n" , |
485 | Vals&: attrClassName); |
486 | |
487 | mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate(); |
488 | if (baseAttrPred.isNull()) |
489 | PrintFatalError(Msg: "ERROR: baseAttrClass for EnumAttr has no Predicate\n" ); |
490 | |
491 | std::string condition = baseAttrPred.getCondition(); |
492 | FmtContext verifyCtx; |
493 | verifyCtx.withSelf(subst: "attr" ); |
494 | os << tgfmt(fmt: " return $0;\n" , /*ctx=*/nullptr, vals: tgfmt(fmt: condition, ctx: &verifyCtx)); |
495 | |
496 | os << "}\n" ; |
497 | |
498 | // Emit get method |
499 | |
500 | os << formatv(Fmt: "{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n" , |
501 | Vals&: attrClassName, Vals&: enumName); |
502 | |
503 | StringRef underlyingType = enumAttr.getUnderlyingType(); |
504 | |
505 | // Assuming that it is IntegerAttr constraint |
506 | int64_t bitwidth = 64; |
507 | if (baseAttrDef->getValue(Name: "valueType" )) { |
508 | auto *valueTypeDef = baseAttrDef->getValueAsDef(FieldName: "valueType" ); |
509 | if (valueTypeDef->getValue(Name: "bitwidth" )) |
510 | bitwidth = valueTypeDef->getValueAsInt(FieldName: "bitwidth" ); |
511 | } |
512 | |
513 | os << formatv(Fmt: " ::mlir::IntegerType intType = " |
514 | "::mlir::IntegerType::get(context, {0});\n" , |
515 | Vals&: bitwidth); |
516 | os << formatv(Fmt: " ::mlir::IntegerAttr baseAttr = " |
517 | "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n" , |
518 | Vals&: underlyingType); |
519 | os << formatv(Fmt: " return ::llvm::cast<{0}>(baseAttr);\n" , Vals&: attrClassName); |
520 | |
521 | os << "}\n" ; |
522 | |
523 | // Emit getValue method |
524 | |
525 | os << formatv(Fmt: "{0} {1}::getValue() const {{\n" , Vals&: enumName, Vals&: attrClassName); |
526 | |
527 | os << formatv(Fmt: " return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n" , |
528 | Vals&: enumName); |
529 | |
530 | os << "}\n" ; |
531 | } |
532 | |
533 | static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, |
534 | raw_ostream &os) { |
535 | EnumAttr enumAttr(enumDef); |
536 | StringRef enumName = enumAttr.getEnumClassName(); |
537 | std::string underlyingType = std::string(enumAttr.getUnderlyingType()); |
538 | StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); |
539 | auto enumerants = enumAttr.getAllCases(); |
540 | auto allBitsUnsetCase = getAllBitsUnsetCase(cases: enumerants); |
541 | |
542 | os << formatv(Fmt: "::std::optional<{0}> {1}({2} value) {{\n" , Vals&: enumName, |
543 | Vals&: underlyingToSymFnName, Vals&: underlyingType); |
544 | if (allBitsUnsetCase) { |
545 | os << " // Special case for all bits unset.\n" ; |
546 | os << formatv(Fmt: " if (value == 0) return {0}::{1};\n\n" , Vals&: enumName, |
547 | Vals: makeIdentifier(str: allBitsUnsetCase->getSymbol())); |
548 | } |
549 | int64_t validBits = enumDef.getValueAsInt(FieldName: "validBits" ); |
550 | os << formatv(Fmt: " if (value & ~static_cast<{0}>({1}u)) return std::nullopt;\n" , |
551 | Vals&: underlyingType, Vals&: validBits); |
552 | os << formatv(Fmt: " return static_cast<{0}>(value);\n" , Vals&: enumName); |
553 | os << "}\n" ; |
554 | } |
555 | |
556 | static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { |
557 | EnumAttr enumAttr(enumDef); |
558 | StringRef enumName = enumAttr.getEnumClassName(); |
559 | StringRef cppNamespace = enumAttr.getCppNamespace(); |
560 | std::string underlyingType = std::string(enumAttr.getUnderlyingType()); |
561 | StringRef description = enumAttr.getSummary(); |
562 | StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); |
563 | StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); |
564 | StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); |
565 | StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); |
566 | auto enumerants = enumAttr.getAllCases(); |
567 | |
568 | llvm::SmallVector<StringRef, 2> namespaces; |
569 | llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::" ); |
570 | |
571 | for (auto ns : namespaces) |
572 | os << "namespace " << ns << " {\n" ; |
573 | |
574 | // Emit the enum class definition |
575 | emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); |
576 | |
577 | // Emit conversion function declarations |
578 | if (llvm::all_of(Range&: enumerants, P: [](EnumAttrCase enumerant) { |
579 | return enumerant.getValue() >= 0; |
580 | })) { |
581 | os << formatv( |
582 | Fmt: "::std::optional<{0}> {1}({2});\n" , Vals&: enumName, Vals&: underlyingToSymFnName, |
583 | Vals: underlyingType.empty() ? std::string("unsigned" ) : underlyingType); |
584 | } |
585 | os << formatv(Fmt: "{2} {1}({0});\n" , Vals&: enumName, Vals&: symToStrFnName, Vals&: symToStrFnRetType); |
586 | os << formatv(Fmt: "::std::optional<{0}> {1}(::llvm::StringRef);\n" , Vals&: enumName, |
587 | Vals&: strToSymFnName); |
588 | |
589 | if (enumAttr.isBitEnum()) { |
590 | emitOperators(enumDef, os); |
591 | } else { |
592 | emitMaxValueFn(enumDef, os); |
593 | } |
594 | |
595 | // Generate a generic `stringifyEnum` function that forwards to the method |
596 | // specified by the user. |
597 | const char *const stringifyEnumStr = R"( |
598 | inline {0} stringifyEnum({1} enumValue) {{ |
599 | return {2}(enumValue); |
600 | } |
601 | )" ; |
602 | os << formatv(Fmt: stringifyEnumStr, Vals&: symToStrFnRetType, Vals&: enumName, Vals&: symToStrFnName); |
603 | |
604 | // Generate a generic `symbolizeEnum` function that forwards to the method |
605 | // specified by the user. |
606 | const char *const symbolizeEnumStr = R"( |
607 | template <typename EnumType> |
608 | ::std::optional<EnumType> symbolizeEnum(::llvm::StringRef); |
609 | |
610 | template <> |
611 | inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) { |
612 | return {1}(str); |
613 | } |
614 | )" ; |
615 | os << formatv(Fmt: symbolizeEnumStr, Vals&: enumName, Vals&: strToSymFnName); |
616 | |
617 | const char *const attrClassDecl = R"( |
618 | class {1} : public ::mlir::{2} { |
619 | public: |
620 | using ValueType = {0}; |
621 | using ::mlir::{2}::{2}; |
622 | static bool classof(::mlir::Attribute attr); |
623 | static {1} get(::mlir::MLIRContext *context, {0} val); |
624 | {0} getValue() const; |
625 | }; |
626 | )" ; |
627 | if (enumAttr.genSpecializedAttr()) { |
628 | StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); |
629 | StringRef baseAttrClassName = "IntegerAttr" ; |
630 | os << formatv(Fmt: attrClassDecl, Vals&: enumName, Vals&: attrClassName, Vals&: baseAttrClassName); |
631 | } |
632 | |
633 | for (auto ns : llvm::reverse(C&: namespaces)) |
634 | os << "} // namespace " << ns << "\n" ; |
635 | |
636 | // Generate a generic parser and printer for the enum. |
637 | std::string qualName = |
638 | std::string(formatv(Fmt: "{0}::{1}" , Vals&: cppNamespace, Vals&: enumName)); |
639 | emitParserPrinter(enumAttr, qualName, cppNamespace, os); |
640 | |
641 | // Emit DenseMapInfo for this enum class |
642 | emitDenseMapInfo(qualName, underlyingType, cppNamespace, os); |
643 | } |
644 | |
645 | static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { |
646 | llvm::emitSourceFileHeader(Desc: "Enum Utility Declarations" , OS&: os, Record: recordKeeper); |
647 | |
648 | auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined(ClassName: "EnumAttrInfo" ); |
649 | for (const auto *def : defs) |
650 | emitEnumDecl(enumDef: *def, os); |
651 | |
652 | return false; |
653 | } |
654 | |
655 | static void emitEnumDef(const Record &enumDef, raw_ostream &os) { |
656 | EnumAttr enumAttr(enumDef); |
657 | StringRef cppNamespace = enumAttr.getCppNamespace(); |
658 | |
659 | llvm::SmallVector<StringRef, 2> namespaces; |
660 | llvm::SplitString(Source: cppNamespace, OutFragments&: namespaces, Delimiters: "::" ); |
661 | |
662 | for (auto ns : namespaces) |
663 | os << "namespace " << ns << " {\n" ; |
664 | |
665 | if (enumAttr.isBitEnum()) { |
666 | emitSymToStrFnForBitEnum(enumDef, os); |
667 | emitStrToSymFnForBitEnum(enumDef, os); |
668 | emitUnderlyingToSymFnForBitEnum(enumDef, os); |
669 | } else { |
670 | emitSymToStrFnForIntEnum(enumDef, os); |
671 | emitStrToSymFnForIntEnum(enumDef, os); |
672 | emitUnderlyingToSymFnForIntEnum(enumDef, os); |
673 | } |
674 | |
675 | if (enumAttr.genSpecializedAttr()) |
676 | emitSpecializedAttrDef(enumDef, os); |
677 | |
678 | for (auto ns : llvm::reverse(C&: namespaces)) |
679 | os << "} // namespace " << ns << "\n" ; |
680 | os << "\n" ; |
681 | } |
682 | |
683 | static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { |
684 | llvm::emitSourceFileHeader(Desc: "Enum Utility Definitions" , OS&: os, Record: recordKeeper); |
685 | |
686 | auto defs = recordKeeper.getAllDerivedDefinitionsIfDefined(ClassName: "EnumAttrInfo" ); |
687 | for (const auto *def : defs) |
688 | emitEnumDef(enumDef: *def, os); |
689 | |
690 | return false; |
691 | } |
692 | |
693 | // Registers the enum utility generator to mlir-tblgen. |
694 | static mlir::GenRegistration |
695 | genEnumDecls("gen-enum-decls" , "Generate enum utility declarations" , |
696 | [](const RecordKeeper &records, raw_ostream &os) { |
697 | return emitEnumDecls(recordKeeper: records, os); |
698 | }); |
699 | |
700 | // Registers the enum utility generator to mlir-tblgen. |
701 | static mlir::GenRegistration |
702 | genEnumDefs("gen-enum-defs" , "Generate enum utility definitions" , |
703 | [](const RecordKeeper &records, raw_ostream &os) { |
704 | return emitEnumDefs(recordKeeper: records, os); |
705 | }); |
706 | |