1 | //===- DialectTest.cpp - Dialect unit tests -------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/Dialect.h" |
10 | #include "mlir/IR/DialectInterface.h" |
11 | #include "mlir/Support/TypeID.h" |
12 | #include "gtest/gtest.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::detail; |
16 | |
17 | namespace { |
18 | struct TestDialect : public Dialect { |
19 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) |
20 | |
21 | static StringRef getDialectNamespace() { return "test" ; }; |
22 | TestDialect(MLIRContext *context) |
23 | : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} |
24 | }; |
25 | struct AnotherTestDialect : public Dialect { |
26 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect) |
27 | |
28 | static StringRef getDialectNamespace() { return "test" ; }; |
29 | AnotherTestDialect(MLIRContext *context) |
30 | : Dialect(getDialectNamespace(), context, |
31 | TypeID::get<AnotherTestDialect>()) {} |
32 | }; |
33 | |
34 | TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { |
35 | MLIRContext context; |
36 | |
37 | // Registering a dialect with the same namespace twice should result in a |
38 | // failure. |
39 | context.loadDialect<TestDialect>(); |
40 | ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "" ); |
41 | } |
42 | |
43 | struct SecondTestDialect : public Dialect { |
44 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect) |
45 | |
46 | static StringRef getDialectNamespace() { return "test2" ; } |
47 | SecondTestDialect(MLIRContext *context) |
48 | : Dialect(getDialectNamespace(), context, |
49 | TypeID::get<SecondTestDialect>()) {} |
50 | }; |
51 | |
52 | struct TestDialectInterfaceBase |
53 | : public DialectInterface::Base<TestDialectInterfaceBase> { |
54 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase) |
55 | |
56 | TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} |
57 | virtual int function() const { return 42; } |
58 | }; |
59 | |
60 | struct TestDialectInterface : public TestDialectInterfaceBase { |
61 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface) |
62 | |
63 | using TestDialectInterfaceBase::TestDialectInterfaceBase; |
64 | int function() const final { return 56; } |
65 | }; |
66 | |
67 | struct SecondTestDialectInterface : public TestDialectInterfaceBase { |
68 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface) |
69 | |
70 | using TestDialectInterfaceBase::TestDialectInterfaceBase; |
71 | int function() const final { return 78; } |
72 | }; |
73 | |
74 | TEST(Dialect, DelayedInterfaceRegistration) { |
75 | DialectRegistry registry; |
76 | registry.insert<TestDialect, SecondTestDialect>(); |
77 | |
78 | // Delayed registration of an interface for TestDialect. |
79 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
80 | dialect->addInterfaces<TestDialectInterface>(); |
81 | }); |
82 | |
83 | MLIRContext context(registry); |
84 | |
85 | // Load the TestDialect and check that the interface got registered for it. |
86 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
87 | ASSERT_TRUE(testDialect != nullptr); |
88 | auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
89 | EXPECT_TRUE(testDialectInterface != nullptr); |
90 | |
91 | // Load the SecondTestDialect and check that the interface is not registered |
92 | // for it. |
93 | Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); |
94 | ASSERT_TRUE(secondTestDialect != nullptr); |
95 | auto *secondTestDialectInterface = |
96 | dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect); |
97 | EXPECT_TRUE(secondTestDialectInterface == nullptr); |
98 | |
99 | // Use the same mechanism as for delayed registration but for an already |
100 | // loaded dialect and check that the interface is now registered. |
101 | DialectRegistry secondRegistry; |
102 | secondRegistry.insert<SecondTestDialect>(); |
103 | secondRegistry.addExtension( |
104 | extensionFn: +[](MLIRContext *ctx, SecondTestDialect *dialect) { |
105 | dialect->addInterfaces<SecondTestDialectInterface>(); |
106 | }); |
107 | context.appendDialectRegistry(registry: secondRegistry); |
108 | secondTestDialectInterface = |
109 | dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect); |
110 | EXPECT_TRUE(secondTestDialectInterface != nullptr); |
111 | } |
112 | |
113 | TEST(Dialect, RepeatedDelayedRegistration) { |
114 | // Set up the delayed registration. |
115 | DialectRegistry registry; |
116 | registry.insert<TestDialect>(); |
117 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
118 | dialect->addInterfaces<TestDialectInterface>(); |
119 | }); |
120 | MLIRContext context(registry); |
121 | |
122 | // Load the TestDialect and check that the interface got registered for it. |
123 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
124 | ASSERT_TRUE(testDialect != nullptr); |
125 | auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
126 | EXPECT_TRUE(testDialectInterface != nullptr); |
127 | |
128 | // Try adding the same dialect interface again and check that we don't crash |
129 | // on repeated interface registration. |
130 | DialectRegistry secondRegistry; |
131 | secondRegistry.insert<TestDialect>(); |
132 | secondRegistry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
133 | dialect->addInterfaces<TestDialectInterface>(); |
134 | }); |
135 | context.appendDialectRegistry(registry: secondRegistry); |
136 | testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
137 | EXPECT_TRUE(testDialectInterface != nullptr); |
138 | } |
139 | |
140 | namespace { |
141 | /// A dummy extension that increases a counter when being applied and |
142 | /// recursively adds additional extensions. |
143 | struct DummyExtension : DialectExtension<DummyExtension, TestDialect> { |
144 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension) |
145 | |
146 | DummyExtension(int *counter, int numRecursive) |
147 | : DialectExtension(), counter(counter), numRecursive(numRecursive) {} |
148 | |
149 | void apply(MLIRContext *ctx, TestDialect *dialect) const final { |
150 | ++(*counter); |
151 | DialectRegistry nestedRegistry; |
152 | for (int i = 0; i < numRecursive; ++i) { |
153 | // Create unique TypeIDs for these recursive extensions so they don't get |
154 | // de-duplicated. |
155 | auto extension = |
156 | std::make_unique<DummyExtension>(args: counter, /*numRecursive=*/args: 0); |
157 | auto typeID = TypeID::getFromOpaquePointer(pointer: extension.get()); |
158 | nestedRegistry.addExtension(extensionID: typeID, extension: std::move(extension)); |
159 | } |
160 | // Adding additional extensions may trigger a reallocation of the |
161 | // `extensions` vector in the dialect registry. |
162 | ctx->appendDialectRegistry(registry: nestedRegistry); |
163 | } |
164 | |
165 | private: |
166 | int *counter; |
167 | int numRecursive; |
168 | }; |
169 | } // namespace |
170 | |
171 | TEST(Dialect, NestedDialectExtension) { |
172 | DialectRegistry registry; |
173 | registry.insert<TestDialect>(); |
174 | |
175 | // Add an extension that adds 100 more extensions. |
176 | int counter1 = 0; |
177 | registry.addExtension(extensionID: TypeID::get<DummyExtension>(), |
178 | extension: std::make_unique<DummyExtension>(args: &counter1, args: 100)); |
179 | // Add one more extension. This should not crash. |
180 | int counter2 = 0; |
181 | registry.addExtension(extensionID: TypeID::getFromOpaquePointer(pointer: &counter2), |
182 | extension: std::make_unique<DummyExtension>(args: &counter2, args: 0)); |
183 | |
184 | // Load dialect and apply extensions. |
185 | MLIRContext context(registry); |
186 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
187 | ASSERT_TRUE(testDialect != nullptr); |
188 | |
189 | // Extensions are de-duplicated by typeID. Make sure that each expected |
190 | // extension was applied at least once. |
191 | EXPECT_GE(counter1, 101); |
192 | EXPECT_GE(counter2, 1); |
193 | } |
194 | |
195 | TEST(Dialect, SubsetWithExtensions) { |
196 | DialectRegistry registry1, registry2; |
197 | registry1.insert<TestDialect>(); |
198 | registry2.insert<TestDialect>(); |
199 | |
200 | // Validate that the registries are equivalent. |
201 | ASSERT_TRUE(registry1.isSubsetOf(registry2)); |
202 | ASSERT_TRUE(registry2.isSubsetOf(registry1)); |
203 | |
204 | // Add extensions to registry2. |
205 | int counter = 0; |
206 | registry2.addExtension(extensionID: TypeID::get<DummyExtension>(), |
207 | extension: std::make_unique<DummyExtension>(args: &counter, args: 0)); |
208 | |
209 | // Expect that (1) is a subset of (2) but not the other way around. |
210 | ASSERT_TRUE(registry1.isSubsetOf(registry2)); |
211 | ASSERT_FALSE(registry2.isSubsetOf(registry1)); |
212 | |
213 | // Add extensions to registry1. |
214 | registry1.addExtension(extensionID: TypeID::get<DummyExtension>(), |
215 | extension: std::make_unique<DummyExtension>(args: &counter, args: 0)); |
216 | |
217 | // Expect that (1) and (2) are equivalent. |
218 | ASSERT_TRUE(registry1.isSubsetOf(registry2)); |
219 | ASSERT_TRUE(registry2.isSubsetOf(registry1)); |
220 | |
221 | // Load dialect and apply extensions. |
222 | MLIRContext context(registry1); |
223 | context.getOrLoadDialect<TestDialect>(); |
224 | context.appendDialectRegistry(registry: registry2); |
225 | // Expect that the extension as only invoked once. |
226 | ASSERT_EQ(counter, 1); |
227 | } |
228 | |
229 | } // namespace |
230 | |