1 | //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains types defined by the TestDialect for testing various |
10 | // features of MLIR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestTypes.h" |
15 | #include "TestDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/DialectImplementation.h" |
19 | #include "mlir/IR/ExtensibleDialect.h" |
20 | #include "mlir/IR/Types.h" |
21 | #include "llvm/ADT/Hashing.h" |
22 | #include "llvm/ADT/SetVector.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | #include "llvm/Support/TypeSize.h" |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace test; |
29 | |
30 | // Custom parser for SignednessSemantics. |
31 | static ParseResult |
32 | parseSignedness(AsmParser &parser, |
33 | TestIntegerType::SignednessSemantics &result) { |
34 | StringRef signStr; |
35 | auto loc = parser.getCurrentLocation(); |
36 | if (parser.parseKeyword(keyword: &signStr)) |
37 | return failure(); |
38 | if (signStr.equals_insensitive(RHS: "u" ) || signStr.equals_insensitive(RHS: "unsigned" )) |
39 | result = TestIntegerType::SignednessSemantics::Unsigned; |
40 | else if (signStr.equals_insensitive(RHS: "s" ) || |
41 | signStr.equals_insensitive(RHS: "signed" )) |
42 | result = TestIntegerType::SignednessSemantics::Signed; |
43 | else if (signStr.equals_insensitive(RHS: "n" ) || |
44 | signStr.equals_insensitive(RHS: "none" )) |
45 | result = TestIntegerType::SignednessSemantics::Signless; |
46 | else |
47 | return parser.emitError(loc, message: "expected signed, unsigned, or none" ); |
48 | return success(); |
49 | } |
50 | |
51 | // Custom printer for SignednessSemantics. |
52 | static void printSignedness(AsmPrinter &printer, |
53 | const TestIntegerType::SignednessSemantics &ss) { |
54 | switch (ss) { |
55 | case TestIntegerType::SignednessSemantics::Unsigned: |
56 | printer << "unsigned" ; |
57 | break; |
58 | case TestIntegerType::SignednessSemantics::Signed: |
59 | printer << "signed" ; |
60 | break; |
61 | case TestIntegerType::SignednessSemantics::Signless: |
62 | printer << "none" ; |
63 | break; |
64 | } |
65 | } |
66 | |
67 | // The functions don't need to be in the header file, but need to be in the mlir |
68 | // namespace. Declare them here, then define them immediately below. Separating |
69 | // the declaration and definition adheres to the LLVM coding standards. |
70 | namespace test { |
71 | // FieldInfo is used as part of a parameter, so equality comparison is |
72 | // compulsory. |
73 | static bool operator==(const FieldInfo &a, const FieldInfo &b); |
74 | // FieldInfo is used as part of a parameter, so a hash will be computed. |
75 | static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT |
76 | } // namespace test |
77 | |
78 | // FieldInfo is used as part of a parameter, so equality comparison is |
79 | // compulsory. |
80 | static bool test::operator==(const FieldInfo &a, const FieldInfo &b) { |
81 | return a.name == b.name && a.type == b.type; |
82 | } |
83 | |
84 | // FieldInfo is used as part of a parameter, so a hash will be computed. |
85 | static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT |
86 | return llvm::hash_combine(args: fi.name, args: fi.type); |
87 | } |
88 | |
89 | //===----------------------------------------------------------------------===// |
90 | // TestCustomType |
91 | //===----------------------------------------------------------------------===// |
92 | |
93 | static LogicalResult parseCustomTypeA(AsmParser &parser, int &aResult) { |
94 | return parser.parseInteger(result&: aResult); |
95 | } |
96 | |
97 | static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } |
98 | |
99 | static LogicalResult parseCustomTypeB(AsmParser &parser, int a, |
100 | std::optional<int> &bResult) { |
101 | if (a < 0) |
102 | return success(); |
103 | for (int i : llvm::seq(Begin: 0, End: a)) |
104 | if (failed(result: parser.parseInteger(result&: i))) |
105 | return failure(); |
106 | bResult.emplace(args: 0); |
107 | return parser.parseInteger(result&: *bResult); |
108 | } |
109 | |
110 | static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) { |
111 | if (a < 0) |
112 | return; |
113 | printer << ' '; |
114 | for (int i : llvm::seq(Begin: 0, End: a)) |
115 | printer << i << ' '; |
116 | printer << *b; |
117 | } |
118 | |
119 | static LogicalResult parseFooString(AsmParser &parser, std::string &foo) { |
120 | std::string result; |
121 | if (parser.parseString(string: &result)) |
122 | return failure(); |
123 | foo = std::move(result); |
124 | return success(); |
125 | } |
126 | |
127 | static void printFooString(AsmPrinter &printer, StringRef foo) { |
128 | printer << '"' << foo << '"'; |
129 | } |
130 | |
131 | static LogicalResult parseBarString(AsmParser &parser, StringRef foo) { |
132 | return parser.parseKeyword(keyword: foo); |
133 | } |
134 | |
135 | static void printBarString(AsmPrinter &printer, StringRef foo) { |
136 | printer << foo; |
137 | } |
138 | //===----------------------------------------------------------------------===// |
139 | // Tablegen Generated Definitions |
140 | //===----------------------------------------------------------------------===// |
141 | |
142 | #include "TestTypeInterfaces.cpp.inc" |
143 | #define GET_TYPEDEF_CLASSES |
144 | #include "TestTypeDefs.cpp.inc" |
145 | |
146 | //===----------------------------------------------------------------------===// |
147 | // CompoundAType |
148 | //===----------------------------------------------------------------------===// |
149 | |
150 | Type CompoundAType::parse(AsmParser &parser) { |
151 | int widthOfSomething; |
152 | Type oneType; |
153 | SmallVector<int, 4> arrayOfInts; |
154 | if (parser.parseLess() || parser.parseInteger(widthOfSomething) || |
155 | parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || |
156 | parser.parseLSquare()) |
157 | return Type(); |
158 | |
159 | int i; |
160 | while (!*parser.parseOptionalInteger(i)) { |
161 | arrayOfInts.push_back(i); |
162 | if (parser.parseOptionalComma()) |
163 | break; |
164 | } |
165 | |
166 | if (parser.parseRSquare() || parser.parseGreater()) |
167 | return Type(); |
168 | |
169 | return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); |
170 | } |
171 | void CompoundAType::print(AsmPrinter &printer) const { |
172 | printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [" ; |
173 | auto intArray = getArrayOfInts(); |
174 | llvm::interleaveComma(intArray, printer); |
175 | printer << "]>" ; |
176 | } |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // TestIntegerType |
180 | //===----------------------------------------------------------------------===// |
181 | |
182 | // Example type validity checker. |
183 | LogicalResult |
184 | TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError, |
185 | unsigned width, |
186 | TestIntegerType::SignednessSemantics ss) { |
187 | if (width > 8) |
188 | return failure(); |
189 | return success(); |
190 | } |
191 | |
192 | Type TestIntegerType::parse(AsmParser &parser) { |
193 | SignednessSemantics signedness; |
194 | int width; |
195 | if (parser.parseLess() || parseSignedness(parser, signedness) || |
196 | parser.parseComma() || parser.parseInteger(width) || |
197 | parser.parseGreater()) |
198 | return Type(); |
199 | Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); |
200 | return getChecked(loc, loc.getContext(), width, signedness); |
201 | } |
202 | |
203 | void TestIntegerType::print(AsmPrinter &p) const { |
204 | p << "<" ; |
205 | printSignedness(p, getSignedness()); |
206 | p << ", " << getWidth() << ">" ; |
207 | } |
208 | |
209 | //===----------------------------------------------------------------------===// |
210 | // TestStructType |
211 | //===----------------------------------------------------------------------===// |
212 | |
213 | Type StructType::parse(AsmParser &p) { |
214 | SmallVector<FieldInfo, 4> parameters; |
215 | if (p.parseLess()) |
216 | return Type(); |
217 | while (succeeded(p.parseOptionalLBrace())) { |
218 | Type type; |
219 | StringRef name; |
220 | if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) || |
221 | p.parseRBrace()) |
222 | return Type(); |
223 | parameters.push_back(FieldInfo{name, type}); |
224 | if (p.parseOptionalComma()) |
225 | break; |
226 | } |
227 | if (p.parseGreater()) |
228 | return Type(); |
229 | return get(p.getContext(), parameters); |
230 | } |
231 | |
232 | void StructType::print(AsmPrinter &p) const { |
233 | p << "<" ; |
234 | llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) { |
235 | p << "{" << field.name << "," << field.type << "}" ; |
236 | }); |
237 | p << ">" ; |
238 | } |
239 | |
240 | //===----------------------------------------------------------------------===// |
241 | // TestType |
242 | //===----------------------------------------------------------------------===// |
243 | |
244 | void TestType::printTypeC(Location loc) const { |
245 | emitRemark(loc) << *this << " - TestC" ; |
246 | } |
247 | |
248 | //===----------------------------------------------------------------------===// |
249 | // TestTypeWithLayout |
250 | //===----------------------------------------------------------------------===// |
251 | |
252 | Type TestTypeWithLayoutType::parse(AsmParser &parser) { |
253 | unsigned val; |
254 | if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater()) |
255 | return Type(); |
256 | return TestTypeWithLayoutType::get(parser.getContext(), val); |
257 | } |
258 | |
259 | void TestTypeWithLayoutType::print(AsmPrinter &printer) const { |
260 | printer << "<" << getKey() << ">" ; |
261 | } |
262 | |
263 | llvm::TypeSize |
264 | TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout, |
265 | DataLayoutEntryListRef params) const { |
266 | return llvm::TypeSize::getFixed(extractKind(params, "size" )); |
267 | } |
268 | |
269 | uint64_t |
270 | TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout, |
271 | DataLayoutEntryListRef params) const { |
272 | return extractKind(params, "alignment" ); |
273 | } |
274 | |
275 | uint64_t TestTypeWithLayoutType::getPreferredAlignment( |
276 | const DataLayout &dataLayout, DataLayoutEntryListRef params) const { |
277 | return extractKind(params, "preferred" ); |
278 | } |
279 | |
280 | std::optional<uint64_t> |
281 | TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout, |
282 | DataLayoutEntryListRef params) const { |
283 | return extractKind(params, "index" ); |
284 | } |
285 | |
286 | bool TestTypeWithLayoutType::areCompatible( |
287 | DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const { |
288 | unsigned old = extractKind(oldLayout, "alignment" ); |
289 | return old == 1 || extractKind(newLayout, "alignment" ) <= old; |
290 | } |
291 | |
292 | LogicalResult |
293 | TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params, |
294 | Location loc) const { |
295 | for (DataLayoutEntryInterface entry : params) { |
296 | // This is for testing purposes only, so assert well-formedness. |
297 | assert(entry.isTypeEntry() && "unexpected identifier entry" ); |
298 | assert(llvm::isa<TestTypeWithLayoutType>(entry.getKey().get<Type>()) && |
299 | "wrong type passed in" ); |
300 | auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue()); |
301 | assert(array && array.getValue().size() == 2 && |
302 | "expected array of two elements" ); |
303 | auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front()); |
304 | (void)kind; |
305 | assert(kind && |
306 | (kind.getValue() == "size" || kind.getValue() == "alignment" || |
307 | kind.getValue() == "preferred" || kind.getValue() == "index" ) && |
308 | "unexpected kind" ); |
309 | assert(llvm::isa<IntegerAttr>(array.getValue().back())); |
310 | } |
311 | return success(); |
312 | } |
313 | |
314 | uint64_t TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, |
315 | StringRef expectedKind) const { |
316 | for (DataLayoutEntryInterface entry : params) { |
317 | ArrayRef<Attribute> pair = |
318 | llvm::cast<ArrayAttr>(entry.getValue()).getValue(); |
319 | StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue(); |
320 | if (kind == expectedKind) |
321 | return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue(); |
322 | } |
323 | return 1; |
324 | } |
325 | |
326 | //===----------------------------------------------------------------------===// |
327 | // Dynamic Types |
328 | //===----------------------------------------------------------------------===// |
329 | |
330 | /// Define a singleton dynamic type. |
331 | static std::unique_ptr<DynamicTypeDefinition> |
332 | getSingletonDynamicType(TestDialect *testDialect) { |
333 | return DynamicTypeDefinition::get( |
334 | "dynamic_singleton" , testDialect, |
335 | [](function_ref<InFlightDiagnostic()> emitError, |
336 | ArrayRef<Attribute> args) { |
337 | if (!args.empty()) { |
338 | emitError() << "expected 0 type arguments, but had " << args.size(); |
339 | return failure(); |
340 | } |
341 | return success(); |
342 | }); |
343 | } |
344 | |
345 | /// Define a dynamic type representing a pair. |
346 | static std::unique_ptr<DynamicTypeDefinition> |
347 | getPairDynamicType(TestDialect *testDialect) { |
348 | return DynamicTypeDefinition::get( |
349 | "dynamic_pair" , testDialect, |
350 | [](function_ref<InFlightDiagnostic()> emitError, |
351 | ArrayRef<Attribute> args) { |
352 | if (args.size() != 2) { |
353 | emitError() << "expected 2 type arguments, but had " << args.size(); |
354 | return failure(); |
355 | } |
356 | return success(); |
357 | }); |
358 | } |
359 | |
360 | static std::unique_ptr<DynamicTypeDefinition> |
361 | getCustomAssemblyFormatDynamicType(TestDialect *testDialect) { |
362 | auto verifier = [](function_ref<InFlightDiagnostic()> emitError, |
363 | ArrayRef<Attribute> args) { |
364 | if (args.size() != 2) { |
365 | emitError() << "expected 2 type arguments, but had " << args.size(); |
366 | return failure(); |
367 | } |
368 | return success(); |
369 | }; |
370 | |
371 | auto parser = [](AsmParser &parser, |
372 | llvm::SmallVectorImpl<Attribute> &parsedParams) { |
373 | Attribute leftAttr, rightAttr; |
374 | if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) || |
375 | parser.parseColon() || parser.parseAttribute(result&: rightAttr) || |
376 | parser.parseGreater()) |
377 | return failure(); |
378 | parsedParams.push_back(Elt: leftAttr); |
379 | parsedParams.push_back(Elt: rightAttr); |
380 | return success(); |
381 | }; |
382 | |
383 | auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { |
384 | printer << "<" << params[0] << ":" << params[1] << ">" ; |
385 | }; |
386 | |
387 | return DynamicTypeDefinition::get("dynamic_custom_assembly_format" , |
388 | testDialect, std::move(verifier), |
389 | std::move(parser), std::move(printer)); |
390 | } |
391 | |
392 | //===----------------------------------------------------------------------===// |
393 | // TestDialect |
394 | //===----------------------------------------------------------------------===// |
395 | |
396 | namespace { |
397 | |
398 | struct PtrElementModel |
399 | : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel, |
400 | SimpleAType> {}; |
401 | } // namespace |
402 | |
403 | void TestDialect::registerTypes() { |
404 | addTypes<TestRecursiveType, |
405 | #define GET_TYPEDEF_LIST |
406 | #include "TestTypeDefs.cpp.inc" |
407 | >(); |
408 | SimpleAType::attachInterface<PtrElementModel>(*getContext()); |
409 | |
410 | registerDynamicType(getSingletonDynamicType(this)); |
411 | registerDynamicType(getPairDynamicType(this)); |
412 | registerDynamicType(getCustomAssemblyFormatDynamicType(this)); |
413 | } |
414 | |
415 | Type TestDialect::parseType(DialectAsmParser &parser) const { |
416 | StringRef typeTag; |
417 | { |
418 | Type genType; |
419 | auto parseResult = generatedTypeParser(parser, &typeTag, genType); |
420 | if (parseResult.has_value()) |
421 | return genType; |
422 | } |
423 | |
424 | { |
425 | Type dynType; |
426 | auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); |
427 | if (parseResult.has_value()) { |
428 | if (succeeded(parseResult.value())) |
429 | return dynType; |
430 | return Type(); |
431 | } |
432 | } |
433 | |
434 | if (typeTag != "test_rec" ) { |
435 | parser.emitError(parser.getNameLoc()) << "unknown type!" ; |
436 | return Type(); |
437 | } |
438 | |
439 | StringRef name; |
440 | if (parser.parseLess() || parser.parseKeyword(&name)) |
441 | return Type(); |
442 | auto rec = TestRecursiveType::get(parser.getContext(), name); |
443 | |
444 | FailureOr<AsmParser::CyclicParseReset> cyclicParse = |
445 | parser.tryStartCyclicParse(rec); |
446 | |
447 | // If this type already has been parsed above in the stack, expect just the |
448 | // name. |
449 | if (failed(cyclicParse)) { |
450 | if (failed(parser.parseGreater())) |
451 | return Type(); |
452 | return rec; |
453 | } |
454 | |
455 | // Otherwise, parse the body and update the type. |
456 | if (failed(parser.parseComma())) |
457 | return Type(); |
458 | Type subtype = parseType(parser); |
459 | if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) |
460 | return Type(); |
461 | |
462 | return rec; |
463 | } |
464 | |
465 | void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { |
466 | if (succeeded(generatedTypePrinter(type, printer))) |
467 | return; |
468 | |
469 | if (succeeded(printIfDynamicType(type, printer))) |
470 | return; |
471 | |
472 | auto rec = llvm::cast<TestRecursiveType>(type); |
473 | |
474 | FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint = |
475 | printer.tryStartCyclicPrint(rec); |
476 | |
477 | printer << "test_rec<" << rec.getName(); |
478 | if (succeeded(cyclicPrint)) { |
479 | printer << ", " ; |
480 | printType(rec.getBody(), printer); |
481 | } |
482 | printer << ">" ; |
483 | } |
484 | |
485 | Type TestRecursiveAliasType::getBody() const { return getImpl()->body; } |
486 | |
487 | void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); } |
488 | |
489 | StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; } |
490 | |
491 | Type TestRecursiveAliasType::parse(AsmParser &parser) { |
492 | StringRef name; |
493 | if (parser.parseLess() || parser.parseKeyword(&name)) |
494 | return Type(); |
495 | auto rec = TestRecursiveAliasType::get(parser.getContext(), name); |
496 | |
497 | FailureOr<AsmParser::CyclicParseReset> cyclicParse = |
498 | parser.tryStartCyclicParse(rec); |
499 | |
500 | // If this type already has been parsed above in the stack, expect just the |
501 | // name. |
502 | if (failed(cyclicParse)) { |
503 | if (failed(parser.parseGreater())) |
504 | return Type(); |
505 | return rec; |
506 | } |
507 | |
508 | // Otherwise, parse the body and update the type. |
509 | if (failed(parser.parseComma())) |
510 | return Type(); |
511 | Type subtype; |
512 | if (parser.parseType(subtype)) |
513 | return nullptr; |
514 | if (!subtype || failed(parser.parseGreater())) |
515 | return Type(); |
516 | |
517 | rec.setBody(subtype); |
518 | |
519 | return rec; |
520 | } |
521 | |
522 | void TestRecursiveAliasType::print(AsmPrinter &printer) const { |
523 | |
524 | FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint = |
525 | printer.tryStartCyclicPrint(*this); |
526 | |
527 | printer << "<" << getName(); |
528 | if (succeeded(cyclicPrint)) { |
529 | printer << ", " ; |
530 | printer << getBody(); |
531 | } |
532 | printer << ">" ; |
533 | } |
534 | |