1//===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Dialect.h"
10#include "mlir/IR/DialectInterface.h"
11#include "mlir/Support/TypeID.h"
12#include "gtest/gtest.h"
13
14using namespace mlir;
15using namespace mlir::detail;
16
17namespace {
18struct TestDialect : public Dialect {
19 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
20
21 static StringRef getDialectNamespace() { return "test"; };
22 TestDialect(MLIRContext *context)
23 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
24};
25struct AnotherTestDialect : public Dialect {
26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect)
27
28 static StringRef getDialectNamespace() { return "test"; };
29 AnotherTestDialect(MLIRContext *context)
30 : Dialect(getDialectNamespace(), context,
31 TypeID::get<AnotherTestDialect>()) {}
32};
33
34TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
35 MLIRContext context;
36
37 // Registering a dialect with the same namespace twice should result in a
38 // failure.
39 context.loadDialect<TestDialect>();
40 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
41}
42
43struct SecondTestDialect : public Dialect {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect)
45
46 static StringRef getDialectNamespace() { return "test2"; }
47 SecondTestDialect(MLIRContext *context)
48 : Dialect(getDialectNamespace(), context,
49 TypeID::get<SecondTestDialect>()) {}
50};
51
52struct TestDialectInterfaceBase
53 : public DialectInterface::Base<TestDialectInterfaceBase> {
54 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase)
55
56 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
57 virtual int function() const { return 42; }
58};
59
60struct TestDialectInterface : public TestDialectInterfaceBase {
61 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface)
62
63 using TestDialectInterfaceBase::TestDialectInterfaceBase;
64 int function() const final { return 56; }
65};
66
67struct SecondTestDialectInterface : public TestDialectInterfaceBase {
68 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface)
69
70 using TestDialectInterfaceBase::TestDialectInterfaceBase;
71 int function() const final { return 78; }
72};
73
74TEST(Dialect, DelayedInterfaceRegistration) {
75 DialectRegistry registry;
76 registry.insert<TestDialect, SecondTestDialect>();
77
78 // Delayed registration of an interface for TestDialect.
79 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) {
80 dialect->addInterfaces<TestDialectInterface>();
81 });
82
83 MLIRContext context(registry);
84
85 // Load the TestDialect and check that the interface got registered for it.
86 Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
87 ASSERT_TRUE(testDialect != nullptr);
88 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect);
89 EXPECT_TRUE(testDialectInterface != nullptr);
90
91 // Load the SecondTestDialect and check that the interface is not registered
92 // for it.
93 Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
94 ASSERT_TRUE(secondTestDialect != nullptr);
95 auto *secondTestDialectInterface =
96 dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect);
97 EXPECT_TRUE(secondTestDialectInterface == nullptr);
98
99 // Use the same mechanism as for delayed registration but for an already
100 // loaded dialect and check that the interface is now registered.
101 DialectRegistry secondRegistry;
102 secondRegistry.insert<SecondTestDialect>();
103 secondRegistry.addExtension(
104 extensionFn: +[](MLIRContext *ctx, SecondTestDialect *dialect) {
105 dialect->addInterfaces<SecondTestDialectInterface>();
106 });
107 context.appendDialectRegistry(registry: secondRegistry);
108 secondTestDialectInterface =
109 dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect);
110 EXPECT_TRUE(secondTestDialectInterface != nullptr);
111}
112
113TEST(Dialect, RepeatedDelayedRegistration) {
114 // Set up the delayed registration.
115 DialectRegistry registry;
116 registry.insert<TestDialect>();
117 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) {
118 dialect->addInterfaces<TestDialectInterface>();
119 });
120 MLIRContext context(registry);
121
122 // Load the TestDialect and check that the interface got registered for it.
123 Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
124 ASSERT_TRUE(testDialect != nullptr);
125 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect);
126 EXPECT_TRUE(testDialectInterface != nullptr);
127
128 // Try adding the same dialect interface again and check that we don't crash
129 // on repeated interface registration.
130 DialectRegistry secondRegistry;
131 secondRegistry.insert<TestDialect>();
132 secondRegistry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) {
133 dialect->addInterfaces<TestDialectInterface>();
134 });
135 context.appendDialectRegistry(registry: secondRegistry);
136 testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect);
137 EXPECT_TRUE(testDialectInterface != nullptr);
138}
139
140namespace {
141/// A dummy extension that increases a counter when being applied and
142/// recursively adds additional extensions.
143struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
144 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension)
145
146 DummyExtension(int *counter, int numRecursive)
147 : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
148
149 void apply(MLIRContext *ctx, TestDialect *dialect) const final {
150 ++(*counter);
151 DialectRegistry nestedRegistry;
152 for (int i = 0; i < numRecursive; ++i) {
153 // Create unique TypeIDs for these recursive extensions so they don't get
154 // de-duplicated.
155 auto extension =
156 std::make_unique<DummyExtension>(args: counter, /*numRecursive=*/args: 0);
157 auto typeID = TypeID::getFromOpaquePointer(pointer: extension.get());
158 nestedRegistry.addExtension(extensionID: typeID, extension: std::move(extension));
159 }
160 // Adding additional extensions may trigger a reallocation of the
161 // `extensions` vector in the dialect registry.
162 ctx->appendDialectRegistry(registry: nestedRegistry);
163 }
164
165private:
166 int *counter;
167 int numRecursive;
168};
169} // namespace
170
171TEST(Dialect, NestedDialectExtension) {
172 DialectRegistry registry;
173 registry.insert<TestDialect>();
174
175 // Add an extension that adds 100 more extensions.
176 int counter1 = 0;
177 registry.addExtension(extensionID: TypeID::get<DummyExtension>(),
178 extension: std::make_unique<DummyExtension>(args: &counter1, args: 100));
179 // Add one more extension. This should not crash.
180 int counter2 = 0;
181 registry.addExtension(extensionID: TypeID::getFromOpaquePointer(pointer: &counter2),
182 extension: std::make_unique<DummyExtension>(args: &counter2, args: 0));
183
184 // Load dialect and apply extensions.
185 MLIRContext context(registry);
186 Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
187 ASSERT_TRUE(testDialect != nullptr);
188
189 // Extensions are de-duplicated by typeID. Make sure that each expected
190 // extension was applied at least once.
191 EXPECT_GE(counter1, 101);
192 EXPECT_GE(counter2, 1);
193}
194
195TEST(Dialect, SubsetWithExtensions) {
196 DialectRegistry registry1, registry2;
197 registry1.insert<TestDialect>();
198 registry2.insert<TestDialect>();
199
200 // Validate that the registries are equivalent.
201 ASSERT_TRUE(registry1.isSubsetOf(registry2));
202 ASSERT_TRUE(registry2.isSubsetOf(registry1));
203
204 // Add extensions to registry2.
205 int counter = 0;
206 registry2.addExtension(extensionID: TypeID::get<DummyExtension>(),
207 extension: std::make_unique<DummyExtension>(args: &counter, args: 0));
208
209 // Expect that (1) is a subset of (2) but not the other way around.
210 ASSERT_TRUE(registry1.isSubsetOf(registry2));
211 ASSERT_FALSE(registry2.isSubsetOf(registry1));
212
213 // Add extensions to registry1.
214 registry1.addExtension(extensionID: TypeID::get<DummyExtension>(),
215 extension: std::make_unique<DummyExtension>(args: &counter, args: 0));
216
217 // Expect that (1) and (2) are equivalent.
218 ASSERT_TRUE(registry1.isSubsetOf(registry2));
219 ASSERT_TRUE(registry2.isSubsetOf(registry1));
220
221 // Load dialect and apply extensions.
222 MLIRContext context(registry1);
223 context.getOrLoadDialect<TestDialect>();
224 context.appendDialectRegistry(registry: registry2);
225 // Expect that the extension as only invoked once.
226 ASSERT_EQ(counter, 1);
227}
228
229} // namespace
230

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/unittests/IR/DialectTest.cpp