1 | //===- SPIRVParsingUtils.h - MLIR SPIR-V Dialect Parsing Utilities --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/OpDefinition.h" |
12 | #include "mlir/IR/OpImplementation.h" |
13 | |
14 | #include "llvm/ADT/ArrayRef.h" |
15 | #include "llvm/ADT/FunctionExtras.h" |
16 | #include "llvm/ADT/SmallVector.h" |
17 | #include "llvm/ADT/StringRef.h" |
18 | |
19 | #include <type_traits> |
20 | |
21 | namespace mlir::spirv { |
22 | namespace AttrNames { |
23 | |
24 | inline constexpr char kClusterSize[] = "cluster_size" ; // no ODS generation |
25 | inline constexpr char kControl[] = "control" ; // no ODS generation |
26 | inline constexpr char kFnNameAttrName[] = "fn" ; // no ODS generation |
27 | inline constexpr char kSpecIdAttrName[] = "spec_id" ; // no ODS generation |
28 | |
29 | } // namespace AttrNames |
30 | |
31 | template <typename Ty> |
32 | ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues, |
33 | function_ref<StringRef(Ty)> stringifyFn) { |
34 | if (enumValues.empty()) { |
35 | return nullptr; |
36 | } |
37 | SmallVector<StringRef, 1> enumValStrs; |
38 | enumValStrs.reserve(N: enumValues.size()); |
39 | for (auto val : enumValues) { |
40 | enumValStrs.emplace_back(stringifyFn(val)); |
41 | } |
42 | return builder.getStrArrayAttr(enumValStrs); |
43 | } |
44 | |
45 | /// Parses the next keyword in `parser` as an enumerant of the given |
46 | /// `EnumClass`. |
47 | template <typename EnumClass, typename ParserType> |
48 | ParseResult |
49 | parseEnumKeywordAttr(EnumClass &value, ParserType &parser, |
50 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
51 | StringRef keyword; |
52 | SmallVector<NamedAttribute, 1> attr; |
53 | auto loc = parser.getCurrentLocation(); |
54 | if (parser.parseKeyword(&keyword)) |
55 | return failure(); |
56 | |
57 | if (std::optional<EnumClass> attr = |
58 | spirv::symbolizeEnum<EnumClass>(keyword)) { |
59 | value = *attr; |
60 | return success(); |
61 | } |
62 | return parser.emitError(loc, "invalid " ) |
63 | << attrName << " attribute specification: " << keyword; |
64 | } |
65 | |
66 | /// Parses the next string attribute in `parser` as an enumerant of the given |
67 | /// `EnumClass`. |
68 | template <typename EnumClass> |
69 | ParseResult |
70 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, |
71 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
72 | static_assert(std::is_enum_v<EnumClass>); |
73 | Attribute attrVal; |
74 | NamedAttrList attr; |
75 | auto loc = parser.getCurrentLocation(); |
76 | if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), |
77 | attrName, attr)) |
78 | return failure(); |
79 | if (!llvm::isa<StringAttr>(Val: attrVal)) |
80 | return parser.emitError(loc, message: "expected " ) |
81 | << attrName << " attribute specified as string" ; |
82 | auto attrOptional = spirv::symbolizeEnum<EnumClass>( |
83 | llvm::cast<StringAttr>(attrVal).getValue()); |
84 | if (!attrOptional) |
85 | return parser.emitError(loc, message: "invalid " ) |
86 | << attrName << " attribute specification: " << attrVal; |
87 | value = *attrOptional; |
88 | return success(); |
89 | } |
90 | |
91 | /// Parses the next string attribute in `parser` as an enumerant of the given |
92 | /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer |
93 | /// attribute with the enum class's name as attribute name. |
94 | template <typename EnumAttrClass, |
95 | typename EnumClass = typename EnumAttrClass::ValueType> |
96 | ParseResult |
97 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, |
98 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
99 | static_assert(std::is_enum_v<EnumClass>); |
100 | if (parseEnumStrAttr(value, parser, attrName)) |
101 | return failure(); |
102 | state.addAttribute(attrName, |
103 | parser.getBuilder().getAttr<EnumAttrClass>(value)); |
104 | return success(); |
105 | } |
106 | |
107 | /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` |
108 | /// and inserts the enumerant into `state` as an 32-bit integer attribute with |
109 | /// the enum class's name as attribute name. |
110 | template <typename EnumAttrClass, |
111 | typename EnumClass = typename EnumAttrClass::ValueType> |
112 | ParseResult |
113 | parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, |
114 | OperationState &state, |
115 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
116 | static_assert(std::is_enum_v<EnumClass>); |
117 | if (parseEnumKeywordAttr(value, parser)) |
118 | return failure(); |
119 | state.addAttribute(attrName, |
120 | parser.getBuilder().getAttr<EnumAttrClass>(value)); |
121 | return success(); |
122 | } |
123 | |
124 | ParseResult parseVariableDecorations(OpAsmParser &parser, |
125 | OperationState &state); |
126 | |
127 | } // namespace mlir::spirv |
128 | |