1//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains types defined by the TestDialect for testing various
10// features of MLIR.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestTypes.h"
15#include "TestDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/DialectImplementation.h"
19#include "mlir/IR/ExtensibleDialect.h"
20#include "mlir/IR/Types.h"
21#include "llvm/ADT/Hashing.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/TypeSize.h"
25#include <optional>
26
27using namespace mlir;
28using namespace test;
29
30// Custom parser for SignednessSemantics.
31static ParseResult
32parseSignedness(AsmParser &parser,
33 TestIntegerType::SignednessSemantics &result) {
34 StringRef signStr;
35 auto loc = parser.getCurrentLocation();
36 if (parser.parseKeyword(keyword: &signStr))
37 return failure();
38 if (signStr.equals_insensitive(RHS: "u") || signStr.equals_insensitive(RHS: "unsigned"))
39 result = TestIntegerType::SignednessSemantics::Unsigned;
40 else if (signStr.equals_insensitive(RHS: "s") ||
41 signStr.equals_insensitive(RHS: "signed"))
42 result = TestIntegerType::SignednessSemantics::Signed;
43 else if (signStr.equals_insensitive(RHS: "n") ||
44 signStr.equals_insensitive(RHS: "none"))
45 result = TestIntegerType::SignednessSemantics::Signless;
46 else
47 return parser.emitError(loc, message: "expected signed, unsigned, or none");
48 return success();
49}
50
51// Custom printer for SignednessSemantics.
52static void printSignedness(AsmPrinter &printer,
53 const TestIntegerType::SignednessSemantics &ss) {
54 switch (ss) {
55 case TestIntegerType::SignednessSemantics::Unsigned:
56 printer << "unsigned";
57 break;
58 case TestIntegerType::SignednessSemantics::Signed:
59 printer << "signed";
60 break;
61 case TestIntegerType::SignednessSemantics::Signless:
62 printer << "none";
63 break;
64 }
65}
66
67// The functions don't need to be in the header file, but need to be in the mlir
68// namespace. Declare them here, then define them immediately below. Separating
69// the declaration and definition adheres to the LLVM coding standards.
70namespace test {
71// FieldInfo is used as part of a parameter, so equality comparison is
72// compulsory.
73static bool operator==(const FieldInfo &a, const FieldInfo &b);
74// FieldInfo is used as part of a parameter, so a hash will be computed.
75static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
76} // namespace test
77
78// FieldInfo is used as part of a parameter, so equality comparison is
79// compulsory.
80static bool test::operator==(const FieldInfo &a, const FieldInfo &b) {
81 return a.name == b.name && a.type == b.type;
82}
83
84// FieldInfo is used as part of a parameter, so a hash will be computed.
85static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
86 return llvm::hash_combine(args: fi.name, args: fi.type);
87}
88
89//===----------------------------------------------------------------------===//
90// TestCustomType
91//===----------------------------------------------------------------------===//
92
93static LogicalResult parseCustomTypeA(AsmParser &parser, int &aResult) {
94 return parser.parseInteger(result&: aResult);
95}
96
97static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
98
99static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
100 std::optional<int> &bResult) {
101 if (a < 0)
102 return success();
103 for (int i : llvm::seq(Begin: 0, End: a))
104 if (failed(result: parser.parseInteger(result&: i)))
105 return failure();
106 bResult.emplace(args: 0);
107 return parser.parseInteger(result&: *bResult);
108}
109
110static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) {
111 if (a < 0)
112 return;
113 printer << ' ';
114 for (int i : llvm::seq(Begin: 0, End: a))
115 printer << i << ' ';
116 printer << *b;
117}
118
119static LogicalResult parseFooString(AsmParser &parser, std::string &foo) {
120 std::string result;
121 if (parser.parseString(string: &result))
122 return failure();
123 foo = std::move(result);
124 return success();
125}
126
127static void printFooString(AsmPrinter &printer, StringRef foo) {
128 printer << '"' << foo << '"';
129}
130
131static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
132 return parser.parseKeyword(keyword: foo);
133}
134
135static void printBarString(AsmPrinter &printer, StringRef foo) {
136 printer << foo;
137}
138//===----------------------------------------------------------------------===//
139// Tablegen Generated Definitions
140//===----------------------------------------------------------------------===//
141
142#include "TestTypeInterfaces.cpp.inc"
143#define GET_TYPEDEF_CLASSES
144#include "TestTypeDefs.cpp.inc"
145
146//===----------------------------------------------------------------------===//
147// CompoundAType
148//===----------------------------------------------------------------------===//
149
150Type CompoundAType::parse(AsmParser &parser) {
151 int widthOfSomething;
152 Type oneType;
153 SmallVector<int, 4> arrayOfInts;
154 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
155 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
156 parser.parseLSquare())
157 return Type();
158
159 int i;
160 while (!*parser.parseOptionalInteger(i)) {
161 arrayOfInts.push_back(i);
162 if (parser.parseOptionalComma())
163 break;
164 }
165
166 if (parser.parseRSquare() || parser.parseGreater())
167 return Type();
168
169 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
170}
171void CompoundAType::print(AsmPrinter &printer) const {
172 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
173 auto intArray = getArrayOfInts();
174 llvm::interleaveComma(intArray, printer);
175 printer << "]>";
176}
177
178//===----------------------------------------------------------------------===//
179// TestIntegerType
180//===----------------------------------------------------------------------===//
181
182// Example type validity checker.
183LogicalResult
184TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
185 unsigned width,
186 TestIntegerType::SignednessSemantics ss) {
187 if (width > 8)
188 return failure();
189 return success();
190}
191
192Type TestIntegerType::parse(AsmParser &parser) {
193 SignednessSemantics signedness;
194 int width;
195 if (parser.parseLess() || parseSignedness(parser, signedness) ||
196 parser.parseComma() || parser.parseInteger(width) ||
197 parser.parseGreater())
198 return Type();
199 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
200 return getChecked(loc, loc.getContext(), width, signedness);
201}
202
203void TestIntegerType::print(AsmPrinter &p) const {
204 p << "<";
205 printSignedness(p, getSignedness());
206 p << ", " << getWidth() << ">";
207}
208
209//===----------------------------------------------------------------------===//
210// TestStructType
211//===----------------------------------------------------------------------===//
212
213Type StructType::parse(AsmParser &p) {
214 SmallVector<FieldInfo, 4> parameters;
215 if (p.parseLess())
216 return Type();
217 while (succeeded(p.parseOptionalLBrace())) {
218 Type type;
219 StringRef name;
220 if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
221 p.parseRBrace())
222 return Type();
223 parameters.push_back(FieldInfo{name, type});
224 if (p.parseOptionalComma())
225 break;
226 }
227 if (p.parseGreater())
228 return Type();
229 return get(p.getContext(), parameters);
230}
231
232void StructType::print(AsmPrinter &p) const {
233 p << "<";
234 llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
235 p << "{" << field.name << "," << field.type << "}";
236 });
237 p << ">";
238}
239
240//===----------------------------------------------------------------------===//
241// TestType
242//===----------------------------------------------------------------------===//
243
244void TestType::printTypeC(Location loc) const {
245 emitRemark(loc) << *this << " - TestC";
246}
247
248//===----------------------------------------------------------------------===//
249// TestTypeWithLayout
250//===----------------------------------------------------------------------===//
251
252Type TestTypeWithLayoutType::parse(AsmParser &parser) {
253 unsigned val;
254 if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
255 return Type();
256 return TestTypeWithLayoutType::get(parser.getContext(), val);
257}
258
259void TestTypeWithLayoutType::print(AsmPrinter &printer) const {
260 printer << "<" << getKey() << ">";
261}
262
263llvm::TypeSize
264TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
265 DataLayoutEntryListRef params) const {
266 return llvm::TypeSize::getFixed(extractKind(params, "size"));
267}
268
269uint64_t
270TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
271 DataLayoutEntryListRef params) const {
272 return extractKind(params, "alignment");
273}
274
275uint64_t TestTypeWithLayoutType::getPreferredAlignment(
276 const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
277 return extractKind(params, "preferred");
278}
279
280std::optional<uint64_t>
281TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout,
282 DataLayoutEntryListRef params) const {
283 return extractKind(params, "index");
284}
285
286bool TestTypeWithLayoutType::areCompatible(
287 DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
288 unsigned old = extractKind(oldLayout, "alignment");
289 return old == 1 || extractKind(newLayout, "alignment") <= old;
290}
291
292LogicalResult
293TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
294 Location loc) const {
295 for (DataLayoutEntryInterface entry : params) {
296 // This is for testing purposes only, so assert well-formedness.
297 assert(entry.isTypeEntry() && "unexpected identifier entry");
298 assert(llvm::isa<TestTypeWithLayoutType>(entry.getKey().get<Type>()) &&
299 "wrong type passed in");
300 auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue());
301 assert(array && array.getValue().size() == 2 &&
302 "expected array of two elements");
303 auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front());
304 (void)kind;
305 assert(kind &&
306 (kind.getValue() == "size" || kind.getValue() == "alignment" ||
307 kind.getValue() == "preferred" || kind.getValue() == "index") &&
308 "unexpected kind");
309 assert(llvm::isa<IntegerAttr>(array.getValue().back()));
310 }
311 return success();
312}
313
314uint64_t TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
315 StringRef expectedKind) const {
316 for (DataLayoutEntryInterface entry : params) {
317 ArrayRef<Attribute> pair =
318 llvm::cast<ArrayAttr>(entry.getValue()).getValue();
319 StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue();
320 if (kind == expectedKind)
321 return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue();
322 }
323 return 1;
324}
325
326//===----------------------------------------------------------------------===//
327// Dynamic Types
328//===----------------------------------------------------------------------===//
329
330/// Define a singleton dynamic type.
331static std::unique_ptr<DynamicTypeDefinition>
332getSingletonDynamicType(TestDialect *testDialect) {
333 return DynamicTypeDefinition::get(
334 "dynamic_singleton", testDialect,
335 [](function_ref<InFlightDiagnostic()> emitError,
336 ArrayRef<Attribute> args) {
337 if (!args.empty()) {
338 emitError() << "expected 0 type arguments, but had " << args.size();
339 return failure();
340 }
341 return success();
342 });
343}
344
345/// Define a dynamic type representing a pair.
346static std::unique_ptr<DynamicTypeDefinition>
347getPairDynamicType(TestDialect *testDialect) {
348 return DynamicTypeDefinition::get(
349 "dynamic_pair", testDialect,
350 [](function_ref<InFlightDiagnostic()> emitError,
351 ArrayRef<Attribute> args) {
352 if (args.size() != 2) {
353 emitError() << "expected 2 type arguments, but had " << args.size();
354 return failure();
355 }
356 return success();
357 });
358}
359
360static std::unique_ptr<DynamicTypeDefinition>
361getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
362 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
363 ArrayRef<Attribute> args) {
364 if (args.size() != 2) {
365 emitError() << "expected 2 type arguments, but had " << args.size();
366 return failure();
367 }
368 return success();
369 };
370
371 auto parser = [](AsmParser &parser,
372 llvm::SmallVectorImpl<Attribute> &parsedParams) {
373 Attribute leftAttr, rightAttr;
374 if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) ||
375 parser.parseColon() || parser.parseAttribute(result&: rightAttr) ||
376 parser.parseGreater())
377 return failure();
378 parsedParams.push_back(Elt: leftAttr);
379 parsedParams.push_back(Elt: rightAttr);
380 return success();
381 };
382
383 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
384 printer << "<" << params[0] << ":" << params[1] << ">";
385 };
386
387 return DynamicTypeDefinition::get("dynamic_custom_assembly_format",
388 testDialect, std::move(verifier),
389 std::move(parser), std::move(printer));
390}
391
392//===----------------------------------------------------------------------===//
393// TestDialect
394//===----------------------------------------------------------------------===//
395
396namespace {
397
398struct PtrElementModel
399 : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
400 SimpleAType> {};
401} // namespace
402
403void TestDialect::registerTypes() {
404 addTypes<TestRecursiveType,
405#define GET_TYPEDEF_LIST
406#include "TestTypeDefs.cpp.inc"
407 >();
408 SimpleAType::attachInterface<PtrElementModel>(*getContext());
409
410 registerDynamicType(getSingletonDynamicType(this));
411 registerDynamicType(getPairDynamicType(this));
412 registerDynamicType(getCustomAssemblyFormatDynamicType(this));
413}
414
415Type TestDialect::parseType(DialectAsmParser &parser) const {
416 StringRef typeTag;
417 {
418 Type genType;
419 auto parseResult = generatedTypeParser(parser, &typeTag, genType);
420 if (parseResult.has_value())
421 return genType;
422 }
423
424 {
425 Type dynType;
426 auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
427 if (parseResult.has_value()) {
428 if (succeeded(parseResult.value()))
429 return dynType;
430 return Type();
431 }
432 }
433
434 if (typeTag != "test_rec") {
435 parser.emitError(parser.getNameLoc()) << "unknown type!";
436 return Type();
437 }
438
439 StringRef name;
440 if (parser.parseLess() || parser.parseKeyword(&name))
441 return Type();
442 auto rec = TestRecursiveType::get(parser.getContext(), name);
443
444 FailureOr<AsmParser::CyclicParseReset> cyclicParse =
445 parser.tryStartCyclicParse(rec);
446
447 // If this type already has been parsed above in the stack, expect just the
448 // name.
449 if (failed(cyclicParse)) {
450 if (failed(parser.parseGreater()))
451 return Type();
452 return rec;
453 }
454
455 // Otherwise, parse the body and update the type.
456 if (failed(parser.parseComma()))
457 return Type();
458 Type subtype = parseType(parser);
459 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
460 return Type();
461
462 return rec;
463}
464
465void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
466 if (succeeded(generatedTypePrinter(type, printer)))
467 return;
468
469 if (succeeded(printIfDynamicType(type, printer)))
470 return;
471
472 auto rec = llvm::cast<TestRecursiveType>(type);
473
474 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
475 printer.tryStartCyclicPrint(rec);
476
477 printer << "test_rec<" << rec.getName();
478 if (succeeded(cyclicPrint)) {
479 printer << ", ";
480 printType(rec.getBody(), printer);
481 }
482 printer << ">";
483}
484
485Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
486
487void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
488
489StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
490
491Type TestRecursiveAliasType::parse(AsmParser &parser) {
492 StringRef name;
493 if (parser.parseLess() || parser.parseKeyword(&name))
494 return Type();
495 auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
496
497 FailureOr<AsmParser::CyclicParseReset> cyclicParse =
498 parser.tryStartCyclicParse(rec);
499
500 // If this type already has been parsed above in the stack, expect just the
501 // name.
502 if (failed(cyclicParse)) {
503 if (failed(parser.parseGreater()))
504 return Type();
505 return rec;
506 }
507
508 // Otherwise, parse the body and update the type.
509 if (failed(parser.parseComma()))
510 return Type();
511 Type subtype;
512 if (parser.parseType(subtype))
513 return nullptr;
514 if (!subtype || failed(parser.parseGreater()))
515 return Type();
516
517 rec.setBody(subtype);
518
519 return rec;
520}
521
522void TestRecursiveAliasType::print(AsmPrinter &printer) const {
523
524 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
525 printer.tryStartCyclicPrint(*this);
526
527 printer << "<" << getName();
528 if (succeeded(cyclicPrint)) {
529 printer << ", ";
530 printer << getBody();
531 }
532 printer << ">";
533}
534

source code of mlir/test/lib/Dialect/Test/TestTypes.cpp