1//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains attributes defined by the TestDialect for testing various
10// features of MLIR.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestAttributes.h"
15#include "TestDialect.h"
16#include "TestTypes.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/DialectImplementation.h"
20#include "mlir/IR/ExtensibleDialect.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/Types.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/Hashing.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/ErrorHandling.h"
29#include "llvm/Support/raw_ostream.h"
30
31using namespace mlir;
32using namespace test;
33
34//===----------------------------------------------------------------------===//
35// CompoundAAttr
36//===----------------------------------------------------------------------===//
37
38Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
39 int widthOfSomething;
40 Type oneType;
41 SmallVector<int, 4> arrayOfInts;
42 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
43 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
44 parser.parseLSquare())
45 return Attribute();
46
47 int intVal;
48 while (!*parser.parseOptionalInteger(intVal)) {
49 arrayOfInts.push_back(intVal);
50 if (parser.parseOptionalComma())
51 break;
52 }
53
54 if (parser.parseRSquare() || parser.parseGreater())
55 return Attribute();
56 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
57}
58
59void CompoundAAttr::print(AsmPrinter &printer) const {
60 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
61 llvm::interleaveComma(getArrayOfInts(), printer);
62 printer << "]>";
63}
64
65//===----------------------------------------------------------------------===//
66// CompoundAAttr
67//===----------------------------------------------------------------------===//
68
69Attribute TestDecimalShapeAttr::parse(AsmParser &parser, Type type) {
70 if (parser.parseLess()) {
71 return Attribute();
72 }
73 SmallVector<int64_t> shape;
74 if (parser.parseOptionalGreater()) {
75 auto parseDecimal = [&]() {
76 shape.emplace_back();
77 auto parseResult = parser.parseOptionalDecimalInteger(shape.back());
78 if (!parseResult.has_value() || failed(*parseResult)) {
79 parser.emitError(parser.getCurrentLocation()) << "expected an integer";
80 return failure();
81 }
82 return success();
83 };
84 if (failed(parseDecimal())) {
85 return Attribute();
86 }
87 while (failed(parser.parseOptionalGreater())) {
88 if (failed(parser.parseXInDimensionList()) || failed(parseDecimal())) {
89 return Attribute();
90 }
91 }
92 }
93 return get(parser.getContext(), shape);
94}
95
96void TestDecimalShapeAttr::print(AsmPrinter &printer) const {
97 printer << "<";
98 llvm::interleave(getShape(), printer, "x");
99 printer << ">";
100}
101
102Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
103 SmallVector<uint64_t> elements;
104 if (parser.parseLess() || parser.parseLSquare())
105 return Attribute();
106 uint64_t intVal;
107 while (succeeded(*parser.parseOptionalInteger(intVal))) {
108 elements.push_back(intVal);
109 if (parser.parseOptionalComma())
110 break;
111 }
112
113 if (parser.parseRSquare() || parser.parseGreater())
114 return Attribute();
115 return parser.getChecked<TestI64ElementsAttr>(
116 parser.getContext(), llvm::cast<ShapedType>(type), elements);
117}
118
119void TestI64ElementsAttr::print(AsmPrinter &printer) const {
120 printer << "<[";
121 llvm::interleaveComma(getElements(), printer);
122 printer << "]>";
123}
124
125LogicalResult
126TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
127 ShapedType type, ArrayRef<uint64_t> elements) {
128 if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
129 return emitError()
130 << "number of elements does not match the provided shape type, got: "
131 << elements.size() << ", but expected: " << type.getNumElements();
132 }
133 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
134 return emitError() << "expected single rank 64-bit shape type, but got: "
135 << type;
136 return success();
137}
138
139LogicalResult TestAttrWithFormatAttr::verify(
140 function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
141 IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six,
142 ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
143 if (four.size() != static_cast<unsigned>(one))
144 return emitError() << "expected 'one' to equal 'four.size()'";
145 return success();
146}
147
148//===----------------------------------------------------------------------===//
149// Utility Functions for Generated Attributes
150//===----------------------------------------------------------------------===//
151
152static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
153 SmallVector<int> ints;
154 if (parser.parseLSquare() || parser.parseCommaSeparatedList(parseElementFn: [&]() {
155 ints.push_back(Elt: 0);
156 return parser.parseInteger(result&: ints.back());
157 }) ||
158 parser.parseRSquare())
159 return failure();
160 return ints;
161}
162
163static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
164 printer << '[';
165 llvm::interleaveComma(c: ints, os&: printer);
166 printer << ']';
167}
168
169//===----------------------------------------------------------------------===//
170// TestSubElementsAccessAttr
171//===----------------------------------------------------------------------===//
172
173Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
174 ::mlir::Type type) {
175 Attribute first, second, third;
176 if (parser.parseLess() || parser.parseAttribute(first) ||
177 parser.parseComma() || parser.parseAttribute(second) ||
178 parser.parseComma() || parser.parseAttribute(third) ||
179 parser.parseGreater()) {
180 return {};
181 }
182 return get(parser.getContext(), first, second, third);
183}
184
185void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
186 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
187 << ">";
188}
189
190//===----------------------------------------------------------------------===//
191// TestExtern1DI64ElementsAttr
192//===----------------------------------------------------------------------===//
193
194ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
195 if (auto *blob = getHandle().getBlob())
196 return blob->getDataAs<uint64_t>();
197 return std::nullopt;
198}
199
200//===----------------------------------------------------------------------===//
201// TestCustomAnchorAttr
202//===----------------------------------------------------------------------===//
203
204static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) {
205 bool b;
206 if (p.parseInteger(result&: b))
207 return failure();
208 result = b;
209 return success();
210}
211
212static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
213 p << (*result ? "true" : "false");
214}
215
216//===----------------------------------------------------------------------===//
217// CopyCountAttr Implementation
218//===----------------------------------------------------------------------===//
219
220CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
221 CopyCount::counter++;
222}
223
224CopyCount &CopyCount::operator=(const CopyCount &rhs) {
225 CopyCount::counter++;
226 value = rhs.value;
227 return *this;
228}
229
230int CopyCount::counter;
231
232static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
233 return lhs.value == rhs.value;
234}
235
236llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
237 const test::CopyCount &value) {
238 return os << value.value;
239}
240
241template <>
242struct mlir::FieldParser<test::CopyCount> {
243 static FailureOr<test::CopyCount> parse(AsmParser &parser) {
244 std::string value;
245 if (parser.parseKeyword(keyword: value))
246 return failure();
247 return test::CopyCount(value);
248 }
249};
250namespace test {
251llvm::hash_code hash_value(const test::CopyCount &copyCount) {
252 return llvm::hash_value(arg: copyCount.value);
253}
254} // namespace test
255
256//===----------------------------------------------------------------------===//
257// TestConditionalAliasAttr
258//===----------------------------------------------------------------------===//
259
260/// Attempt to parse the conditionally-aliased string attribute as a keyword or
261/// string, else try to parse an alias.
262static ParseResult parseConditionalAlias(AsmParser &p, StringAttr &value) {
263 std::string str;
264 if (succeeded(Result: p.parseOptionalKeywordOrString(result: &str))) {
265 value = StringAttr::get(p.getContext(), str);
266 return success();
267 }
268 return p.parseAttribute(result&: value);
269}
270
271/// Print the string attribute as an alias if it has one, otherwise print it as
272/// a keyword if possible.
273static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
274 if (succeeded(p.printAlias(value)))
275 return;
276 p.printKeywordOrString(keyword: value);
277}
278
279//===----------------------------------------------------------------------===//
280// Custom Float Attribute
281//===----------------------------------------------------------------------===//
282
283static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
284 APFloat value) {
285 p << typeStrAttr << " : " << value;
286}
287
288static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
289 FailureOr<APFloat> &value) {
290
291 std::string str;
292 if (p.parseString(string: &str))
293 return failure();
294
295 typeStrAttr = StringAttr::get(p.getContext(), str);
296
297 if (p.parseColon())
298 return failure();
299
300 const llvm::fltSemantics *semantics;
301 if (str == "float")
302 semantics = &llvm::APFloat::IEEEsingle();
303 else if (str == "double")
304 semantics = &llvm::APFloat::IEEEdouble();
305 else if (str == "fp80")
306 semantics = &llvm::APFloat::x87DoubleExtended();
307 else
308 return p.emitError(loc: p.getCurrentLocation(), message: "unknown float type, expected "
309 "'float', 'double' or 'fp80'");
310
311 APFloat parsedValue(0.0);
312 if (p.parseFloat(semantics: *semantics, result&: parsedValue))
313 return failure();
314
315 value.emplace(args&: parsedValue);
316 return success();
317}
318
319//===----------------------------------------------------------------------===//
320// TestCustomStructAttr
321//===----------------------------------------------------------------------===//
322
323static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
324 if (ShapedType::isDynamic(value)) {
325 p << "?";
326 } else {
327 p.printStrippedAttrOrType(attrOrType: value);
328 }
329}
330
331static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
332 if (succeeded(Result: p.parseOptionalQuestion())) {
333 value = ShapedType::kDynamic;
334 return success();
335 }
336 return p.parseInteger(result&: value);
337}
338
339static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
340 if (attr && attr.size() == 1 && isa<IntegerAttr>(attr[0])) {
341 p << cast<IntegerAttr>(attr[0]).getInt();
342 } else {
343 p.printStrippedAttrOrType(attr);
344 }
345}
346
347static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
348 ArrayAttr &attr) {
349 int64_t value;
350 OptionalParseResult result = p.parseOptionalInteger(result&: value);
351 if (result.has_value()) {
352 if (failed(Result: result.value()))
353 return failure();
354 attr = ArrayAttr::get(
355 p.getContext(),
356 {IntegerAttr::get(IntegerType::get(p.getContext(), 64), value)});
357 return success();
358 }
359 return p.parseAttribute(result&: attr);
360}
361
362//===----------------------------------------------------------------------===//
363// TestOpAsmAttrInterfaceAttr
364//===----------------------------------------------------------------------===//
365
366::mlir::OpAsmDialectInterface::AliasResult
367TestOpAsmAttrInterfaceAttr::getAlias(::llvm::raw_ostream &os) const {
368 os << "op_asm_attr_interface_";
369 os << getValue().getValue();
370 return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
371}
372
373//===----------------------------------------------------------------------===//
374// TestConstMemorySpaceAttr
375//===----------------------------------------------------------------------===//
376
377bool TestConstMemorySpaceAttr::isValidLoad(
378 Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
379 function_ref<InFlightDiagnostic()> emitError) const {
380 return true;
381}
382
383bool TestConstMemorySpaceAttr::isValidStore(
384 Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment,
385 function_ref<InFlightDiagnostic()> emitError) const {
386 if (emitError)
387 emitError() << "memory space is read-only";
388 return false;
389}
390
391bool TestConstMemorySpaceAttr::isValidAtomicOp(
392 mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering,
393 IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const {
394 if (emitError)
395 emitError() << "memory space is read-only";
396 return false;
397}
398
399bool TestConstMemorySpaceAttr::isValidAtomicXchg(
400 Type type, mlir::ptr::AtomicOrdering successOrdering,
401 mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment,
402 function_ref<InFlightDiagnostic()> emitError) const {
403 if (emitError)
404 emitError() << "memory space is read-only";
405 return false;
406}
407
408bool TestConstMemorySpaceAttr::isValidAddrSpaceCast(
409 Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const {
410 if (emitError)
411 emitError() << "memory space doesn't allow addrspace casts";
412 return false;
413}
414
415bool TestConstMemorySpaceAttr::isValidPtrIntCast(
416 Type intLikeTy, Type ptrLikeTy,
417 function_ref<InFlightDiagnostic()> emitError) const {
418 if (emitError)
419 emitError() << "memory space doesn't allow int-ptr casts";
420 return false;
421}
422
423//===----------------------------------------------------------------------===//
424// Tablegen Generated Definitions
425//===----------------------------------------------------------------------===//
426
427#include "TestAttrInterfaces.cpp.inc"
428#include "TestOpEnums.cpp.inc"
429#define GET_ATTRDEF_CLASSES
430#include "TestAttrDefs.cpp.inc"
431
432//===----------------------------------------------------------------------===//
433// Dynamic Attributes
434//===----------------------------------------------------------------------===//
435
436/// Define a singleton dynamic attribute.
437static std::unique_ptr<DynamicAttrDefinition>
438getDynamicSingletonAttr(TestDialect *testDialect) {
439 return DynamicAttrDefinition::get(
440 "dynamic_singleton", testDialect,
441 [](function_ref<InFlightDiagnostic()> emitError,
442 ArrayRef<Attribute> args) {
443 if (!args.empty()) {
444 emitError() << "expected 0 attribute arguments, but had "
445 << args.size();
446 return failure();
447 }
448 return success();
449 });
450}
451
452/// Define a dynamic attribute representing a pair or attributes.
453static std::unique_ptr<DynamicAttrDefinition>
454getDynamicPairAttr(TestDialect *testDialect) {
455 return DynamicAttrDefinition::get(
456 "dynamic_pair", testDialect,
457 [](function_ref<InFlightDiagnostic()> emitError,
458 ArrayRef<Attribute> args) {
459 if (args.size() != 2) {
460 emitError() << "expected 2 attribute arguments, but had "
461 << args.size();
462 return failure();
463 }
464 return success();
465 });
466}
467
468static std::unique_ptr<DynamicAttrDefinition>
469getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
470 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
471 ArrayRef<Attribute> args) {
472 if (args.size() != 2) {
473 emitError() << "expected 2 attribute arguments, but had " << args.size();
474 return failure();
475 }
476 return success();
477 };
478
479 auto parser = [](AsmParser &parser,
480 llvm::SmallVectorImpl<Attribute> &parsedParams) {
481 Attribute leftAttr, rightAttr;
482 if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) ||
483 parser.parseColon() || parser.parseAttribute(result&: rightAttr) ||
484 parser.parseGreater())
485 return failure();
486 parsedParams.push_back(Elt: leftAttr);
487 parsedParams.push_back(Elt: rightAttr);
488 return success();
489 };
490
491 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
492 printer << "<" << params[0] << ":" << params[1] << ">";
493 };
494
495 return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
496 testDialect, std::move(verifier),
497 std::move(parser), std::move(printer));
498}
499
500//===----------------------------------------------------------------------===//
501// SlashAttr
502//===----------------------------------------------------------------------===//
503
504Attribute SlashAttr::parse(AsmParser &parser, Type type) {
505 int lhs, rhs;
506
507 if (parser.parseLess() || parser.parseInteger(lhs) || parser.parseSlash() ||
508 parser.parseInteger(rhs) || parser.parseGreater())
509 return Attribute();
510
511 return SlashAttr::get(parser.getContext(), lhs, rhs);
512}
513
514void SlashAttr::print(AsmPrinter &printer) const {
515 printer << "<" << getLhs() << " / " << getRhs() << ">";
516}
517
518//===----------------------------------------------------------------------===//
519// TestDialect
520//===----------------------------------------------------------------------===//
521
522void TestDialect::registerAttributes() {
523 addAttributes<
524#define GET_ATTRDEF_LIST
525#include "TestAttrDefs.cpp.inc"
526 >();
527 registerDynamicAttr(getDynamicSingletonAttr(this));
528 registerDynamicAttr(getDynamicPairAttr(this));
529 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
530}
531

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Test/TestAttributes.cpp