1//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains types defined by the TestDialect for testing various
10// features of MLIR.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestTypes.h"
15#include "TestDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/DialectImplementation.h"
19#include "mlir/IR/ExtensibleDialect.h"
20#include "mlir/IR/Types.h"
21#include "llvm/ADT/Hashing.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/TypeSize.h"
25#include <optional>
26
27using namespace mlir;
28using namespace test;
29
30// Custom parser for SignednessSemantics.
31static ParseResult
32parseSignedness(AsmParser &parser,
33 TestIntegerType::SignednessSemantics &result) {
34 StringRef signStr;
35 auto loc = parser.getCurrentLocation();
36 if (parser.parseKeyword(keyword: &signStr))
37 return failure();
38 if (signStr.equals_insensitive(RHS: "u") || signStr.equals_insensitive(RHS: "unsigned"))
39 result = TestIntegerType::SignednessSemantics::Unsigned;
40 else if (signStr.equals_insensitive(RHS: "s") ||
41 signStr.equals_insensitive(RHS: "signed"))
42 result = TestIntegerType::SignednessSemantics::Signed;
43 else if (signStr.equals_insensitive(RHS: "n") ||
44 signStr.equals_insensitive(RHS: "none"))
45 result = TestIntegerType::SignednessSemantics::Signless;
46 else
47 return parser.emitError(loc, message: "expected signed, unsigned, or none");
48 return success();
49}
50
51// Custom printer for SignednessSemantics.
52static void printSignedness(AsmPrinter &printer,
53 const TestIntegerType::SignednessSemantics &ss) {
54 switch (ss) {
55 case TestIntegerType::SignednessSemantics::Unsigned:
56 printer << "unsigned";
57 break;
58 case TestIntegerType::SignednessSemantics::Signed:
59 printer << "signed";
60 break;
61 case TestIntegerType::SignednessSemantics::Signless:
62 printer << "none";
63 break;
64 }
65}
66
67// The functions don't need to be in the header file, but need to be in the mlir
68// namespace. Declare them here, then define them immediately below. Separating
69// the declaration and definition adheres to the LLVM coding standards.
70namespace test {
71// FieldInfo is used as part of a parameter, so equality comparison is
72// compulsory.
73static bool operator==(const FieldInfo &a, const FieldInfo &b);
74// FieldInfo is used as part of a parameter, so a hash will be computed.
75static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
76} // namespace test
77
78// FieldInfo is used as part of a parameter, so equality comparison is
79// compulsory.
80static bool test::operator==(const FieldInfo &a, const FieldInfo &b) {
81 return a.name == b.name && a.type == b.type;
82}
83
84// FieldInfo is used as part of a parameter, so a hash will be computed.
85static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
86 return llvm::hash_combine(args: fi.name, args: fi.type);
87}
88
89//===----------------------------------------------------------------------===//
90// TestCustomType
91//===----------------------------------------------------------------------===//
92
93static ParseResult parseCustomTypeA(AsmParser &parser, int &aResult) {
94 return parser.parseInteger(result&: aResult);
95}
96
97static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
98
99static ParseResult parseCustomTypeB(AsmParser &parser, int a,
100 std::optional<int> &bResult) {
101 if (a < 0)
102 return success();
103 for (int i : llvm::seq(Begin: 0, End: a))
104 if (failed(Result: parser.parseInteger(result&: i)))
105 return failure();
106 bResult.emplace(args: 0);
107 return parser.parseInteger(result&: *bResult);
108}
109
110static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) {
111 if (a < 0)
112 return;
113 printer << ' ';
114 for (int i : llvm::seq(Begin: 0, End: a))
115 printer << i << ' ';
116 printer << *b;
117}
118
119static ParseResult parseFooString(AsmParser &parser, std::string &foo) {
120 std::string result;
121 if (parser.parseString(string: &result))
122 return failure();
123 foo = std::move(result);
124 return success();
125}
126
127static void printFooString(AsmPrinter &printer, StringRef foo) {
128 printer << '"' << foo << '"';
129}
130
131static ParseResult parseBarString(AsmParser &parser, StringRef foo) {
132 return parser.parseKeyword(keyword: foo);
133}
134
135static void printBarString(AsmPrinter &printer, StringRef foo) {
136 printer << foo;
137}
138//===----------------------------------------------------------------------===//
139// Tablegen Generated Definitions
140//===----------------------------------------------------------------------===//
141
142#include "TestTypeInterfaces.cpp.inc"
143#define GET_TYPEDEF_CLASSES
144#include "TestTypeDefs.cpp.inc"
145
146//===----------------------------------------------------------------------===//
147// CompoundAType
148//===----------------------------------------------------------------------===//
149
150Type CompoundAType::parse(AsmParser &parser) {
151 int widthOfSomething;
152 Type oneType;
153 SmallVector<int, 4> arrayOfInts;
154 if (parser.parseLess() || parser.parseInteger(result&: widthOfSomething) ||
155 parser.parseComma() || parser.parseType(result&: oneType) || parser.parseComma() ||
156 parser.parseLSquare())
157 return Type();
158
159 int i;
160 while (!*parser.parseOptionalInteger(result&: i)) {
161 arrayOfInts.push_back(Elt: i);
162 if (parser.parseOptionalComma())
163 break;
164 }
165
166 if (parser.parseRSquare() || parser.parseGreater())
167 return Type();
168
169 return get(context: parser.getContext(), widthOfSomething, oneType, arrayOfInts);
170}
171void CompoundAType::print(AsmPrinter &printer) const {
172 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
173 auto intArray = getArrayOfInts();
174 llvm::interleaveComma(c: intArray, os&: printer);
175 printer << "]>";
176}
177
178//===----------------------------------------------------------------------===//
179// TestIntegerType
180//===----------------------------------------------------------------------===//
181
182// Example type validity checker.
183LogicalResult
184TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
185 unsigned width,
186 TestIntegerType::SignednessSemantics ss) {
187 if (width > 8)
188 return failure();
189 return success();
190}
191
192Type TestIntegerType::parse(AsmParser &parser) {
193 SignednessSemantics signedness;
194 int width;
195 if (parser.parseLess() || parseSignedness(parser, result&: signedness) ||
196 parser.parseComma() || parser.parseInteger(result&: width) ||
197 parser.parseGreater())
198 return Type();
199 Location loc = parser.getEncodedSourceLoc(loc: parser.getNameLoc());
200 return getChecked(loc, args: loc.getContext(), args&: width, args&: signedness);
201}
202
203void TestIntegerType::print(AsmPrinter &p) const {
204 p << "<";
205 printSignedness(printer&: p, ss: getSignedness());
206 p << ", " << getWidth() << ">";
207}
208
209//===----------------------------------------------------------------------===//
210// TestStructType
211//===----------------------------------------------------------------------===//
212
213Type StructType::parse(AsmParser &p) {
214 SmallVector<FieldInfo, 4> parameters;
215 if (p.parseLess())
216 return Type();
217 while (succeeded(Result: p.parseOptionalLBrace())) {
218 Type type;
219 StringRef name;
220 if (p.parseKeyword(keyword: &name) || p.parseComma() || p.parseType(result&: type) ||
221 p.parseRBrace())
222 return Type();
223 parameters.push_back(Elt: FieldInfo{.name: name, .type: type});
224 if (p.parseOptionalComma())
225 break;
226 }
227 if (p.parseGreater())
228 return Type();
229 return get(context: p.getContext(), fields: parameters);
230}
231
232void StructType::print(AsmPrinter &p) const {
233 p << "<";
234 llvm::interleaveComma(c: getFields(), os&: p, each_fn: [&](const FieldInfo &field) {
235 p << "{" << field.name << "," << field.type << "}";
236 });
237 p << ">";
238}
239
240//===----------------------------------------------------------------------===//
241// TestType
242//===----------------------------------------------------------------------===//
243
244void TestType::printTypeC(Location loc) const {
245 emitRemark(loc) << *this << " - TestC";
246}
247
248//===----------------------------------------------------------------------===//
249// TestTypeWithLayout
250//===----------------------------------------------------------------------===//
251
252Type TestTypeWithLayoutType::parse(AsmParser &parser) {
253 unsigned val;
254 if (parser.parseLess() || parser.parseInteger(result&: val) || parser.parseGreater())
255 return Type();
256 return TestTypeWithLayoutType::get(context: parser.getContext(), key: val);
257}
258
259void TestTypeWithLayoutType::print(AsmPrinter &printer) const {
260 printer << "<" << getKey() << ">";
261}
262
263llvm::TypeSize
264TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
265 DataLayoutEntryListRef params) const {
266 return llvm::TypeSize::getFixed(ExactSize: extractKind(params, expectedKind: "size"));
267}
268
269uint64_t
270TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
271 DataLayoutEntryListRef params) const {
272 return extractKind(params, expectedKind: "alignment");
273}
274
275uint64_t TestTypeWithLayoutType::getPreferredAlignment(
276 const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
277 return extractKind(params, expectedKind: "preferred");
278}
279
280std::optional<uint64_t>
281TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout,
282 DataLayoutEntryListRef params) const {
283 return extractKind(params, expectedKind: "index");
284}
285
286bool TestTypeWithLayoutType::areCompatible(
287 DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
288 DataLayoutSpecInterface newSpec,
289 const DataLayoutIdentifiedEntryMap &map) const {
290 unsigned old = extractKind(params: oldLayout, expectedKind: "alignment");
291 return old == 1 || extractKind(params: newLayout, expectedKind: "alignment") <= old;
292}
293
294LogicalResult
295TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
296 Location loc) const {
297 for (DataLayoutEntryInterface entry : params) {
298 // This is for testing purposes only, so assert well-formedness.
299 assert(entry.isTypeEntry() && "unexpected identifier entry");
300 assert(
301 llvm::isa<TestTypeWithLayoutType>(llvm::cast<Type>(entry.getKey())) &&
302 "wrong type passed in");
303 auto array = llvm::dyn_cast<ArrayAttr>(Val: entry.getValue());
304 assert(array && array.getValue().size() == 2 &&
305 "expected array of two elements");
306 auto kind = llvm::dyn_cast<StringAttr>(Val: array.getValue().front());
307 (void)kind;
308 assert(kind &&
309 (kind.getValue() == "size" || kind.getValue() == "alignment" ||
310 kind.getValue() == "preferred" || kind.getValue() == "index") &&
311 "unexpected kind");
312 assert(llvm::isa<IntegerAttr>(array.getValue().back()));
313 }
314 return success();
315}
316
317uint64_t TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
318 StringRef expectedKind) const {
319 for (DataLayoutEntryInterface entry : params) {
320 ArrayRef<Attribute> pair =
321 llvm::cast<ArrayAttr>(Val: entry.getValue()).getValue();
322 StringRef kind = llvm::cast<StringAttr>(Val: pair.front()).getValue();
323 if (kind == expectedKind)
324 return llvm::cast<IntegerAttr>(Val: pair.back()).getValue().getZExtValue();
325 }
326 return 1;
327}
328
329//===----------------------------------------------------------------------===//
330// Dynamic Types
331//===----------------------------------------------------------------------===//
332
333/// Define a singleton dynamic type.
334static std::unique_ptr<DynamicTypeDefinition>
335getSingletonDynamicType(TestDialect *testDialect) {
336 return DynamicTypeDefinition::get(
337 name: "dynamic_singleton", dialect: testDialect,
338 verifier: [](function_ref<InFlightDiagnostic()> emitError,
339 ArrayRef<Attribute> args) {
340 if (!args.empty()) {
341 emitError() << "expected 0 type arguments, but had " << args.size();
342 return failure();
343 }
344 return success();
345 });
346}
347
348/// Define a dynamic type representing a pair.
349static std::unique_ptr<DynamicTypeDefinition>
350getPairDynamicType(TestDialect *testDialect) {
351 return DynamicTypeDefinition::get(
352 name: "dynamic_pair", dialect: testDialect,
353 verifier: [](function_ref<InFlightDiagnostic()> emitError,
354 ArrayRef<Attribute> args) {
355 if (args.size() != 2) {
356 emitError() << "expected 2 type arguments, but had " << args.size();
357 return failure();
358 }
359 return success();
360 });
361}
362
363static std::unique_ptr<DynamicTypeDefinition>
364getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
365 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
366 ArrayRef<Attribute> args) {
367 if (args.size() != 2) {
368 emitError() << "expected 2 type arguments, but had " << args.size();
369 return failure();
370 }
371 return success();
372 };
373
374 auto parser = [](AsmParser &parser,
375 llvm::SmallVectorImpl<Attribute> &parsedParams) {
376 Attribute leftAttr, rightAttr;
377 if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) ||
378 parser.parseColon() || parser.parseAttribute(result&: rightAttr) ||
379 parser.parseGreater())
380 return failure();
381 parsedParams.push_back(Elt: leftAttr);
382 parsedParams.push_back(Elt: rightAttr);
383 return success();
384 };
385
386 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
387 printer << "<" << params[0] << ":" << params[1] << ">";
388 };
389
390 return DynamicTypeDefinition::get(name: "dynamic_custom_assembly_format",
391 dialect: testDialect, verifier: std::move(verifier),
392 parser: std::move(parser), printer: std::move(printer));
393}
394
395test::detail::TestCustomStorageCtorTypeStorage *
396test::detail::TestCustomStorageCtorTypeStorage::construct(
397 mlir::StorageUniquer::StorageAllocator &, std::tuple<int> &&) {
398 // Note: this tests linker error ("undefined symbol"), the actual
399 // implementation is not important.
400 return nullptr;
401}
402
403//===----------------------------------------------------------------------===//
404// TestDialect
405//===----------------------------------------------------------------------===//
406
407namespace {
408
409struct PtrElementModel
410 : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
411 SimpleAType> {};
412} // namespace
413
414void TestDialect::registerTypes() {
415 addTypes<TestRecursiveType,
416#define GET_TYPEDEF_LIST
417#include "TestTypeDefs.cpp.inc"
418 >();
419 SimpleAType::attachInterface<PtrElementModel>(context&: *getContext());
420
421 registerDynamicType(type: getSingletonDynamicType(testDialect: this));
422 registerDynamicType(type: getPairDynamicType(testDialect: this));
423 registerDynamicType(type: getCustomAssemblyFormatDynamicType(testDialect: this));
424}
425
426Type TestDialect::parseType(DialectAsmParser &parser) const {
427 StringRef typeTag;
428 {
429 Type genType;
430 auto parseResult = generatedTypeParser(parser, mnemonic: &typeTag, value&: genType);
431 if (parseResult.has_value())
432 return genType;
433 }
434
435 {
436 Type dynType;
437 auto parseResult = parseOptionalDynamicType(typeName: typeTag, parser, resultType&: dynType);
438 if (parseResult.has_value()) {
439 if (succeeded(Result: parseResult.value()))
440 return dynType;
441 return Type();
442 }
443 }
444
445 if (typeTag != "test_rec") {
446 parser.emitError(loc: parser.getNameLoc()) << "unknown type!";
447 return Type();
448 }
449
450 StringRef name;
451 if (parser.parseLess() || parser.parseKeyword(keyword: &name))
452 return Type();
453 auto rec = TestRecursiveType::get(ctx: parser.getContext(), name);
454
455 FailureOr<AsmParser::CyclicParseReset> cyclicParse =
456 parser.tryStartCyclicParse(attrOrType: rec);
457
458 // If this type already has been parsed above in the stack, expect just the
459 // name.
460 if (failed(Result: cyclicParse)) {
461 if (failed(Result: parser.parseGreater()))
462 return Type();
463 return rec;
464 }
465
466 // Otherwise, parse the body and update the type.
467 if (failed(Result: parser.parseComma()))
468 return Type();
469 Type subtype = parseType(parser);
470 if (!subtype || failed(Result: parser.parseGreater()) || failed(Result: rec.setBody(subtype)))
471 return Type();
472
473 return rec;
474}
475
476void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
477 if (succeeded(Result: generatedTypePrinter(def: type, printer)))
478 return;
479
480 if (succeeded(Result: printIfDynamicType(type, printer)))
481 return;
482
483 auto rec = llvm::cast<TestRecursiveType>(Val&: type);
484
485 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
486 printer.tryStartCyclicPrint(attrOrType: rec);
487
488 printer << "test_rec<" << rec.getName();
489 if (succeeded(Result: cyclicPrint)) {
490 printer << ", ";
491 printType(type: rec.getBody(), printer);
492 }
493 printer << ">";
494}
495
496Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
497
498void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(args&: type); }
499
500StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
501
502Type TestRecursiveAliasType::parse(AsmParser &parser) {
503 StringRef name;
504 if (parser.parseLess() || parser.parseKeyword(keyword: &name))
505 return Type();
506 auto rec = TestRecursiveAliasType::get(context: parser.getContext(), name);
507
508 FailureOr<AsmParser::CyclicParseReset> cyclicParse =
509 parser.tryStartCyclicParse(attrOrType: rec);
510
511 // If this type already has been parsed above in the stack, expect just the
512 // name.
513 if (failed(Result: cyclicParse)) {
514 if (failed(Result: parser.parseGreater()))
515 return Type();
516 return rec;
517 }
518
519 // Otherwise, parse the body and update the type.
520 if (failed(Result: parser.parseComma()))
521 return Type();
522 Type subtype;
523 if (parser.parseType(result&: subtype))
524 return nullptr;
525 if (!subtype || failed(Result: parser.parseGreater()))
526 return Type();
527
528 rec.setBody(subtype);
529
530 return rec;
531}
532
533void TestRecursiveAliasType::print(AsmPrinter &printer) const {
534
535 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
536 printer.tryStartCyclicPrint(attrOrType: *this);
537
538 printer << "<" << getName();
539 if (succeeded(Result: cyclicPrint)) {
540 printer << ", ";
541 printer << getBody();
542 }
543 printer << ">";
544}
545
546void TestTypeOpAsmTypeInterfaceType::getAsmName(
547 OpAsmSetNameFn setNameFn) const {
548 setNameFn("op_asm_type_interface");
549}
550
551::mlir::OpAsmDialectInterface::AliasResult
552TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const {
553 os << "op_asm_type_interface_type";
554 return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
555}
556
557::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
558TestTensorType::getBufferType(
559 const ::mlir::bufferization::BufferizationOptions &,
560 ::llvm::function_ref<::mlir::InFlightDiagnostic()>) {
561 return cast<bufferization::BufferLikeType>(
562 Val: TestMemrefType::get(context: getContext(), shape: getShape(), elementType: getElementType(), memSpace: nullptr));
563}
564
565::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
566 ::mlir::bufferization::BufferLikeType bufferType,
567 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
568 auto testMemref = dyn_cast<TestMemrefType>(Val&: bufferType);
569 if (!testMemref)
570 return emitError() << "expected TestMemrefType";
571
572 const bool valid = getShape() == testMemref.getShape() &&
573 getElementType() == testMemref.getElementType();
574 return mlir::success(IsSuccess: valid);
575}
576

source code of mlir/test/lib/Dialect/Test/TestTypes.cpp