1 | //===- DialectTest.cpp - Dialect unit tests -------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/Dialect.h" |
10 | #include "mlir/IR/DialectInterface.h" |
11 | #include "gtest/gtest.h" |
12 | |
13 | using namespace mlir; |
14 | using namespace mlir::detail; |
15 | |
16 | namespace { |
17 | struct TestDialect : public Dialect { |
18 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) |
19 | |
20 | static StringRef getDialectNamespace() { return "test" ; }; |
21 | TestDialect(MLIRContext *context) |
22 | : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} |
23 | }; |
24 | struct AnotherTestDialect : public Dialect { |
25 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect) |
26 | |
27 | static StringRef getDialectNamespace() { return "test" ; }; |
28 | AnotherTestDialect(MLIRContext *context) |
29 | : Dialect(getDialectNamespace(), context, |
30 | TypeID::get<AnotherTestDialect>()) {} |
31 | }; |
32 | |
33 | TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { |
34 | MLIRContext context; |
35 | |
36 | // Registering a dialect with the same namespace twice should result in a |
37 | // failure. |
38 | context.loadDialect<TestDialect>(); |
39 | ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "" ); |
40 | } |
41 | |
42 | struct SecondTestDialect : public Dialect { |
43 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect) |
44 | |
45 | static StringRef getDialectNamespace() { return "test2" ; } |
46 | SecondTestDialect(MLIRContext *context) |
47 | : Dialect(getDialectNamespace(), context, |
48 | TypeID::get<SecondTestDialect>()) {} |
49 | }; |
50 | |
51 | struct TestDialectInterfaceBase |
52 | : public DialectInterface::Base<TestDialectInterfaceBase> { |
53 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase) |
54 | |
55 | TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} |
56 | virtual int function() const { return 42; } |
57 | }; |
58 | |
59 | struct TestDialectInterface : public TestDialectInterfaceBase { |
60 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface) |
61 | |
62 | using TestDialectInterfaceBase::TestDialectInterfaceBase; |
63 | int function() const final { return 56; } |
64 | }; |
65 | |
66 | struct SecondTestDialectInterface : public TestDialectInterfaceBase { |
67 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface) |
68 | |
69 | using TestDialectInterfaceBase::TestDialectInterfaceBase; |
70 | int function() const final { return 78; } |
71 | }; |
72 | |
73 | TEST(Dialect, DelayedInterfaceRegistration) { |
74 | DialectRegistry registry; |
75 | registry.insert<TestDialect, SecondTestDialect>(); |
76 | |
77 | // Delayed registration of an interface for TestDialect. |
78 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
79 | dialect->addInterfaces<TestDialectInterface>(); |
80 | }); |
81 | |
82 | MLIRContext context(registry); |
83 | |
84 | // Load the TestDialect and check that the interface got registered for it. |
85 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
86 | ASSERT_TRUE(testDialect != nullptr); |
87 | auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
88 | EXPECT_TRUE(testDialectInterface != nullptr); |
89 | |
90 | // Load the SecondTestDialect and check that the interface is not registered |
91 | // for it. |
92 | Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); |
93 | ASSERT_TRUE(secondTestDialect != nullptr); |
94 | auto *secondTestDialectInterface = |
95 | dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect); |
96 | EXPECT_TRUE(secondTestDialectInterface == nullptr); |
97 | |
98 | // Use the same mechanism as for delayed registration but for an already |
99 | // loaded dialect and check that the interface is now registered. |
100 | DialectRegistry secondRegistry; |
101 | secondRegistry.insert<SecondTestDialect>(); |
102 | secondRegistry.addExtension( |
103 | extensionFn: +[](MLIRContext *ctx, SecondTestDialect *dialect) { |
104 | dialect->addInterfaces<SecondTestDialectInterface>(); |
105 | }); |
106 | context.appendDialectRegistry(registry: secondRegistry); |
107 | secondTestDialectInterface = |
108 | dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect); |
109 | EXPECT_TRUE(secondTestDialectInterface != nullptr); |
110 | } |
111 | |
112 | TEST(Dialect, RepeatedDelayedRegistration) { |
113 | // Set up the delayed registration. |
114 | DialectRegistry registry; |
115 | registry.insert<TestDialect>(); |
116 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
117 | dialect->addInterfaces<TestDialectInterface>(); |
118 | }); |
119 | MLIRContext context(registry); |
120 | |
121 | // Load the TestDialect and check that the interface got registered for it. |
122 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
123 | ASSERT_TRUE(testDialect != nullptr); |
124 | auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
125 | EXPECT_TRUE(testDialectInterface != nullptr); |
126 | |
127 | // Try adding the same dialect interface again and check that we don't crash |
128 | // on repeated interface registration. |
129 | DialectRegistry secondRegistry; |
130 | secondRegistry.insert<TestDialect>(); |
131 | secondRegistry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
132 | dialect->addInterfaces<TestDialectInterface>(); |
133 | }); |
134 | context.appendDialectRegistry(registry: secondRegistry); |
135 | testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
136 | EXPECT_TRUE(testDialectInterface != nullptr); |
137 | } |
138 | |
139 | namespace { |
140 | /// A dummy extension that increases a counter when being applied and |
141 | /// recursively adds additional extensions. |
142 | struct DummyExtension : DialectExtension<DummyExtension, TestDialect> { |
143 | DummyExtension(int *counter, int numRecursive) |
144 | : DialectExtension(), counter(counter), numRecursive(numRecursive) {} |
145 | |
146 | void apply(MLIRContext *ctx, TestDialect *dialect) const final { |
147 | ++(*counter); |
148 | DialectRegistry nestedRegistry; |
149 | for (int i = 0; i < numRecursive; ++i) |
150 | nestedRegistry.addExtension( |
151 | extension: std::make_unique<DummyExtension>(args: counter, /*numRecursive=*/args: 0)); |
152 | // Adding additional extensions may trigger a reallocation of the |
153 | // `extensions` vector in the dialect registry. |
154 | ctx->appendDialectRegistry(registry: nestedRegistry); |
155 | } |
156 | |
157 | private: |
158 | int *counter; |
159 | int numRecursive; |
160 | }; |
161 | } // namespace |
162 | |
163 | TEST(Dialect, NestedDialectExtension) { |
164 | DialectRegistry registry; |
165 | registry.insert<TestDialect>(); |
166 | |
167 | // Add an extension that adds 100 more extensions. |
168 | int counter1 = 0; |
169 | registry.addExtension(extension: std::make_unique<DummyExtension>(args: &counter1, args: 100)); |
170 | // Add one more extension. This should not crash. |
171 | int counter2 = 0; |
172 | registry.addExtension(extension: std::make_unique<DummyExtension>(args: &counter2, args: 0)); |
173 | |
174 | // Load dialect and apply extensions. |
175 | MLIRContext context(registry); |
176 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
177 | ASSERT_TRUE(testDialect != nullptr); |
178 | |
179 | // Extensions may be applied multiple times. Make sure that each expected |
180 | // extension was applied at least once. |
181 | EXPECT_GE(counter1, 101); |
182 | EXPECT_GE(counter2, 1); |
183 | } |
184 | |
185 | } // namespace |
186 | |