1//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains attributes defined by the TestDialect for testing various
10// features of MLIR.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestAttributes.h"
15#include "TestDialect.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/DialectImplementation.h"
18#include "mlir/IR/ExtensibleDialect.h"
19#include "mlir/IR/Types.h"
20#include "mlir/Support/LogicalResult.h"
21#include "llvm/ADT/Hashing.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/ADT/bit.h"
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/raw_ostream.h"
27
28using namespace mlir;
29using namespace test;
30
31//===----------------------------------------------------------------------===//
32// CompoundAAttr
33//===----------------------------------------------------------------------===//
34
35Attribute CompoundAAttr::parse(AsmParser &parser, Type type) {
36 int widthOfSomething;
37 Type oneType;
38 SmallVector<int, 4> arrayOfInts;
39 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
40 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
41 parser.parseLSquare())
42 return Attribute();
43
44 int intVal;
45 while (!*parser.parseOptionalInteger(intVal)) {
46 arrayOfInts.push_back(intVal);
47 if (parser.parseOptionalComma())
48 break;
49 }
50
51 if (parser.parseRSquare() || parser.parseGreater())
52 return Attribute();
53 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
54}
55
56void CompoundAAttr::print(AsmPrinter &printer) const {
57 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
58 llvm::interleaveComma(getArrayOfInts(), printer);
59 printer << "]>";
60}
61
62//===----------------------------------------------------------------------===//
63// CompoundAAttr
64//===----------------------------------------------------------------------===//
65
66Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
67 SmallVector<uint64_t> elements;
68 if (parser.parseLess() || parser.parseLSquare())
69 return Attribute();
70 uint64_t intVal;
71 while (succeeded(*parser.parseOptionalInteger(intVal))) {
72 elements.push_back(intVal);
73 if (parser.parseOptionalComma())
74 break;
75 }
76
77 if (parser.parseRSquare() || parser.parseGreater())
78 return Attribute();
79 return parser.getChecked<TestI64ElementsAttr>(
80 parser.getContext(), llvm::cast<ShapedType>(type), elements);
81}
82
83void TestI64ElementsAttr::print(AsmPrinter &printer) const {
84 printer << "<[";
85 llvm::interleaveComma(getElements(), printer);
86 printer << "]>";
87}
88
89LogicalResult
90TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
91 ShapedType type, ArrayRef<uint64_t> elements) {
92 if (type.getNumElements() != static_cast<int64_t>(elements.size())) {
93 return emitError()
94 << "number of elements does not match the provided shape type, got: "
95 << elements.size() << ", but expected: " << type.getNumElements();
96 }
97 if (type.getRank() != 1 || !type.getElementType().isSignlessInteger(64))
98 return emitError() << "expected single rank 64-bit shape type, but got: "
99 << type;
100 return success();
101}
102
103LogicalResult TestAttrWithFormatAttr::verify(
104 function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
105 IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six,
106 ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
107 if (four.size() != static_cast<unsigned>(one))
108 return emitError() << "expected 'one' to equal 'four.size()'";
109 return success();
110}
111
112//===----------------------------------------------------------------------===//
113// Utility Functions for Generated Attributes
114//===----------------------------------------------------------------------===//
115
116static FailureOr<SmallVector<int>> parseIntArray(AsmParser &parser) {
117 SmallVector<int> ints;
118 if (parser.parseLSquare() || parser.parseCommaSeparatedList(parseElementFn: [&]() {
119 ints.push_back(Elt: 0);
120 return parser.parseInteger(result&: ints.back());
121 }) ||
122 parser.parseRSquare())
123 return failure();
124 return ints;
125}
126
127static void printIntArray(AsmPrinter &printer, ArrayRef<int> ints) {
128 printer << '[';
129 llvm::interleaveComma(c: ints, os&: printer);
130 printer << ']';
131}
132
133//===----------------------------------------------------------------------===//
134// TestSubElementsAccessAttr
135//===----------------------------------------------------------------------===//
136
137Attribute TestSubElementsAccessAttr::parse(::mlir::AsmParser &parser,
138 ::mlir::Type type) {
139 Attribute first, second, third;
140 if (parser.parseLess() || parser.parseAttribute(first) ||
141 parser.parseComma() || parser.parseAttribute(second) ||
142 parser.parseComma() || parser.parseAttribute(third) ||
143 parser.parseGreater()) {
144 return {};
145 }
146 return get(parser.getContext(), first, second, third);
147}
148
149void TestSubElementsAccessAttr::print(::mlir::AsmPrinter &printer) const {
150 printer << "<" << getFirst() << ", " << getSecond() << ", " << getThird()
151 << ">";
152}
153
154//===----------------------------------------------------------------------===//
155// TestExtern1DI64ElementsAttr
156//===----------------------------------------------------------------------===//
157
158ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
159 if (auto *blob = getHandle().getBlob())
160 return blob->getDataAs<uint64_t>();
161 return std::nullopt;
162}
163
164//===----------------------------------------------------------------------===//
165// TestCustomAnchorAttr
166//===----------------------------------------------------------------------===//
167
168static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) {
169 bool b;
170 if (p.parseInteger(result&: b))
171 return failure();
172 result = b;
173 return success();
174}
175
176static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
177 p << (*result ? "true" : "false");
178}
179
180//===----------------------------------------------------------------------===//
181// CopyCountAttr Implementation
182//===----------------------------------------------------------------------===//
183
184CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
185 CopyCount::counter++;
186}
187
188CopyCount &CopyCount::operator=(const CopyCount &rhs) {
189 CopyCount::counter++;
190 value = rhs.value;
191 return *this;
192}
193
194int CopyCount::counter;
195
196static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
197 return lhs.value == rhs.value;
198}
199
200llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
201 const test::CopyCount &value) {
202 return os << value.value;
203}
204
205template <>
206struct mlir::FieldParser<test::CopyCount> {
207 static FailureOr<test::CopyCount> parse(AsmParser &parser) {
208 std::string value;
209 if (parser.parseKeyword(keyword: value))
210 return failure();
211 return test::CopyCount(value);
212 }
213};
214namespace test {
215llvm::hash_code hash_value(const test::CopyCount &copyCount) {
216 return llvm::hash_value(arg: copyCount.value);
217}
218} // namespace test
219
220//===----------------------------------------------------------------------===//
221// TestConditionalAliasAttr
222//===----------------------------------------------------------------------===//
223
224/// Attempt to parse the conditionally-aliased string attribute as a keyword or
225/// string, else try to parse an alias.
226static ParseResult parseConditionalAlias(AsmParser &p, StringAttr &value) {
227 std::string str;
228 if (succeeded(result: p.parseOptionalKeywordOrString(result: &str))) {
229 value = StringAttr::get(p.getContext(), str);
230 return success();
231 }
232 return p.parseAttribute(result&: value);
233}
234
235/// Print the string attribute as an alias if it has one, otherwise print it as
236/// a keyword if possible.
237static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
238 if (succeeded(p.printAlias(value)))
239 return;
240 p.printKeywordOrString(keyword: value);
241}
242
243//===----------------------------------------------------------------------===//
244// Tablegen Generated Definitions
245//===----------------------------------------------------------------------===//
246
247#include "TestAttrInterfaces.cpp.inc"
248#include "TestOpEnums.cpp.inc"
249#define GET_ATTRDEF_CLASSES
250#include "TestAttrDefs.cpp.inc"
251
252//===----------------------------------------------------------------------===//
253// Dynamic Attributes
254//===----------------------------------------------------------------------===//
255
256/// Define a singleton dynamic attribute.
257static std::unique_ptr<DynamicAttrDefinition>
258getDynamicSingletonAttr(TestDialect *testDialect) {
259 return DynamicAttrDefinition::get(
260 "dynamic_singleton", testDialect,
261 [](function_ref<InFlightDiagnostic()> emitError,
262 ArrayRef<Attribute> args) {
263 if (!args.empty()) {
264 emitError() << "expected 0 attribute arguments, but had "
265 << args.size();
266 return failure();
267 }
268 return success();
269 });
270}
271
272/// Define a dynamic attribute representing a pair or attributes.
273static std::unique_ptr<DynamicAttrDefinition>
274getDynamicPairAttr(TestDialect *testDialect) {
275 return DynamicAttrDefinition::get(
276 "dynamic_pair", testDialect,
277 [](function_ref<InFlightDiagnostic()> emitError,
278 ArrayRef<Attribute> args) {
279 if (args.size() != 2) {
280 emitError() << "expected 2 attribute arguments, but had "
281 << args.size();
282 return failure();
283 }
284 return success();
285 });
286}
287
288static std::unique_ptr<DynamicAttrDefinition>
289getDynamicCustomAssemblyFormatAttr(TestDialect *testDialect) {
290 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
291 ArrayRef<Attribute> args) {
292 if (args.size() != 2) {
293 emitError() << "expected 2 attribute arguments, but had " << args.size();
294 return failure();
295 }
296 return success();
297 };
298
299 auto parser = [](AsmParser &parser,
300 llvm::SmallVectorImpl<Attribute> &parsedParams) {
301 Attribute leftAttr, rightAttr;
302 if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) ||
303 parser.parseColon() || parser.parseAttribute(result&: rightAttr) ||
304 parser.parseGreater())
305 return failure();
306 parsedParams.push_back(Elt: leftAttr);
307 parsedParams.push_back(Elt: rightAttr);
308 return success();
309 };
310
311 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
312 printer << "<" << params[0] << ":" << params[1] << ">";
313 };
314
315 return DynamicAttrDefinition::get("dynamic_custom_assembly_format",
316 testDialect, std::move(verifier),
317 std::move(parser), std::move(printer));
318}
319
320//===----------------------------------------------------------------------===//
321// TestDialect
322//===----------------------------------------------------------------------===//
323
324void TestDialect::registerAttributes() {
325 addAttributes<
326#define GET_ATTRDEF_LIST
327#include "TestAttrDefs.cpp.inc"
328 >();
329 registerDynamicAttr(getDynamicSingletonAttr(this));
330 registerDynamicAttr(getDynamicPairAttr(this));
331 registerDynamicAttr(getDynamicCustomAssemblyFormatAttr(this));
332}
333

source code of mlir/test/lib/Dialect/Test/TestAttributes.cpp