| 1 | //===- SPIRVParsingUtils.h - MLIR SPIR-V Dialect Parsing Utilities --------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| 10 | #include "mlir/IR/Builders.h" |
| 11 | #include "mlir/IR/OpDefinition.h" |
| 12 | #include "mlir/IR/OpImplementation.h" |
| 13 | |
| 14 | #include "llvm/ADT/ArrayRef.h" |
| 15 | #include "llvm/ADT/FunctionExtras.h" |
| 16 | #include "llvm/ADT/SmallVector.h" |
| 17 | #include "llvm/ADT/StringRef.h" |
| 18 | |
| 19 | #include <type_traits> |
| 20 | |
| 21 | namespace mlir::spirv { |
| 22 | namespace AttrNames { |
| 23 | |
| 24 | inline constexpr char kClusterSize[] = "cluster_size" ; // no ODS generation |
| 25 | inline constexpr char kControl[] = "control" ; // no ODS generation |
| 26 | inline constexpr char kFnNameAttrName[] = "fn" ; // no ODS generation |
| 27 | inline constexpr char kSpecIdAttrName[] = "spec_id" ; // no ODS generation |
| 28 | |
| 29 | } // namespace AttrNames |
| 30 | |
| 31 | template <typename Ty> |
| 32 | ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues, |
| 33 | function_ref<StringRef(Ty)> stringifyFn) { |
| 34 | if (enumValues.empty()) { |
| 35 | return nullptr; |
| 36 | } |
| 37 | SmallVector<StringRef, 1> enumValStrs; |
| 38 | enumValStrs.reserve(N: enumValues.size()); |
| 39 | for (auto val : enumValues) { |
| 40 | enumValStrs.emplace_back(stringifyFn(val)); |
| 41 | } |
| 42 | return builder.getStrArrayAttr(enumValStrs); |
| 43 | } |
| 44 | |
| 45 | /// Parses the next keyword in `parser` as an enumerant of the given |
| 46 | /// `EnumClass`. |
| 47 | template <typename EnumClass, typename ParserType> |
| 48 | ParseResult |
| 49 | parseEnumKeywordAttr(EnumClass &value, ParserType &parser, |
| 50 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
| 51 | StringRef keyword; |
| 52 | auto loc = parser.getCurrentLocation(); |
| 53 | if (parser.parseKeyword(&keyword)) |
| 54 | return failure(); |
| 55 | |
| 56 | if (std::optional<EnumClass> attr = |
| 57 | spirv::symbolizeEnum<EnumClass>(keyword)) { |
| 58 | value = *attr; |
| 59 | return success(); |
| 60 | } |
| 61 | return parser.emitError(loc, "invalid " ) |
| 62 | << attrName << " attribute specification: " << keyword; |
| 63 | } |
| 64 | |
| 65 | /// Parses the next string attribute in `parser` as an enumerant of the given |
| 66 | /// `EnumClass`. |
| 67 | template <typename EnumClass> |
| 68 | ParseResult |
| 69 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, |
| 70 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
| 71 | static_assert(std::is_enum_v<EnumClass>); |
| 72 | Attribute attrVal; |
| 73 | NamedAttrList attr; |
| 74 | auto loc = parser.getCurrentLocation(); |
| 75 | if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), |
| 76 | attrName, attr)) |
| 77 | return failure(); |
| 78 | if (!llvm::isa<StringAttr>(Val: attrVal)) |
| 79 | return parser.emitError(loc, message: "expected " ) |
| 80 | << attrName << " attribute specified as string" ; |
| 81 | auto attrOptional = spirv::symbolizeEnum<EnumClass>( |
| 82 | llvm::cast<StringAttr>(attrVal).getValue()); |
| 83 | if (!attrOptional) |
| 84 | return parser.emitError(loc, message: "invalid " ) |
| 85 | << attrName << " attribute specification: " << attrVal; |
| 86 | value = *attrOptional; |
| 87 | return success(); |
| 88 | } |
| 89 | |
| 90 | /// Parses the next string attribute in `parser` as an enumerant of the given |
| 91 | /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer |
| 92 | /// attribute with the enum class's name as attribute name. |
| 93 | template <typename EnumAttrClass, |
| 94 | typename EnumClass = typename EnumAttrClass::ValueType> |
| 95 | ParseResult |
| 96 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, |
| 97 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
| 98 | static_assert(std::is_enum_v<EnumClass>); |
| 99 | if (parseEnumStrAttr(value, parser, attrName)) |
| 100 | return failure(); |
| 101 | state.addAttribute(attrName, |
| 102 | parser.getBuilder().getAttr<EnumAttrClass>(value)); |
| 103 | return success(); |
| 104 | } |
| 105 | |
| 106 | /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` |
| 107 | /// and inserts the enumerant into `state` as an 32-bit integer attribute with |
| 108 | /// the enum class's name as attribute name. |
| 109 | template <typename EnumAttrClass, |
| 110 | typename EnumClass = typename EnumAttrClass::ValueType> |
| 111 | ParseResult |
| 112 | parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, |
| 113 | OperationState &state, |
| 114 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
| 115 | static_assert(std::is_enum_v<EnumClass>); |
| 116 | if (parseEnumKeywordAttr(value, parser)) |
| 117 | return failure(); |
| 118 | state.addAttribute(attrName, |
| 119 | parser.getBuilder().getAttr<EnumAttrClass>(value)); |
| 120 | return success(); |
| 121 | } |
| 122 | |
| 123 | ParseResult parseVariableDecorations(OpAsmParser &parser, |
| 124 | OperationState &state); |
| 125 | |
| 126 | } // namespace mlir::spirv |
| 127 | |