1//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains types defined by the TestDialect for testing various
10// features of MLIR.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_TESTTYPES_H
15#define MLIR_TESTTYPES_H
16
17#include <optional>
18#include <tuple>
19
20#include "TestTraits.h"
21#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
22#include "mlir/IR/Diagnostics.h"
23#include "mlir/IR/Dialect.h"
24#include "mlir/IR/DialectImplementation.h"
25#include "mlir/IR/Operation.h"
26#include "mlir/IR/Types.h"
27#include "mlir/Interfaces/DataLayoutInterfaces.h"
28
29namespace test {
30class TestAttrWithFormatAttr;
31
32/// FieldInfo represents a field in the StructType data type. It is used as a
33/// parameter in TestTypeDefs.td.
34struct FieldInfo {
35 llvm::StringRef name;
36 mlir::Type type;
37
38 // Custom allocation called from generated constructor code
39 FieldInfo allocateInto(mlir::TypeStorageAllocator &alloc) const {
40 return FieldInfo{.name: alloc.copyInto(str: name), .type: type};
41 }
42};
43
44/// A custom type for a test type parameter.
45struct CustomParam {
46 int value;
47
48 bool operator==(const CustomParam &other) const {
49 return other.value == value;
50 }
51};
52
53inline llvm::hash_code hash_value(const test::CustomParam &param) {
54 return llvm::hash_value(value: param.value);
55}
56
57} // namespace test
58
59namespace mlir {
60template <>
61struct FieldParser<test::CustomParam> {
62 static FailureOr<test::CustomParam> parse(AsmParser &parser) {
63 auto value = FieldParser<int>::parse(parser);
64 if (failed(Result: value))
65 return failure();
66 return test::CustomParam{.value: *value};
67 }
68};
69
70inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
71 test::CustomParam param) {
72 return printer << param.value;
73}
74
75/// Overload the attribute parameter parser for optional integers.
76template <>
77struct FieldParser<std::optional<int>> {
78 static FailureOr<std::optional<int>> parse(AsmParser &parser) {
79 std::optional<int> value;
80 value.emplace();
81 OptionalParseResult result = parser.parseOptionalInteger(result&: *value);
82 if (result.has_value()) {
83 if (succeeded(Result: *result))
84 return value;
85 return failure();
86 }
87 value.reset();
88 return value;
89 }
90};
91} // namespace mlir
92
93#include "TestTypeInterfaces.h.inc"
94
95namespace test {
96
97/// Storage for simple named recursive types, where the type is identified by
98/// its name and can "contain" another type, including itself.
99struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
100 using KeyTy = ::llvm::StringRef;
101
102 explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key) {}
103
104 bool operator==(const KeyTy &other) const { return name == other; }
105
106 static TestRecursiveTypeStorage *
107 construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
108 return new (allocator.allocate<TestRecursiveTypeStorage>())
109 TestRecursiveTypeStorage(allocator.copyInto(str: key));
110 }
111
112 ::llvm::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
113 ::mlir::Type newBody) {
114 // Cannot set a different body than before.
115 if (body && body != newBody)
116 return ::mlir::failure();
117
118 body = newBody;
119 return ::mlir::success();
120 }
121
122 ::llvm::StringRef name;
123 ::mlir::Type body;
124};
125
126/// Simple recursive type identified by its name and pointing to another named
127/// type, potentially itself. This requires the body to be mutated separately
128/// from type creation.
129class TestRecursiveType
130 : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
131 TestRecursiveTypeStorage,
132 ::mlir::TypeTrait::IsMutable> {
133public:
134 using Base::Base;
135
136 static constexpr ::mlir::StringLiteral name = "test.recursive";
137
138 static TestRecursiveType get(::mlir::MLIRContext *ctx,
139 ::llvm::StringRef name) {
140 return Base::get(ctx, name);
141 }
142
143 /// Body getter and setter.
144 ::llvm::LogicalResult setBody(Type body) { return Base::mutate(body); }
145 ::mlir::Type getBody() const { return getImpl()->body; }
146
147 /// Name/key getter.
148 ::llvm::StringRef getName() { return getImpl()->name; }
149};
150
151} // namespace test
152
153#define GET_TYPEDEF_CLASSES
154#include "TestTypeDefs.h.inc"
155
156#endif // MLIR_TESTTYPES_H
157

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Test/TestTypes.h