1//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains types defined by the TestDialect for testing various
10// features of MLIR.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_TESTTYPES_H
15#define MLIR_TESTTYPES_H
16
17#include <optional>
18#include <tuple>
19
20#include "TestTraits.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/DialectImplementation.h"
24#include "mlir/IR/Operation.h"
25#include "mlir/IR/Types.h"
26#include "mlir/Interfaces/DataLayoutInterfaces.h"
27
28namespace test {
29class TestAttrWithFormatAttr;
30
31/// FieldInfo represents a field in the StructType data type. It is used as a
32/// parameter in TestTypeDefs.td.
33struct FieldInfo {
34 llvm::StringRef name;
35 mlir::Type type;
36
37 // Custom allocation called from generated constructor code
38 FieldInfo allocateInto(mlir::TypeStorageAllocator &alloc) const {
39 return FieldInfo{.name: alloc.copyInto(str: name), .type: type};
40 }
41};
42
43/// A custom type for a test type parameter.
44struct CustomParam {
45 int value;
46
47 bool operator==(const CustomParam &other) const {
48 return other.value == value;
49 }
50};
51
52inline llvm::hash_code hash_value(const test::CustomParam &param) {
53 return llvm::hash_value(value: param.value);
54}
55
56} // namespace test
57
58namespace mlir {
59template <>
60struct FieldParser<test::CustomParam> {
61 static FailureOr<test::CustomParam> parse(AsmParser &parser) {
62 auto value = FieldParser<int>::parse(parser);
63 if (failed(value))
64 return failure();
65 return test::CustomParam{*value};
66 }
67};
68
69inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
70 test::CustomParam param) {
71 return printer << param.value;
72}
73
74/// Overload the attribute parameter parser for optional integers.
75template <>
76struct FieldParser<std::optional<int>> {
77 static FailureOr<std::optional<int>> parse(AsmParser &parser) {
78 std::optional<int> value;
79 value.emplace();
80 OptionalParseResult result = parser.parseOptionalInteger(result&: *value);
81 if (result.has_value()) {
82 if (succeeded(result: *result))
83 return value;
84 return failure();
85 }
86 value.reset();
87 return value;
88 }
89};
90} // namespace mlir
91
92#include "TestTypeInterfaces.h.inc"
93
94namespace test {
95
96/// Storage for simple named recursive types, where the type is identified by
97/// its name and can "contain" another type, including itself.
98struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
99 using KeyTy = ::llvm::StringRef;
100
101 explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key) {}
102
103 bool operator==(const KeyTy &other) const { return name == other; }
104
105 static TestRecursiveTypeStorage *
106 construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
107 return new (allocator.allocate<TestRecursiveTypeStorage>())
108 TestRecursiveTypeStorage(allocator.copyInto(str: key));
109 }
110
111 ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
112 ::mlir::Type newBody) {
113 // Cannot set a different body than before.
114 if (body && body != newBody)
115 return ::mlir::failure();
116
117 body = newBody;
118 return ::mlir::success();
119 }
120
121 ::llvm::StringRef name;
122 ::mlir::Type body;
123};
124
125/// Simple recursive type identified by its name and pointing to another named
126/// type, potentially itself. This requires the body to be mutated separately
127/// from type creation.
128class TestRecursiveType
129 : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
130 TestRecursiveTypeStorage,
131 ::mlir::TypeTrait::IsMutable> {
132public:
133 using Base::Base;
134
135 static constexpr ::mlir::StringLiteral name = "test.recursive";
136
137 static TestRecursiveType get(::mlir::MLIRContext *ctx,
138 ::llvm::StringRef name) {
139 return Base::get(ctx, name);
140 }
141
142 /// Body getter and setter.
143 ::mlir::LogicalResult setBody(Type body) { return Base::mutate(body); }
144 ::mlir::Type getBody() const { return getImpl()->body; }
145
146 /// Name/key getter.
147 ::llvm::StringRef getName() { return getImpl()->name; }
148};
149
150} // namespace test
151
152#define GET_TYPEDEF_CLASSES
153#include "TestTypeDefs.h.inc"
154
155#endif // MLIR_TESTTYPES_H
156

source code of mlir/test/lib/Dialect/Test/TestTypes.h