1//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains types defined by the TestDialect for testing various
10// features of MLIR.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestTypes.h"
15#include "TestDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/DialectImplementation.h"
19#include "mlir/IR/ExtensibleDialect.h"
20#include "mlir/IR/Types.h"
21#include "llvm/ADT/Hashing.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/TypeSize.h"
25#include <optional>
26
27using namespace mlir;
28using namespace test;
29
30// Custom parser for SignednessSemantics.
31static ParseResult
32parseSignedness(AsmParser &parser,
33 TestIntegerType::SignednessSemantics &result) {
34 StringRef signStr;
35 auto loc = parser.getCurrentLocation();
36 if (parser.parseKeyword(keyword: &signStr))
37 return failure();
38 if (signStr.equals_insensitive(RHS: "u") || signStr.equals_insensitive(RHS: "unsigned"))
39 result = TestIntegerType::SignednessSemantics::Unsigned;
40 else if (signStr.equals_insensitive(RHS: "s") ||
41 signStr.equals_insensitive(RHS: "signed"))
42 result = TestIntegerType::SignednessSemantics::Signed;
43 else if (signStr.equals_insensitive(RHS: "n") ||
44 signStr.equals_insensitive(RHS: "none"))
45 result = TestIntegerType::SignednessSemantics::Signless;
46 else
47 return parser.emitError(loc, message: "expected signed, unsigned, or none");
48 return success();
49}
50
51// Custom printer for SignednessSemantics.
52static void printSignedness(AsmPrinter &printer,
53 const TestIntegerType::SignednessSemantics &ss) {
54 switch (ss) {
55 case TestIntegerType::SignednessSemantics::Unsigned:
56 printer << "unsigned";
57 break;
58 case TestIntegerType::SignednessSemantics::Signed:
59 printer << "signed";
60 break;
61 case TestIntegerType::SignednessSemantics::Signless:
62 printer << "none";
63 break;
64 }
65}
66
67// The functions don't need to be in the header file, but need to be in the mlir
68// namespace. Declare them here, then define them immediately below. Separating
69// the declaration and definition adheres to the LLVM coding standards.
70namespace test {
71// FieldInfo is used as part of a parameter, so equality comparison is
72// compulsory.
73static bool operator==(const FieldInfo &a, const FieldInfo &b);
74// FieldInfo is used as part of a parameter, so a hash will be computed.
75static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
76} // namespace test
77
78// FieldInfo is used as part of a parameter, so equality comparison is
79// compulsory.
80static bool test::operator==(const FieldInfo &a, const FieldInfo &b) {
81 return a.name == b.name && a.type == b.type;
82}
83
84// FieldInfo is used as part of a parameter, so a hash will be computed.
85static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
86 return llvm::hash_combine(args: fi.name, args: fi.type);
87}
88
89//===----------------------------------------------------------------------===//
90// TestCustomType
91//===----------------------------------------------------------------------===//
92
93static ParseResult parseCustomTypeA(AsmParser &parser, int &aResult) {
94 return parser.parseInteger(result&: aResult);
95}
96
97static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
98
99static ParseResult parseCustomTypeB(AsmParser &parser, int a,
100 std::optional<int> &bResult) {
101 if (a < 0)
102 return success();
103 for (int i : llvm::seq(Begin: 0, End: a))
104 if (failed(Result: parser.parseInteger(result&: i)))
105 return failure();
106 bResult.emplace(args: 0);
107 return parser.parseInteger(result&: *bResult);
108}
109
110static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) {
111 if (a < 0)
112 return;
113 printer << ' ';
114 for (int i : llvm::seq(Begin: 0, End: a))
115 printer << i << ' ';
116 printer << *b;
117}
118
119static ParseResult parseFooString(AsmParser &parser, std::string &foo) {
120 std::string result;
121 if (parser.parseString(string: &result))
122 return failure();
123 foo = std::move(result);
124 return success();
125}
126
127static void printFooString(AsmPrinter &printer, StringRef foo) {
128 printer << '"' << foo << '"';
129}
130
131static ParseResult parseBarString(AsmParser &parser, StringRef foo) {
132 return parser.parseKeyword(keyword: foo);
133}
134
135static void printBarString(AsmPrinter &printer, StringRef foo) {
136 printer << foo;
137}
138//===----------------------------------------------------------------------===//
139// Tablegen Generated Definitions
140//===----------------------------------------------------------------------===//
141
142#include "TestTypeInterfaces.cpp.inc"
143#define GET_TYPEDEF_CLASSES
144#include "TestTypeDefs.cpp.inc"
145
146//===----------------------------------------------------------------------===//
147// CompoundAType
148//===----------------------------------------------------------------------===//
149
150Type CompoundAType::parse(AsmParser &parser) {
151 int widthOfSomething;
152 Type oneType;
153 SmallVector<int, 4> arrayOfInts;
154 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
155 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
156 parser.parseLSquare())
157 return Type();
158
159 int i;
160 while (!*parser.parseOptionalInteger(i)) {
161 arrayOfInts.push_back(i);
162 if (parser.parseOptionalComma())
163 break;
164 }
165
166 if (parser.parseRSquare() || parser.parseGreater())
167 return Type();
168
169 return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
170}
171void CompoundAType::print(AsmPrinter &printer) const {
172 printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
173 auto intArray = getArrayOfInts();
174 llvm::interleaveComma(intArray, printer);
175 printer << "]>";
176}
177
178//===----------------------------------------------------------------------===//
179// TestIntegerType
180//===----------------------------------------------------------------------===//
181
182// Example type validity checker.
183LogicalResult
184TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
185 unsigned width,
186 TestIntegerType::SignednessSemantics ss) {
187 if (width > 8)
188 return failure();
189 return success();
190}
191
192Type TestIntegerType::parse(AsmParser &parser) {
193 SignednessSemantics signedness;
194 int width;
195 if (parser.parseLess() || parseSignedness(parser, signedness) ||
196 parser.parseComma() || parser.parseInteger(width) ||
197 parser.parseGreater())
198 return Type();
199 Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
200 return getChecked(loc, loc.getContext(), width, signedness);
201}
202
203void TestIntegerType::print(AsmPrinter &p) const {
204 p << "<";
205 printSignedness(p, getSignedness());
206 p << ", " << getWidth() << ">";
207}
208
209//===----------------------------------------------------------------------===//
210// TestStructType
211//===----------------------------------------------------------------------===//
212
213Type StructType::parse(AsmParser &p) {
214 SmallVector<FieldInfo, 4> parameters;
215 if (p.parseLess())
216 return Type();
217 while (succeeded(p.parseOptionalLBrace())) {
218 Type type;
219 StringRef name;
220 if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
221 p.parseRBrace())
222 return Type();
223 parameters.push_back(FieldInfo{name, type});
224 if (p.parseOptionalComma())
225 break;
226 }
227 if (p.parseGreater())
228 return Type();
229 return get(p.getContext(), parameters);
230}
231
232void StructType::print(AsmPrinter &p) const {
233 p << "<";
234 llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
235 p << "{" << field.name << "," << field.type << "}";
236 });
237 p << ">";
238}
239
240//===----------------------------------------------------------------------===//
241// TestType
242//===----------------------------------------------------------------------===//
243
244void TestType::printTypeC(Location loc) const {
245 emitRemark(loc) << *this << " - TestC";
246}
247
248//===----------------------------------------------------------------------===//
249// TestTypeWithLayout
250//===----------------------------------------------------------------------===//
251
252Type TestTypeWithLayoutType::parse(AsmParser &parser) {
253 unsigned val;
254 if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
255 return Type();
256 return TestTypeWithLayoutType::get(parser.getContext(), val);
257}
258
259void TestTypeWithLayoutType::print(AsmPrinter &printer) const {
260 printer << "<" << getKey() << ">";
261}
262
263llvm::TypeSize
264TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
265 DataLayoutEntryListRef params) const {
266 return llvm::TypeSize::getFixed(extractKind(params, "size"));
267}
268
269uint64_t
270TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
271 DataLayoutEntryListRef params) const {
272 return extractKind(params, "alignment");
273}
274
275uint64_t TestTypeWithLayoutType::getPreferredAlignment(
276 const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
277 return extractKind(params, "preferred");
278}
279
280std::optional<uint64_t>
281TestTypeWithLayoutType::getIndexBitwidth(const DataLayout &dataLayout,
282 DataLayoutEntryListRef params) const {
283 return extractKind(params, "index");
284}
285
286bool TestTypeWithLayoutType::areCompatible(
287 DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout,
288 DataLayoutSpecInterface newSpec,
289 const DataLayoutIdentifiedEntryMap &map) const {
290 unsigned old = extractKind(oldLayout, "alignment");
291 return old == 1 || extractKind(newLayout, "alignment") <= old;
292}
293
294LogicalResult
295TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
296 Location loc) const {
297 for (DataLayoutEntryInterface entry : params) {
298 // This is for testing purposes only, so assert well-formedness.
299 assert(entry.isTypeEntry() && "unexpected identifier entry");
300 assert(
301 llvm::isa<TestTypeWithLayoutType>(llvm::cast<Type>(entry.getKey())) &&
302 "wrong type passed in");
303 auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue());
304 assert(array && array.getValue().size() == 2 &&
305 "expected array of two elements");
306 auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front());
307 (void)kind;
308 assert(kind &&
309 (kind.getValue() == "size" || kind.getValue() == "alignment" ||
310 kind.getValue() == "preferred" || kind.getValue() == "index") &&
311 "unexpected kind");
312 assert(llvm::isa<IntegerAttr>(array.getValue().back()));
313 }
314 return success();
315}
316
317uint64_t TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
318 StringRef expectedKind) const {
319 for (DataLayoutEntryInterface entry : params) {
320 ArrayRef<Attribute> pair =
321 llvm::cast<ArrayAttr>(entry.getValue()).getValue();
322 StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue();
323 if (kind == expectedKind)
324 return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue();
325 }
326 return 1;
327}
328
329//===----------------------------------------------------------------------===//
330// Dynamic Types
331//===----------------------------------------------------------------------===//
332
333/// Define a singleton dynamic type.
334static std::unique_ptr<DynamicTypeDefinition>
335getSingletonDynamicType(TestDialect *testDialect) {
336 return DynamicTypeDefinition::get(
337 "dynamic_singleton", testDialect,
338 [](function_ref<InFlightDiagnostic()> emitError,
339 ArrayRef<Attribute> args) {
340 if (!args.empty()) {
341 emitError() << "expected 0 type arguments, but had " << args.size();
342 return failure();
343 }
344 return success();
345 });
346}
347
348/// Define a dynamic type representing a pair.
349static std::unique_ptr<DynamicTypeDefinition>
350getPairDynamicType(TestDialect *testDialect) {
351 return DynamicTypeDefinition::get(
352 "dynamic_pair", testDialect,
353 [](function_ref<InFlightDiagnostic()> emitError,
354 ArrayRef<Attribute> args) {
355 if (args.size() != 2) {
356 emitError() << "expected 2 type arguments, but had " << args.size();
357 return failure();
358 }
359 return success();
360 });
361}
362
363static std::unique_ptr<DynamicTypeDefinition>
364getCustomAssemblyFormatDynamicType(TestDialect *testDialect) {
365 auto verifier = [](function_ref<InFlightDiagnostic()> emitError,
366 ArrayRef<Attribute> args) {
367 if (args.size() != 2) {
368 emitError() << "expected 2 type arguments, but had " << args.size();
369 return failure();
370 }
371 return success();
372 };
373
374 auto parser = [](AsmParser &parser,
375 llvm::SmallVectorImpl<Attribute> &parsedParams) {
376 Attribute leftAttr, rightAttr;
377 if (parser.parseLess() || parser.parseAttribute(result&: leftAttr) ||
378 parser.parseColon() || parser.parseAttribute(result&: rightAttr) ||
379 parser.parseGreater())
380 return failure();
381 parsedParams.push_back(Elt: leftAttr);
382 parsedParams.push_back(Elt: rightAttr);
383 return success();
384 };
385
386 auto printer = [](AsmPrinter &printer, ArrayRef<Attribute> params) {
387 printer << "<" << params[0] << ":" << params[1] << ">";
388 };
389
390 return DynamicTypeDefinition::get("dynamic_custom_assembly_format",
391 testDialect, std::move(verifier),
392 std::move(parser), std::move(printer));
393}
394
395//===----------------------------------------------------------------------===//
396// TestDialect
397//===----------------------------------------------------------------------===//
398
399namespace {
400
401struct PtrElementModel
402 : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
403 SimpleAType> {};
404} // namespace
405
406void TestDialect::registerTypes() {
407 addTypes<TestRecursiveType,
408#define GET_TYPEDEF_LIST
409#include "TestTypeDefs.cpp.inc"
410 >();
411 SimpleAType::attachInterface<PtrElementModel>(*getContext());
412
413 registerDynamicType(getSingletonDynamicType(this));
414 registerDynamicType(getPairDynamicType(this));
415 registerDynamicType(getCustomAssemblyFormatDynamicType(this));
416}
417
418Type TestDialect::parseType(DialectAsmParser &parser) const {
419 StringRef typeTag;
420 {
421 Type genType;
422 auto parseResult = generatedTypeParser(parser, &typeTag, genType);
423 if (parseResult.has_value())
424 return genType;
425 }
426
427 {
428 Type dynType;
429 auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
430 if (parseResult.has_value()) {
431 if (succeeded(parseResult.value()))
432 return dynType;
433 return Type();
434 }
435 }
436
437 if (typeTag != "test_rec") {
438 parser.emitError(parser.getNameLoc()) << "unknown type!";
439 return Type();
440 }
441
442 StringRef name;
443 if (parser.parseLess() || parser.parseKeyword(&name))
444 return Type();
445 auto rec = TestRecursiveType::get(parser.getContext(), name);
446
447 FailureOr<AsmParser::CyclicParseReset> cyclicParse =
448 parser.tryStartCyclicParse(rec);
449
450 // If this type already has been parsed above in the stack, expect just the
451 // name.
452 if (failed(cyclicParse)) {
453 if (failed(parser.parseGreater()))
454 return Type();
455 return rec;
456 }
457
458 // Otherwise, parse the body and update the type.
459 if (failed(parser.parseComma()))
460 return Type();
461 Type subtype = parseType(parser);
462 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
463 return Type();
464
465 return rec;
466}
467
468void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
469 if (succeeded(generatedTypePrinter(type, printer)))
470 return;
471
472 if (succeeded(printIfDynamicType(type, printer)))
473 return;
474
475 auto rec = llvm::cast<TestRecursiveType>(type);
476
477 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
478 printer.tryStartCyclicPrint(rec);
479
480 printer << "test_rec<" << rec.getName();
481 if (succeeded(cyclicPrint)) {
482 printer << ", ";
483 printType(rec.getBody(), printer);
484 }
485 printer << ">";
486}
487
488Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
489
490void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
491
492StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
493
494Type TestRecursiveAliasType::parse(AsmParser &parser) {
495 StringRef name;
496 if (parser.parseLess() || parser.parseKeyword(&name))
497 return Type();
498 auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
499
500 FailureOr<AsmParser::CyclicParseReset> cyclicParse =
501 parser.tryStartCyclicParse(rec);
502
503 // If this type already has been parsed above in the stack, expect just the
504 // name.
505 if (failed(cyclicParse)) {
506 if (failed(parser.parseGreater()))
507 return Type();
508 return rec;
509 }
510
511 // Otherwise, parse the body and update the type.
512 if (failed(parser.parseComma()))
513 return Type();
514 Type subtype;
515 if (parser.parseType(subtype))
516 return nullptr;
517 if (!subtype || failed(parser.parseGreater()))
518 return Type();
519
520 rec.setBody(subtype);
521
522 return rec;
523}
524
525void TestRecursiveAliasType::print(AsmPrinter &printer) const {
526
527 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint =
528 printer.tryStartCyclicPrint(*this);
529
530 printer << "<" << getName();
531 if (succeeded(cyclicPrint)) {
532 printer << ", ";
533 printer << getBody();
534 }
535 printer << ">";
536}
537
538void TestTypeOpAsmTypeInterfaceType::getAsmName(
539 OpAsmSetNameFn setNameFn) const {
540 setNameFn("op_asm_type_interface");
541}
542
543::mlir::OpAsmDialectInterface::AliasResult
544TestTypeOpAsmTypeInterfaceType::getAlias(::llvm::raw_ostream &os) const {
545 os << "op_asm_type_interface_type";
546 return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
547}
548

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Test/TestTypes.cpp