1 | //===- SPIRVParsingUtils.h - MLIR SPIR-V Dialect Parsing Utilities --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/OpDefinition.h" |
12 | #include "mlir/IR/OpImplementation.h" |
13 | |
14 | #include "llvm/ADT/ArrayRef.h" |
15 | #include "llvm/ADT/FunctionExtras.h" |
16 | #include "llvm/ADT/SmallVector.h" |
17 | #include "llvm/ADT/StringRef.h" |
18 | |
19 | #include <type_traits> |
20 | |
21 | namespace mlir::spirv { |
22 | namespace AttrNames { |
23 | |
24 | inline constexpr char kClusterSize[] = "cluster_size" ; // no ODS generation |
25 | inline constexpr char kControl[] = "control" ; // no ODS generation |
26 | inline constexpr char kFnNameAttrName[] = "fn" ; // no ODS generation |
27 | inline constexpr char kSpecIdAttrName[] = "spec_id" ; // no ODS generation |
28 | |
29 | } // namespace AttrNames |
30 | |
31 | template <typename Ty> |
32 | ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues, |
33 | function_ref<StringRef(Ty)> stringifyFn) { |
34 | if (enumValues.empty()) { |
35 | return nullptr; |
36 | } |
37 | SmallVector<StringRef, 1> enumValStrs; |
38 | enumValStrs.reserve(N: enumValues.size()); |
39 | for (auto val : enumValues) { |
40 | enumValStrs.emplace_back(stringifyFn(val)); |
41 | } |
42 | return builder.getStrArrayAttr(enumValStrs); |
43 | } |
44 | |
45 | /// Parses the next keyword in `parser` as an enumerant of the given |
46 | /// `EnumClass`. |
47 | template <typename EnumClass, typename ParserType> |
48 | ParseResult |
49 | parseEnumKeywordAttr(EnumClass &value, ParserType &parser, |
50 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
51 | StringRef keyword; |
52 | auto loc = parser.getCurrentLocation(); |
53 | if (parser.parseKeyword(&keyword)) |
54 | return failure(); |
55 | |
56 | if (std::optional<EnumClass> attr = |
57 | spirv::symbolizeEnum<EnumClass>(keyword)) { |
58 | value = *attr; |
59 | return success(); |
60 | } |
61 | return parser.emitError(loc, "invalid " ) |
62 | << attrName << " attribute specification: " << keyword; |
63 | } |
64 | |
65 | /// Parses the next string attribute in `parser` as an enumerant of the given |
66 | /// `EnumClass`. |
67 | template <typename EnumClass> |
68 | ParseResult |
69 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, |
70 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
71 | static_assert(std::is_enum_v<EnumClass>); |
72 | Attribute attrVal; |
73 | NamedAttrList attr; |
74 | auto loc = parser.getCurrentLocation(); |
75 | if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), |
76 | attrName, attr)) |
77 | return failure(); |
78 | if (!llvm::isa<StringAttr>(Val: attrVal)) |
79 | return parser.emitError(loc, message: "expected " ) |
80 | << attrName << " attribute specified as string" ; |
81 | auto attrOptional = spirv::symbolizeEnum<EnumClass>( |
82 | llvm::cast<StringAttr>(attrVal).getValue()); |
83 | if (!attrOptional) |
84 | return parser.emitError(loc, message: "invalid " ) |
85 | << attrName << " attribute specification: " << attrVal; |
86 | value = *attrOptional; |
87 | return success(); |
88 | } |
89 | |
90 | /// Parses the next string attribute in `parser` as an enumerant of the given |
91 | /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer |
92 | /// attribute with the enum class's name as attribute name. |
93 | template <typename EnumAttrClass, |
94 | typename EnumClass = typename EnumAttrClass::ValueType> |
95 | ParseResult |
96 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, |
97 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
98 | static_assert(std::is_enum_v<EnumClass>); |
99 | if (parseEnumStrAttr(value, parser, attrName)) |
100 | return failure(); |
101 | state.addAttribute(attrName, |
102 | parser.getBuilder().getAttr<EnumAttrClass>(value)); |
103 | return success(); |
104 | } |
105 | |
106 | /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` |
107 | /// and inserts the enumerant into `state` as an 32-bit integer attribute with |
108 | /// the enum class's name as attribute name. |
109 | template <typename EnumAttrClass, |
110 | typename EnumClass = typename EnumAttrClass::ValueType> |
111 | ParseResult |
112 | parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, |
113 | OperationState &state, |
114 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
115 | static_assert(std::is_enum_v<EnumClass>); |
116 | if (parseEnumKeywordAttr(value, parser)) |
117 | return failure(); |
118 | state.addAttribute(attrName, |
119 | parser.getBuilder().getAttr<EnumAttrClass>(value)); |
120 | return success(); |
121 | } |
122 | |
123 | ParseResult parseVariableDecorations(OpAsmParser &parser, |
124 | OperationState &state); |
125 | |
126 | } // namespace mlir::spirv |
127 | |