1 | //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains attributes defined by the TestDialect for testing various |
10 | // features of MLIR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestAttributes.h" |
15 | #include "TestDialect.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/DialectImplementation.h" |
18 | #include "mlir/IR/ExtensibleDialect.h" |
19 | #include "mlir/IR/Types.h" |
20 | #include "mlir/Support/LogicalResult.h" |
21 | #include "llvm/ADT/Hashing.h" |
22 | #include "llvm/ADT/StringExtras.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | #include "llvm/ADT/bit.h" |
25 | #include "llvm/Support/ErrorHandling.h" |
26 | #include "llvm/Support/raw_ostream.h" |
27 | |
28 | using namespace mlir; |
29 | using namespace test; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // CompoundAAttr |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | Attribute CompoundAAttr::parse(AsmParser &parser, Type type) { |
36 | int widthOfSomething; |
37 | Type oneType; |
38 | SmallVector<int, 4> arrayOfInts; |
39 | if (parser.parseLess() || parser.parseInteger(widthOfSomething) || |
40 | parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || |
41 | parser.parseLSquare()) |
42 | return Attribute(); |
43 | |
44 | int intVal; |
45 | while (!*parser.parseOptionalInteger(intVal)) { |
46 | arrayOfInts.push_back(intVal); |
47 | if (parser.parseOptionalComma()) |
48 | break; |
49 | } |
50 | |
51 | if (parser.parseRSquare() || parser.parseGreater()) |
52 | return Attribute(); |
53 | return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); |
54 | } |
55 | |
56 | void CompoundAAttr::print(AsmPrinter &printer) const { |
57 | printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [" ; |
58 | llvm::interleaveComma(getArrayOfInts(), printer); |
59 | printer << "]>" ; |
60 | } |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // CompoundAAttr |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { |
67 | SmallVector<uint64_t> elements; |
68 | if (parser.parseLess() || parser.parseLSquare()) |
69 | return Attribute(); |
70 | uint64_t intVal; |
71 | while (succeeded(*parser.parseOptionalInteger(intVal))) { |
72 | elements.push_back(intVal); |
73 | if (parser.parseOptionalComma()) |
74 | break; |
75 | } |
76 | |
77 | if (parser.parseRSquare() || parser.parseGreater()) |
78 | return Attribute(); |
79 | return parser.getChecked<TestI64ElementsAttr>( |
80 | parser.getContext(), llvm::cast<ShapedType>(type), elements); |
81 | } |
82 | |
83 | void TestI64ElementsAttr::print(AsmPrinter &printer) const { |
84 | printer << "<[" ; |
85 | llvm::interleaveComma(getElements(), printer); |
86 | printer << "]>" ; |
87 | } |
88 | |
89 | LogicalResult |
90 | TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
91 | ShapedType type, ArrayRef<uint64_t> elements) { |
92 | if (type.getNumElements() != static_cast<int64_t>(elements.size())) { |
93 | return emitError() |
94 | << "number of elements does not match the provided shape type, got: " |
95 | << elements.size() << ", but expected: " << type.getNumElements(); |
96 | } |
97 | if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) |
98 | return emitError() << "expected single rank 64-bit shape type, but got: " |
99 | << type; |
100 | return success(); |
101 | } |
102 | |
103 | LogicalResult TestAttrWithFormatAttr::verify( |
104 | function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two, |
105 | IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six, |
106 | ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) { |
107 | if (four.size() != static_cast<unsigned>(one)) |
108 | return emitError() << "expected 'one' to equal 'four.size()'" ; |
109 | return success(); |
110 | } |
111 | |
112 | //===----------------------------------------------------------------------===// |
113 | // Utility Functions for Generated Attributes |
114 | //===----------------------------------------------------------------------===// |
115 | |
116 | static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) { |
117 | SmallVector<int> ints; |
118 | if (parser.parseLSquare() || parser.parseCommaSeparatedList(parseElementFn: [&]() { |
119 | ints.push_back(Elt: 0); |
120 | return parser.parseInteger(result&: ints.back()); |
121 | }) || |
122 | parser.parseRSquare()) |
123 | return failure(); |
124 | return ints; |
125 | } |
126 | |
127 | static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) { |
128 | printer << '['; |
129 | llvm::interleaveComma(c: ints, os&: printer); |
130 | printer << ']'; |
131 | } |
132 | |
133 | //===----------------------------------------------------------------------===// |
134 | // TestSubElementsAccessAttr |
135 | //===----------------------------------------------------------------------===// |
136 | |
137 | Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser, |
138 | ::mlir::Type type) { |
139 | Attribute first, second, third; |
140 | if (parser.parseLess() || parser.parseAttribute(first) || |
141 | parser.parseComma() || parser.parseAttribute(second) || |
142 | parser.parseComma() || parser.parseAttribute(third) || |
143 | parser.parseGreater()) { |
144 | return {}; |
145 | } |
146 | return get(parser.getContext(), first, second, third); |
147 | } |
148 | |
149 | void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const { |
150 | printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird() |
151 | << ">" ; |
152 | } |
153 | |
154 | //===----------------------------------------------------------------------===// |
155 | // TestExtern1DI64ElementsAttr |
156 | //===----------------------------------------------------------------------===// |
157 | |
158 | ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const { |
159 | if (auto *blob = getHandle().getBlob()) |
160 | return blob->getDataAs<uint64_t>(); |
161 | return std::nullopt; |
162 | } |
163 | |
164 | //===----------------------------------------------------------------------===// |
165 | // TestCustomAnchorAttr |
166 | //===----------------------------------------------------------------------===// |
167 | |
168 | static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) { |
169 | bool b; |
170 | if (p.parseInteger(result&: b)) |
171 | return failure(); |
172 | result = b; |
173 | return success(); |
174 | } |
175 | |
176 | static void printTrueFalse(AsmPrinter &p, std::optional<int> result) { |
177 | p << (*result ? "true" : "false" ); |
178 | } |
179 | |
180 | //===----------------------------------------------------------------------===// |
181 | // CopyCountAttr Implementation |
182 | //===----------------------------------------------------------------------===// |
183 | |
184 | CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) { |
185 | CopyCount::counter++; |
186 | } |
187 | |
188 | CopyCount &CopyCount::operator=(const CopyCount &rhs) { |
189 | CopyCount::counter++; |
190 | value = rhs.value; |
191 | return *this; |
192 | } |
193 | |
194 | int CopyCount::counter; |
195 | |
196 | static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) { |
197 | return lhs.value == rhs.value; |
198 | } |
199 | |
200 | llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os, |
201 | const test::CopyCount &value) { |
202 | return os << value.value; |
203 | } |
204 | |
205 | template <> |
206 | struct mlir::FieldParser<test::CopyCount> { |
207 | static FailureOr<test::CopyCount> parse(AsmParser &parser) { |
208 | std::string value; |
209 | if (parser.parseKeyword(keyword: value)) |
210 | return failure(); |
211 | return test::CopyCount(value); |
212 | } |
213 | }; |
214 | namespace test { |
215 | llvm::hash_code hash_value(const test::CopyCount ©Count) { |
216 | return llvm::hash_value(arg: copyCount.value); |
217 | } |
218 | } // namespace test |
219 | |
220 | //===----------------------------------------------------------------------===// |
221 | // TestConditionalAliasAttr |
222 | //===----------------------------------------------------------------------===// |
223 | |
224 | /// Attempt to parse the conditionally-aliased string attribute as a keyword or |
225 | /// string, else try to parse an alias. |
226 | static ParseResult parseConditionalAlias(AsmParser &p, StringAttr &value) { |
227 | std::string str; |
228 | if (succeeded(result: p.parseOptionalKeywordOrString(result: &str))) { |
229 | value = StringAttr::get(p.getContext(), str); |
230 | return success(); |
231 | } |
232 | return p.parseAttribute(result&: value); |
233 | } |
234 | |
235 | /// Print the string attribute as an alias if it has one, otherwise print it as |
236 | /// a keyword if possible. |
237 | static void printConditionalAlias(AsmPrinter &p, StringAttr value) { |
238 | if (succeeded(p.printAlias(value))) |
239 | return; |
240 | p.printKeywordOrString(keyword: value); |
241 | } |
242 | |
243 | //===----------------------------------------------------------------------===// |
244 | // Tablegen Generated Definitions |
245 | //===----------------------------------------------------------------------===// |
246 | |
247 | #include "TestAttrInterfaces.cpp.inc" |
248 | #include "TestOpEnums.cpp.inc" |
249 | #define GET_ATTRDEF_CLASSES |
250 | #include "TestAttrDefs.cpp.inc" |
251 | |
252 | //===----------------------------------------------------------------------===// |
253 | // Dynamic Attributes |
254 | //===----------------------------------------------------------------------===// |
255 | |
256 | /// Define a singleton dynamic attribute. |
257 | static std::unique_ptr<DynamicAttrDefinition> |
258 | getDynamicSingletonAttr(TestDialect *testDialect) { |
259 | return DynamicAttrDefinition::get( |
260 | "dynamic_singleton" , testDialect, |
261 | [](function_ref<InFlightDiagnostic()> emitError, |
262 | ArrayRef<Attribute> args) { |
263 | if (!args.empty()) { |
264 | emitError() << "expected 0 attribute arguments, but had " |
265 | << args.size(); |
266 | return failure(); |
267 | } |
268 | return success(); |
269 | }); |
270 | } |
271 | |
272 | /// Define a dynamic attribute representing a pair or attributes. |
273 | static std::unique_ptr<DynamicAttrDefinition> |
274 | getDynamicPairAttr(TestDialect *testDialect) { |
275 | return DynamicAttrDefinition::get( |
276 | "dynamic_pair" , testDialect, |
277 | [](function_ref<InFlightDiagnostic()> emitError, |
278 | ArrayRef<Attribute> args) { |
279 | if (args.size() != 2) { |
280 | emitError() << "expected 2 attribute arguments, but had " |
281 | << args.size(); |
282 | return failure(); |
283 | } |
284 | return success(); |
285 | }); |
286 | } |
287 | |
288 | static std::unique_ptr<DynamicAttrDefinition> |
289 | getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) { |
290 | auto verifier = [](function_ref<InFlightDiagnostic()> emitError, |
291 | ArrayRef<Attribute> args) { |
292 | if (args.size() != 2) { |
293 | emitError() << "expected 2 attribute arguments, but had " << args.size(); |
294 | return failure(); |
295 | } |
296 | return success(); |
297 | }; |
298 | |
299 | auto parser = [](AsmParser &parser, |
300 | llvm::SmallVectorImpl<Attribute> &parsedParams) { |
301 | Attribute leftAttr, rightAttr; |
302 | if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) || |
303 | parser.parseColon() || parser.parseAttribute(result&: rightAttr) || |
304 | parser.parseGreater()) |
305 | return failure(); |
306 | parsedParams.push_back(Elt: leftAttr); |
307 | parsedParams.push_back(Elt: rightAttr); |
308 | return success(); |
309 | }; |
310 | |
311 | auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { |
312 | printer << "<" << params[0] << ":" << params[1] << ">" ; |
313 | }; |
314 | |
315 | return DynamicAttrDefinition::get("dynamic_custom_assembly_format" , |
316 | testDialect, std::move(verifier), |
317 | std::move(parser), std::move(printer)); |
318 | } |
319 | |
320 | //===----------------------------------------------------------------------===// |
321 | // TestDialect |
322 | //===----------------------------------------------------------------------===// |
323 | |
324 | void TestDialect::registerAttributes() { |
325 | addAttributes< |
326 | #define GET_ATTRDEF_LIST |
327 | #include "TestAttrDefs.cpp.inc" |
328 | >(); |
329 | registerDynamicAttr(getDynamicSingletonAttr(this)); |
330 | registerDynamicAttr(getDynamicPairAttr(this)); |
331 | registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this)); |
332 | } |
333 | |