1 | //===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains attributes defined by the TestDialect for testing various |
10 | // features of MLIR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestAttributes.h" |
15 | #include "TestDialect.h" |
16 | #include "TestTypes.h" |
17 | #include "mlir/IR/Attributes.h" |
18 | #include "mlir/IR/Builders.h" |
19 | #include "mlir/IR/DialectImplementation.h" |
20 | #include "mlir/IR/ExtensibleDialect.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "mlir/IR/Types.h" |
23 | #include "llvm/ADT/APFloat.h" |
24 | #include "llvm/ADT/Hashing.h" |
25 | #include "llvm/ADT/StringExtras.h" |
26 | #include "llvm/ADT/TypeSwitch.h" |
27 | #include "llvm/ADT/bit.h" |
28 | #include "llvm/Support/ErrorHandling.h" |
29 | #include "llvm/Support/raw_ostream.h" |
30 | |
31 | using namespace mlir; |
32 | using namespace test; |
33 | |
34 | //===----------------------------------------------------------------------===// |
35 | // CompoundAAttr |
36 | //===----------------------------------------------------------------------===// |
37 | |
38 | Attribute CompoundAAttr::parse(AsmParser &parser, Type type) { |
39 | int widthOfSomething; |
40 | Type oneType; |
41 | SmallVector<int, 4> arrayOfInts; |
42 | if (parser.parseLess() || parser.parseInteger(widthOfSomething) || |
43 | parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || |
44 | parser.parseLSquare()) |
45 | return Attribute(); |
46 | |
47 | int intVal; |
48 | while (!*parser.parseOptionalInteger(intVal)) { |
49 | arrayOfInts.push_back(intVal); |
50 | if (parser.parseOptionalComma()) |
51 | break; |
52 | } |
53 | |
54 | if (parser.parseRSquare() || parser.parseGreater()) |
55 | return Attribute(); |
56 | return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); |
57 | } |
58 | |
59 | void CompoundAAttr::print(AsmPrinter &printer) const { |
60 | printer << "<"<< getWidthOfSomething() << ", "<< getOneType() << ", ["; |
61 | llvm::interleaveComma(getArrayOfInts(), printer); |
62 | printer << "]>"; |
63 | } |
64 | |
65 | //===----------------------------------------------------------------------===// |
66 | // CompoundAAttr |
67 | //===----------------------------------------------------------------------===// |
68 | |
69 | Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) { |
70 | if (parser.parseLess()) { |
71 | return Attribute(); |
72 | } |
73 | SmallVector<int64_t> shape; |
74 | if (parser.parseOptionalGreater()) { |
75 | auto parseDecimal = [&]() { |
76 | shape.emplace_back(); |
77 | auto parseResult = parser.parseOptionalDecimalInteger(shape.back()); |
78 | if (!parseResult.has_value() || failed(*parseResult)) { |
79 | parser.emitError(parser.getCurrentLocation()) << "expected an integer"; |
80 | return failure(); |
81 | } |
82 | return success(); |
83 | }; |
84 | if (failed(parseDecimal())) { |
85 | return Attribute(); |
86 | } |
87 | while (failed(parser.parseOptionalGreater())) { |
88 | if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) { |
89 | return Attribute(); |
90 | } |
91 | } |
92 | } |
93 | return get(parser.getContext(), shape); |
94 | } |
95 | |
96 | void TestDecimalShapeAttr::print(AsmPrinter &printer) const { |
97 | printer << "<"; |
98 | llvm::interleave(getShape(), printer, "x"); |
99 | printer << ">"; |
100 | } |
101 | |
102 | Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { |
103 | SmallVector<uint64_t> elements; |
104 | if (parser.parseLess() || parser.parseLSquare()) |
105 | return Attribute(); |
106 | uint64_t intVal; |
107 | while (succeeded(*parser.parseOptionalInteger(intVal))) { |
108 | elements.push_back(intVal); |
109 | if (parser.parseOptionalComma()) |
110 | break; |
111 | } |
112 | |
113 | if (parser.parseRSquare() || parser.parseGreater()) |
114 | return Attribute(); |
115 | return parser.getChecked<TestI64ElementsAttr>( |
116 | parser.getContext(), llvm::cast<ShapedType>(type), elements); |
117 | } |
118 | |
119 | void TestI64ElementsAttr::print(AsmPrinter &printer) const { |
120 | printer << "<["; |
121 | llvm::interleaveComma(getElements(), printer); |
122 | printer << "]>"; |
123 | } |
124 | |
125 | LogicalResult |
126 | TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
127 | ShapedType type, ArrayRef<uint64_t> elements) { |
128 | if (type.getNumElements() != static_cast<int64_t>(elements.size())) { |
129 | return emitError() |
130 | << "number of elements does not match the provided shape type, got: " |
131 | << elements.size() << ", but expected: "<< type.getNumElements(); |
132 | } |
133 | if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64)) |
134 | return emitError() << "expected single rank 64-bit shape type, but got: " |
135 | << type; |
136 | return success(); |
137 | } |
138 | |
139 | LogicalResult TestAttrWithFormatAttr::verify( |
140 | function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two, |
141 | IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six, |
142 | ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) { |
143 | if (four.size() != static_cast<unsigned>(one)) |
144 | return emitError() << "expected 'one' to equal 'four.size()'"; |
145 | return success(); |
146 | } |
147 | |
148 | //===----------------------------------------------------------------------===// |
149 | // Utility Functions for Generated Attributes |
150 | //===----------------------------------------------------------------------===// |
151 | |
152 | static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) { |
153 | SmallVector<int> ints; |
154 | if (parser.parseLSquare() || parser.parseCommaSeparatedList(parseElementFn: [&]() { |
155 | ints.push_back(Elt: 0); |
156 | return parser.parseInteger(result&: ints.back()); |
157 | }) || |
158 | parser.parseRSquare()) |
159 | return failure(); |
160 | return ints; |
161 | } |
162 | |
163 | static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) { |
164 | printer << '['; |
165 | llvm::interleaveComma(c: ints, os&: printer); |
166 | printer << ']'; |
167 | } |
168 | |
169 | //===----------------------------------------------------------------------===// |
170 | // TestSubElementsAccessAttr |
171 | //===----------------------------------------------------------------------===// |
172 | |
173 | Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser, |
174 | ::mlir::Type type) { |
175 | Attribute first, second, third; |
176 | if (parser.parseLess() || parser.parseAttribute(first) || |
177 | parser.parseComma() || parser.parseAttribute(second) || |
178 | parser.parseComma() || parser.parseAttribute(third) || |
179 | parser.parseGreater()) { |
180 | return {}; |
181 | } |
182 | return get(parser.getContext(), first, second, third); |
183 | } |
184 | |
185 | void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const { |
186 | printer << "<"<< getFirst() << ", "<< getSecond() << ", "<< getThird() |
187 | << ">"; |
188 | } |
189 | |
190 | //===----------------------------------------------------------------------===// |
191 | // TestExtern1DI64ElementsAttr |
192 | //===----------------------------------------------------------------------===// |
193 | |
194 | ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const { |
195 | if (auto *blob = getHandle().getBlob()) |
196 | return blob->getDataAs<uint64_t>(); |
197 | return std::nullopt; |
198 | } |
199 | |
200 | //===----------------------------------------------------------------------===// |
201 | // TestCustomAnchorAttr |
202 | //===----------------------------------------------------------------------===// |
203 | |
204 | static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) { |
205 | bool b; |
206 | if (p.parseInteger(result&: b)) |
207 | return failure(); |
208 | result = b; |
209 | return success(); |
210 | } |
211 | |
212 | static void printTrueFalse(AsmPrinter &p, std::optional<int> result) { |
213 | p << (*result ? "true": "false"); |
214 | } |
215 | |
216 | //===----------------------------------------------------------------------===// |
217 | // CopyCountAttr Implementation |
218 | //===----------------------------------------------------------------------===// |
219 | |
220 | CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) { |
221 | CopyCount::counter++; |
222 | } |
223 | |
224 | CopyCount &CopyCount::operator=(const CopyCount &rhs) { |
225 | CopyCount::counter++; |
226 | value = rhs.value; |
227 | return *this; |
228 | } |
229 | |
230 | int CopyCount::counter; |
231 | |
232 | static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) { |
233 | return lhs.value == rhs.value; |
234 | } |
235 | |
236 | llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os, |
237 | const test::CopyCount &value) { |
238 | return os << value.value; |
239 | } |
240 | |
241 | template <> |
242 | struct mlir::FieldParser<test::CopyCount> { |
243 | static FailureOr<test::CopyCount> parse(AsmParser &parser) { |
244 | std::string value; |
245 | if (parser.parseKeyword(keyword: value)) |
246 | return failure(); |
247 | return test::CopyCount(value); |
248 | } |
249 | }; |
250 | namespace test { |
251 | llvm::hash_code hash_value(const test::CopyCount ©Count) { |
252 | return llvm::hash_value(arg: copyCount.value); |
253 | } |
254 | } // namespace test |
255 | |
256 | //===----------------------------------------------------------------------===// |
257 | // TestConditionalAliasAttr |
258 | //===----------------------------------------------------------------------===// |
259 | |
260 | /// Attempt to parse the conditionally-aliased string attribute as a keyword or |
261 | /// string, else try to parse an alias. |
262 | static ParseResult parseConditionalAlias(AsmParser &p, StringAttr &value) { |
263 | std::string str; |
264 | if (succeeded(Result: p.parseOptionalKeywordOrString(result: &str))) { |
265 | value = StringAttr::get(p.getContext(), str); |
266 | return success(); |
267 | } |
268 | return p.parseAttribute(result&: value); |
269 | } |
270 | |
271 | /// Print the string attribute as an alias if it has one, otherwise print it as |
272 | /// a keyword if possible. |
273 | static void printConditionalAlias(AsmPrinter &p, StringAttr value) { |
274 | if (succeeded(p.printAlias(value))) |
275 | return; |
276 | p.printKeywordOrString(keyword: value); |
277 | } |
278 | |
279 | //===----------------------------------------------------------------------===// |
280 | // Custom Float Attribute |
281 | //===----------------------------------------------------------------------===// |
282 | |
283 | static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr, |
284 | APFloat value) { |
285 | p << typeStrAttr << " : "<< value; |
286 | } |
287 | |
288 | static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr, |
289 | FailureOr<APFloat> &value) { |
290 | |
291 | std::string str; |
292 | if (p.parseString(string: &str)) |
293 | return failure(); |
294 | |
295 | typeStrAttr = StringAttr::get(p.getContext(), str); |
296 | |
297 | if (p.parseColon()) |
298 | return failure(); |
299 | |
300 | const llvm::fltSemantics *semantics; |
301 | if (str == "float") |
302 | semantics = &llvm::APFloat::IEEEsingle(); |
303 | else if (str == "double") |
304 | semantics = &llvm::APFloat::IEEEdouble(); |
305 | else if (str == "fp80") |
306 | semantics = &llvm::APFloat::x87DoubleExtended(); |
307 | else |
308 | return p.emitError(loc: p.getCurrentLocation(), message: "unknown float type, expected " |
309 | "'float', 'double' or 'fp80'"); |
310 | |
311 | APFloat parsedValue(0.0); |
312 | if (p.parseFloat(semantics: *semantics, result&: parsedValue)) |
313 | return failure(); |
314 | |
315 | value.emplace(args&: parsedValue); |
316 | return success(); |
317 | } |
318 | |
319 | //===----------------------------------------------------------------------===// |
320 | // TestCustomStructAttr |
321 | //===----------------------------------------------------------------------===// |
322 | |
323 | static void printCustomStructAttr(AsmPrinter &p, int64_t value) { |
324 | if (ShapedType::isDynamic(value)) { |
325 | p << "?"; |
326 | } else { |
327 | p.printStrippedAttrOrType(attrOrType: value); |
328 | } |
329 | } |
330 | |
331 | static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) { |
332 | if (succeeded(Result: p.parseOptionalQuestion())) { |
333 | value = ShapedType::kDynamic; |
334 | return success(); |
335 | } |
336 | return p.parseInteger(result&: value); |
337 | } |
338 | |
339 | static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) { |
340 | if (attr && attr.size() == 1 && isa<IntegerAttr>(attr[0])) { |
341 | p << cast<IntegerAttr>(attr[0]).getInt(); |
342 | } else { |
343 | p.printStrippedAttrOrType(attr); |
344 | } |
345 | } |
346 | |
347 | static ParseResult parseCustomOptStructFieldAttr(AsmParser &p, |
348 | ArrayAttr &attr) { |
349 | int64_t value; |
350 | OptionalParseResult result = p.parseOptionalInteger(result&: value); |
351 | if (result.has_value()) { |
352 | if (failed(Result: result.value())) |
353 | return failure(); |
354 | attr = ArrayAttr::get( |
355 | p.getContext(), |
356 | {IntegerAttr::get(IntegerType::get(p.getContext(), 64), value)}); |
357 | return success(); |
358 | } |
359 | return p.parseAttribute(result&: attr); |
360 | } |
361 | |
362 | //===----------------------------------------------------------------------===// |
363 | // TestOpAsmAttrInterfaceAttr |
364 | //===----------------------------------------------------------------------===// |
365 | |
366 | ::mlir::OpAsmDialectInterface::AliasResult |
367 | TestOpAsmAttrInterfaceAttr::getAlias(::llvm::raw_ostream &os) const { |
368 | os << "op_asm_attr_interface_"; |
369 | os << getValue().getValue(); |
370 | return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias; |
371 | } |
372 | |
373 | //===----------------------------------------------------------------------===// |
374 | // TestConstMemorySpaceAttr |
375 | //===----------------------------------------------------------------------===// |
376 | |
377 | bool TestConstMemorySpaceAttr::isValidLoad( |
378 | Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, |
379 | function_ref<InFlightDiagnostic()> emitError) const { |
380 | return true; |
381 | } |
382 | |
383 | bool TestConstMemorySpaceAttr::isValidStore( |
384 | Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, |
385 | function_ref<InFlightDiagnostic()> emitError) const { |
386 | if (emitError) |
387 | emitError() << "memory space is read-only"; |
388 | return false; |
389 | } |
390 | |
391 | bool TestConstMemorySpaceAttr::isValidAtomicOp( |
392 | mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering, |
393 | IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const { |
394 | if (emitError) |
395 | emitError() << "memory space is read-only"; |
396 | return false; |
397 | } |
398 | |
399 | bool TestConstMemorySpaceAttr::isValidAtomicXchg( |
400 | Type type, mlir::ptr::AtomicOrdering successOrdering, |
401 | mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, |
402 | function_ref<InFlightDiagnostic()> emitError) const { |
403 | if (emitError) |
404 | emitError() << "memory space is read-only"; |
405 | return false; |
406 | } |
407 | |
408 | bool TestConstMemorySpaceAttr::isValidAddrSpaceCast( |
409 | Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const { |
410 | if (emitError) |
411 | emitError() << "memory space doesn't allow addrspace casts"; |
412 | return false; |
413 | } |
414 | |
415 | bool TestConstMemorySpaceAttr::isValidPtrIntCast( |
416 | Type intLikeTy, Type ptrLikeTy, |
417 | function_ref<InFlightDiagnostic()> emitError) const { |
418 | if (emitError) |
419 | emitError() << "memory space doesn't allow int-ptr casts"; |
420 | return false; |
421 | } |
422 | |
423 | //===----------------------------------------------------------------------===// |
424 | // Tablegen Generated Definitions |
425 | //===----------------------------------------------------------------------===// |
426 | |
427 | #include "TestAttrInterfaces.cpp.inc" |
428 | #include "TestOpEnums.cpp.inc" |
429 | #define GET_ATTRDEF_CLASSES |
430 | #include "TestAttrDefs.cpp.inc" |
431 | |
432 | //===----------------------------------------------------------------------===// |
433 | // Dynamic Attributes |
434 | //===----------------------------------------------------------------------===// |
435 | |
436 | /// Define a singleton dynamic attribute. |
437 | static std::unique_ptr<DynamicAttrDefinition> |
438 | getDynamicSingletonAttr(TestDialect *testDialect) { |
439 | return DynamicAttrDefinition::get( |
440 | "dynamic_singleton", testDialect, |
441 | [](function_ref<InFlightDiagnostic()> emitError, |
442 | ArrayRef<Attribute> args) { |
443 | if (!args.empty()) { |
444 | emitError() << "expected 0 attribute arguments, but had " |
445 | << args.size(); |
446 | return failure(); |
447 | } |
448 | return success(); |
449 | }); |
450 | } |
451 | |
452 | /// Define a dynamic attribute representing a pair or attributes. |
453 | static std::unique_ptr<DynamicAttrDefinition> |
454 | getDynamicPairAttr(TestDialect *testDialect) { |
455 | return DynamicAttrDefinition::get( |
456 | "dynamic_pair", testDialect, |
457 | [](function_ref<InFlightDiagnostic()> emitError, |
458 | ArrayRef<Attribute> args) { |
459 | if (args.size() != 2) { |
460 | emitError() << "expected 2 attribute arguments, but had " |
461 | << args.size(); |
462 | return failure(); |
463 | } |
464 | return success(); |
465 | }); |
466 | } |
467 | |
468 | static std::unique_ptr<DynamicAttrDefinition> |
469 | getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) { |
470 | auto verifier = [](function_ref<InFlightDiagnostic()> emitError, |
471 | ArrayRef<Attribute> args) { |
472 | if (args.size() != 2) { |
473 | emitError() << "expected 2 attribute arguments, but had "<< args.size(); |
474 | return failure(); |
475 | } |
476 | return success(); |
477 | }; |
478 | |
479 | auto parser = [](AsmParser &parser, |
480 | llvm::SmallVectorImpl<Attribute> &parsedParams) { |
481 | Attribute leftAttr, rightAttr; |
482 | if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) || |
483 | parser.parseColon() || parser.parseAttribute(result&: rightAttr) || |
484 | parser.parseGreater()) |
485 | return failure(); |
486 | parsedParams.push_back(Elt: leftAttr); |
487 | parsedParams.push_back(Elt: rightAttr); |
488 | return success(); |
489 | }; |
490 | |
491 | auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { |
492 | printer << "<"<< params[0] << ":"<< params[1] << ">"; |
493 | }; |
494 | |
495 | return DynamicAttrDefinition::get("dynamic_custom_assembly_format", |
496 | testDialect, std::move(verifier), |
497 | std::move(parser), std::move(printer)); |
498 | } |
499 | |
500 | //===----------------------------------------------------------------------===// |
501 | // SlashAttr |
502 | //===----------------------------------------------------------------------===// |
503 | |
504 | Attribute SlashAttr::parse(AsmParser &parser, Type type) { |
505 | int lhs, rhs; |
506 | |
507 | if (parser.parseLess() || parser.parseInteger(lhs) || parser.parseSlash() || |
508 | parser.parseInteger(rhs) || parser.parseGreater()) |
509 | return Attribute(); |
510 | |
511 | return SlashAttr::get(parser.getContext(), lhs, rhs); |
512 | } |
513 | |
514 | void SlashAttr::print(AsmPrinter &printer) const { |
515 | printer << "<"<< getLhs() << " / "<< getRhs() << ">"; |
516 | } |
517 | |
518 | //===----------------------------------------------------------------------===// |
519 | // TestDialect |
520 | //===----------------------------------------------------------------------===// |
521 | |
522 | void TestDialect::registerAttributes() { |
523 | addAttributes< |
524 | #define GET_ATTRDEF_LIST |
525 | #include "TestAttrDefs.cpp.inc" |
526 | >(); |
527 | registerDynamicAttr(getDynamicSingletonAttr(this)); |
528 | registerDynamicAttr(getDynamicPairAttr(this)); |
529 | registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this)); |
530 | } |
531 |
Definitions
- parseIntArray
- printIntArray
- parseTrueFalse
- printTrueFalse
- CopyCount
- operator=
- counter
- operator==
- operator<<
- FieldParser
- parse
- hash_value
- parseConditionalAlias
- printConditionalAlias
- printCustomFloatAttr
- parseCustomFloatAttr
- printCustomStructAttr
- parseCustomStructAttr
- printCustomOptStructFieldAttr
- parseCustomOptStructFieldAttr
- getDynamicSingletonAttr
- getDynamicPairAttr
Improve your Profiling and Debugging skills
Find out more