1 | //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains types defined by the TestDialect for testing various |
10 | // features of MLIR. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestTypes.h" |
15 | #include "TestDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/DialectImplementation.h" |
19 | #include "mlir/IR/ExtensibleDialect.h" |
20 | #include "mlir/IR/Types.h" |
21 | #include "llvm/ADT/Hashing.h" |
22 | #include "llvm/ADT/SetVector.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | #include "llvm/Support/TypeSize.h" |
25 | #include <optional> |
26 | |
27 | using namespace mlir; |
28 | using namespace test; |
29 | |
30 | // Custom parser for SignednessSemantics. |
31 | static ParseResult |
32 | parseSignedness(AsmParser &parser, |
33 | TestIntegerType::SignednessSemantics &result) { |
34 | StringRef signStr; |
35 | auto loc = parser.getCurrentLocation(); |
36 | if (parser.parseKeyword(keyword: &signStr)) |
37 | return failure(); |
38 | if (signStr.equals_insensitive(RHS: "u" ) || signStr.equals_insensitive(RHS: "unsigned" )) |
39 | result = TestIntegerType::SignednessSemantics::Unsigned; |
40 | else if (signStr.equals_insensitive(RHS: "s" ) || |
41 | signStr.equals_insensitive(RHS: "signed" )) |
42 | result = TestIntegerType::SignednessSemantics::Signed; |
43 | else if (signStr.equals_insensitive(RHS: "n" ) || |
44 | signStr.equals_insensitive(RHS: "none" )) |
45 | result = TestIntegerType::SignednessSemantics::Signless; |
46 | else |
47 | return parser.emitError(loc, message: "expected signed, unsigned, or none" ); |
48 | return success(); |
49 | } |
50 | |
51 | // Custom printer for SignednessSemantics. |
52 | static void printSignedness(AsmPrinter &printer, |
53 | const TestIntegerType::SignednessSemantics &ss) { |
54 | switch (ss) { |
55 | case TestIntegerType::SignednessSemantics::Unsigned: |
56 | printer << "unsigned" ; |
57 | break; |
58 | case TestIntegerType::SignednessSemantics::Signed: |
59 | printer << "signed" ; |
60 | break; |
61 | case TestIntegerType::SignednessSemantics::Signless: |
62 | printer << "none" ; |
63 | break; |
64 | } |
65 | } |
66 | |
67 | // The functions don't need to be in the header file, but need to be in the mlir |
68 | // namespace. Declare them here, then define them immediately below. Separating |
69 | // the declaration and definition adheres to the LLVM coding standards. |
70 | namespace test { |
71 | // FieldInfo is used as part of a parameter, so equality comparison is |
72 | // compulsory. |
73 | static bool operator==(const FieldInfo &a, const FieldInfo &b); |
74 | // FieldInfo is used as part of a parameter, so a hash will be computed. |
75 | static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT |
76 | } // namespace test |
77 | |
78 | // FieldInfo is used as part of a parameter, so equality comparison is |
79 | // compulsory. |
80 | static bool test::operator==(const FieldInfo &a, const FieldInfo &b) { |
81 | return a.name == b.name && a.type == b.type; |
82 | } |
83 | |
84 | // FieldInfo is used as part of a parameter, so a hash will be computed. |
85 | static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT |
86 | return llvm::hash_combine(args: fi.name, args: fi.type); |
87 | } |
88 | |
89 | //===----------------------------------------------------------------------===// |
90 | // TestCustomType |
91 | //===----------------------------------------------------------------------===// |
92 | |
93 | static ParseResult parseCustomTypeA(AsmParser &parser, int &aResult) { |
94 | return parser.parseInteger(result&: aResult); |
95 | } |
96 | |
97 | static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; } |
98 | |
99 | static ParseResult parseCustomTypeB(AsmParser &parser, int a, |
100 | std::optional<int> &bResult) { |
101 | if (a < 0) |
102 | return success(); |
103 | for (int i : llvm::seq(Begin: 0, End: a)) |
104 | if (failed(Result: parser.parseInteger(result&: i))) |
105 | return failure(); |
106 | bResult.emplace(args: 0); |
107 | return parser.parseInteger(result&: *bResult); |
108 | } |
109 | |
110 | static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) { |
111 | if (a < 0) |
112 | return; |
113 | printer << ' '; |
114 | for (int i : llvm::seq(Begin: 0, End: a)) |
115 | printer << i << ' '; |
116 | printer << *b; |
117 | } |
118 | |
119 | static ParseResult parseFooString(AsmParser &parser, std::string &foo) { |
120 | std::string result; |
121 | if (parser.parseString(string: &result)) |
122 | return failure(); |
123 | foo = std::move(result); |
124 | return success(); |
125 | } |
126 | |
127 | static void printFooString(AsmPrinter &printer, StringRef foo) { |
128 | printer << '"' << foo << '"'; |
129 | } |
130 | |
131 | static ParseResult parseBarString(AsmParser &parser, StringRef foo) { |
132 | return parser.parseKeyword(keyword: foo); |
133 | } |
134 | |
135 | static void printBarString(AsmPrinter &printer, StringRef foo) { |
136 | printer << foo; |
137 | } |
138 | //===----------------------------------------------------------------------===// |
139 | // Tablegen Generated Definitions |
140 | //===----------------------------------------------------------------------===// |
141 | |
142 | #include "TestTypeInterfaces.cpp.inc" |
143 | #define GET_TYPEDEF_CLASSES |
144 | #include "TestTypeDefs.cpp.inc" |
145 | |
146 | //===----------------------------------------------------------------------===// |
147 | // CompoundAType |
148 | //===----------------------------------------------------------------------===// |
149 | |
150 | Type CompoundAType::parse(AsmParser &parser) { |
151 | int widthOfSomething; |
152 | Type oneType; |
153 | SmallVector<int, 4> arrayOfInts; |
154 | if (parser.parseLess() || parser.parseInteger(widthOfSomething) || |
155 | parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || |
156 | parser.parseLSquare()) |
157 | return Type(); |
158 | |
159 | int i; |
160 | while (!*parser.parseOptionalInteger(i)) { |
161 | arrayOfInts.push_back(i); |
162 | if (parser.parseOptionalComma()) |
163 | break; |
164 | } |
165 | |
166 | if (parser.parseRSquare() || parser.parseGreater()) |
167 | return Type(); |
168 | |
169 | return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); |
170 | } |
171 | void CompoundAType::print(AsmPrinter &printer) const { |
172 | printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [" ; |
173 | auto intArray = getArrayOfInts(); |
174 | llvm::interleaveComma(intArray, printer); |
175 | printer << "]>" ; |
176 | } |
177 | |
178 | //===----------------------------------------------------------------------===// |
179 | // TestIntegerType |
180 | //===----------------------------------------------------------------------===// |
181 | |
182 | // Example type validity checker. |
183 | LogicalResult |
184 | TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError, |
185 | unsigned width, |
186 | TestIntegerType::SignednessSemantics ss) { |
187 | if (width > 8) |
188 | return failure(); |
189 | return success(); |
190 | } |
191 | |
192 | Type TestIntegerType::parse(AsmParser &parser) { |
193 | SignednessSemantics signedness; |
194 | int width; |
195 | if (parser.parseLess() || parseSignedness(parser, signedness) || |
196 | parser.parseComma() || parser.parseInteger(width) || |
197 | parser.parseGreater()) |
198 | return Type(); |
199 | Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); |
200 | return getChecked(loc, loc.getContext(), width, signedness); |
201 | } |
202 | |
203 | void TestIntegerType::print(AsmPrinter &p) const { |
204 | p << "<" ; |
205 | printSignedness(p, getSignedness()); |
206 | p << ", " << getWidth() << ">" ; |
207 | } |
208 | |
209 | //===----------------------------------------------------------------------===// |
210 | // TestStructType |
211 | //===----------------------------------------------------------------------===// |
212 | |
213 | Type StructType::parse(AsmParser &p) { |
214 | SmallVector<FieldInfo, 4> parameters; |
215 | if (p.parseLess()) |
216 | return Type(); |
217 | while (succeeded(p.parseOptionalLBrace())) { |
218 | Type type; |
219 | StringRef name; |
220 | if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) || |
221 | p.parseRBrace()) |
222 | return Type(); |
223 | parameters.push_back(FieldInfo{name, type}); |
224 | if (p.parseOptionalComma()) |
225 | break; |
226 | } |
227 | if (p.parseGreater()) |
228 | return Type(); |
229 | return get(p.getContext(), parameters); |
230 | } |
231 | |
232 | void StructType::print(AsmPrinter &p) const { |
233 | p << "<" ; |
234 | llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) { |
235 | p << "{" << field.name << "," << field.type << "}" ; |
236 | }); |
237 | p << ">" ; |
238 | } |
239 | |
240 | //===----------------------------------------------------------------------===// |
241 | // TestType |
242 | //===----------------------------------------------------------------------===// |
243 | |
244 | void TestType::printTypeC(Location loc) const { |
245 | emitRemark(loc) << *this << " - TestC" ; |
246 | } |
247 | |
248 | //===----------------------------------------------------------------------===// |
249 | // TestTypeWithLayout |
250 | //===----------------------------------------------------------------------===// |
251 | |
252 | Type TestTypeWithLayoutType::parse(AsmParser &parser) { |
253 | unsigned val; |
254 | if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater()) |
255 | return Type(); |
256 | return TestTypeWithLayoutType::get(parser.getContext(), val); |
257 | } |
258 | |
259 | void TestTypeWithLayoutType::print(AsmPrinter &printer) const { |
260 | printer << "<" << getKey() << ">" ; |
261 | } |
262 | |
263 | llvm::TypeSize |
264 | TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout, |
265 | DataLayoutEntryListRef params) const { |
266 | return llvm::TypeSize::getFixed(extractKind(params, "size" )); |
267 | } |
268 | |
269 | uint64_t |
270 | TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout, |
271 | DataLayoutEntryListRef params) const { |
272 | return extractKind(params, "alignment" ); |
273 | } |
274 | |
275 | uint64_t TestTypeWithLayoutType::getPreferredAlignment( |
276 | const DataLayout &dataLayout, DataLayoutEntryListRef params) const { |
277 | return extractKind(params, "preferred" ); |
278 | } |
279 | |
280 | std::optional<uint64_t> |
281 | TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout, |
282 | DataLayoutEntryListRef params) const { |
283 | return extractKind(params, "index" ); |
284 | } |
285 | |
286 | bool TestTypeWithLayoutType::areCompatible( |
287 | DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout, |
288 | DataLayoutSpecInterface newSpec, |
289 | const DataLayoutIdentifiedEntryMap &map) const { |
290 | unsigned old = extractKind(oldLayout, "alignment" ); |
291 | return old == 1 || extractKind(newLayout, "alignment" ) <= old; |
292 | } |
293 | |
294 | LogicalResult |
295 | TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params, |
296 | Location loc) const { |
297 | for (DataLayoutEntryInterface entry : params) { |
298 | // This is for testing purposes only, so assert well-formedness. |
299 | assert(entry.isTypeEntry() && "unexpected identifier entry" ); |
300 | assert( |
301 | llvm::isa<TestTypeWithLayoutType>(llvm::cast<Type>(entry.getKey())) && |
302 | "wrong type passed in" ); |
303 | auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue()); |
304 | assert(array && array.getValue().size() == 2 && |
305 | "expected array of two elements" ); |
306 | auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front()); |
307 | (void)kind; |
308 | assert(kind && |
309 | (kind.getValue() == "size" || kind.getValue() == "alignment" || |
310 | kind.getValue() == "preferred" || kind.getValue() == "index" ) && |
311 | "unexpected kind" ); |
312 | assert(llvm::isa<IntegerAttr>(array.getValue().back())); |
313 | } |
314 | return success(); |
315 | } |
316 | |
317 | uint64_t TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, |
318 | StringRef expectedKind) const { |
319 | for (DataLayoutEntryInterface entry : params) { |
320 | ArrayRef<Attribute> pair = |
321 | llvm::cast<ArrayAttr>(entry.getValue()).getValue(); |
322 | StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue(); |
323 | if (kind == expectedKind) |
324 | return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue(); |
325 | } |
326 | return 1; |
327 | } |
328 | |
329 | //===----------------------------------------------------------------------===// |
330 | // Dynamic Types |
331 | //===----------------------------------------------------------------------===// |
332 | |
333 | /// Define a singleton dynamic type. |
334 | static std::unique_ptr<DynamicTypeDefinition> |
335 | getSingletonDynamicType(TestDialect *testDialect) { |
336 | return DynamicTypeDefinition::get( |
337 | "dynamic_singleton" , testDialect, |
338 | [](function_ref<InFlightDiagnostic()> emitError, |
339 | ArrayRef<Attribute> args) { |
340 | if (!args.empty()) { |
341 | emitError() << "expected 0 type arguments, but had " << args.size(); |
342 | return failure(); |
343 | } |
344 | return success(); |
345 | }); |
346 | } |
347 | |
348 | /// Define a dynamic type representing a pair. |
349 | static std::unique_ptr<DynamicTypeDefinition> |
350 | getPairDynamicType(TestDialect *testDialect) { |
351 | return DynamicTypeDefinition::get( |
352 | "dynamic_pair" , testDialect, |
353 | [](function_ref<InFlightDiagnostic()> emitError, |
354 | ArrayRef<Attribute> args) { |
355 | if (args.size() != 2) { |
356 | emitError() << "expected 2 type arguments, but had " << args.size(); |
357 | return failure(); |
358 | } |
359 | return success(); |
360 | }); |
361 | } |
362 | |
363 | static std::unique_ptr<DynamicTypeDefinition> |
364 | getCustomAssemblyFormatDynamicType(TestDialect *testDialect) { |
365 | auto verifier = [](function_ref<InFlightDiagnostic()> emitError, |
366 | ArrayRef<Attribute> args) { |
367 | if (args.size() != 2) { |
368 | emitError() << "expected 2 type arguments, but had " << args.size(); |
369 | return failure(); |
370 | } |
371 | return success(); |
372 | }; |
373 | |
374 | auto parser = [](AsmParser &parser, |
375 | llvm::SmallVectorImpl<Attribute> &parsedParams) { |
376 | Attribute leftAttr, rightAttr; |
377 | if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) || |
378 | parser.parseColon() || parser.parseAttribute(result&: rightAttr) || |
379 | parser.parseGreater()) |
380 | return failure(); |
381 | parsedParams.push_back(Elt: leftAttr); |
382 | parsedParams.push_back(Elt: rightAttr); |
383 | return success(); |
384 | }; |
385 | |
386 | auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) { |
387 | printer << "<" << params[0] << ":" << params[1] << ">" ; |
388 | }; |
389 | |
390 | return DynamicTypeDefinition::get("dynamic_custom_assembly_format" , |
391 | testDialect, std::move(verifier), |
392 | std::move(parser), std::move(printer)); |
393 | } |
394 | |
395 | //===----------------------------------------------------------------------===// |
396 | // TestDialect |
397 | //===----------------------------------------------------------------------===// |
398 | |
399 | namespace { |
400 | |
401 | struct PtrElementModel |
402 | : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel, |
403 | SimpleAType> {}; |
404 | } // namespace |
405 | |
406 | void TestDialect::registerTypes() { |
407 | addTypes<TestRecursiveType, |
408 | #define GET_TYPEDEF_LIST |
409 | #include "TestTypeDefs.cpp.inc" |
410 | >(); |
411 | SimpleAType::attachInterface<PtrElementModel>(*getContext()); |
412 | |
413 | registerDynamicType(getSingletonDynamicType(this)); |
414 | registerDynamicType(getPairDynamicType(this)); |
415 | registerDynamicType(getCustomAssemblyFormatDynamicType(this)); |
416 | } |
417 | |
418 | Type TestDialect::parseType(DialectAsmParser &parser) const { |
419 | StringRef typeTag; |
420 | { |
421 | Type genType; |
422 | auto parseResult = generatedTypeParser(parser, &typeTag, genType); |
423 | if (parseResult.has_value()) |
424 | return genType; |
425 | } |
426 | |
427 | { |
428 | Type dynType; |
429 | auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType); |
430 | if (parseResult.has_value()) { |
431 | if (succeeded(parseResult.value())) |
432 | return dynType; |
433 | return Type(); |
434 | } |
435 | } |
436 | |
437 | if (typeTag != "test_rec" ) { |
438 | parser.emitError(parser.getNameLoc()) << "unknown type!" ; |
439 | return Type(); |
440 | } |
441 | |
442 | StringRef name; |
443 | if (parser.parseLess() || parser.parseKeyword(&name)) |
444 | return Type(); |
445 | auto rec = TestRecursiveType::get(parser.getContext(), name); |
446 | |
447 | FailureOr<AsmParser::CyclicParseReset> cyclicParse = |
448 | parser.tryStartCyclicParse(rec); |
449 | |
450 | // If this type already has been parsed above in the stack, expect just the |
451 | // name. |
452 | if (failed(cyclicParse)) { |
453 | if (failed(parser.parseGreater())) |
454 | return Type(); |
455 | return rec; |
456 | } |
457 | |
458 | // Otherwise, parse the body and update the type. |
459 | if (failed(parser.parseComma())) |
460 | return Type(); |
461 | Type subtype = parseType(parser); |
462 | if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) |
463 | return Type(); |
464 | |
465 | return rec; |
466 | } |
467 | |
468 | void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { |
469 | if (succeeded(generatedTypePrinter(type, printer))) |
470 | return; |
471 | |
472 | if (succeeded(printIfDynamicType(type, printer))) |
473 | return; |
474 | |
475 | auto rec = llvm::cast<TestRecursiveType>(type); |
476 | |
477 | FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint = |
478 | printer.tryStartCyclicPrint(rec); |
479 | |
480 | printer << "test_rec<" << rec.getName(); |
481 | if (succeeded(cyclicPrint)) { |
482 | printer << ", " ; |
483 | printType(rec.getBody(), printer); |
484 | } |
485 | printer << ">" ; |
486 | } |
487 | |
488 | Type TestRecursiveAliasType::getBody() const { return getImpl()->body; } |
489 | |
490 | void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); } |
491 | |
492 | StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; } |
493 | |
494 | Type TestRecursiveAliasType::parse(AsmParser &parser) { |
495 | StringRef name; |
496 | if (parser.parseLess() || parser.parseKeyword(&name)) |
497 | return Type(); |
498 | auto rec = TestRecursiveAliasType::get(parser.getContext(), name); |
499 | |
500 | FailureOr<AsmParser::CyclicParseReset> cyclicParse = |
501 | parser.tryStartCyclicParse(rec); |
502 | |
503 | // If this type already has been parsed above in the stack, expect just the |
504 | // name. |
505 | if (failed(cyclicParse)) { |
506 | if (failed(parser.parseGreater())) |
507 | return Type(); |
508 | return rec; |
509 | } |
510 | |
511 | // Otherwise, parse the body and update the type. |
512 | if (failed(parser.parseComma())) |
513 | return Type(); |
514 | Type subtype; |
515 | if (parser.parseType(subtype)) |
516 | return nullptr; |
517 | if (!subtype || failed(parser.parseGreater())) |
518 | return Type(); |
519 | |
520 | rec.setBody(subtype); |
521 | |
522 | return rec; |
523 | } |
524 | |
525 | void TestRecursiveAliasType::print(AsmPrinter &printer) const { |
526 | |
527 | FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint = |
528 | printer.tryStartCyclicPrint(*this); |
529 | |
530 | printer << "<" << getName(); |
531 | if (succeeded(cyclicPrint)) { |
532 | printer << ", " ; |
533 | printer << getBody(); |
534 | } |
535 | printer << ">" ; |
536 | } |
537 | |
538 | void TestTypeOpAsmTypeInterfaceType::getAsmName( |
539 | OpAsmSetNameFn setNameFn) const { |
540 | setNameFn("op_asm_type_interface" ); |
541 | } |
542 | |
543 | ::mlir::OpAsmDialectInterface::AliasResult |
544 | TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const { |
545 | os << "op_asm_type_interface_type" ; |
546 | return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias; |
547 | } |
548 | |