| 1 | //===- DialectTest.cpp - Dialect unit tests -------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/Dialect.h" |
| 10 | #include "mlir/IR/DialectInterface.h" |
| 11 | #include "mlir/Support/TypeID.h" |
| 12 | #include "gtest/gtest.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | using namespace mlir::detail; |
| 16 | |
| 17 | namespace { |
| 18 | struct TestDialect : public Dialect { |
| 19 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) |
| 20 | |
| 21 | static StringRef getDialectNamespace() { return "test" ; }; |
| 22 | TestDialect(MLIRContext *context) |
| 23 | : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {} |
| 24 | }; |
| 25 | struct AnotherTestDialect : public Dialect { |
| 26 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect) |
| 27 | |
| 28 | static StringRef getDialectNamespace() { return "test" ; }; |
| 29 | AnotherTestDialect(MLIRContext *context) |
| 30 | : Dialect(getDialectNamespace(), context, |
| 31 | TypeID::get<AnotherTestDialect>()) {} |
| 32 | }; |
| 33 | |
| 34 | TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) { |
| 35 | MLIRContext context; |
| 36 | |
| 37 | // Registering a dialect with the same namespace twice should result in a |
| 38 | // failure. |
| 39 | context.loadDialect<TestDialect>(); |
| 40 | ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "" ); |
| 41 | } |
| 42 | |
| 43 | struct SecondTestDialect : public Dialect { |
| 44 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect) |
| 45 | |
| 46 | static StringRef getDialectNamespace() { return "test2" ; } |
| 47 | SecondTestDialect(MLIRContext *context) |
| 48 | : Dialect(getDialectNamespace(), context, |
| 49 | TypeID::get<SecondTestDialect>()) {} |
| 50 | }; |
| 51 | |
| 52 | struct TestDialectInterfaceBase |
| 53 | : public DialectInterface::Base<TestDialectInterfaceBase> { |
| 54 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase) |
| 55 | |
| 56 | TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} |
| 57 | virtual int function() const { return 42; } |
| 58 | }; |
| 59 | |
| 60 | struct TestDialectInterface : public TestDialectInterfaceBase { |
| 61 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface) |
| 62 | |
| 63 | using TestDialectInterfaceBase::TestDialectInterfaceBase; |
| 64 | int function() const final { return 56; } |
| 65 | }; |
| 66 | |
| 67 | struct SecondTestDialectInterface : public TestDialectInterfaceBase { |
| 68 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface) |
| 69 | |
| 70 | using TestDialectInterfaceBase::TestDialectInterfaceBase; |
| 71 | int function() const final { return 78; } |
| 72 | }; |
| 73 | |
| 74 | TEST(Dialect, DelayedInterfaceRegistration) { |
| 75 | DialectRegistry registry; |
| 76 | registry.insert<TestDialect, SecondTestDialect>(); |
| 77 | |
| 78 | // Delayed registration of an interface for TestDialect. |
| 79 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
| 80 | dialect->addInterfaces<TestDialectInterface>(); |
| 81 | }); |
| 82 | |
| 83 | MLIRContext context(registry); |
| 84 | |
| 85 | // Load the TestDialect and check that the interface got registered for it. |
| 86 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
| 87 | ASSERT_TRUE(testDialect != nullptr); |
| 88 | auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
| 89 | EXPECT_TRUE(testDialectInterface != nullptr); |
| 90 | |
| 91 | // Load the SecondTestDialect and check that the interface is not registered |
| 92 | // for it. |
| 93 | Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>(); |
| 94 | ASSERT_TRUE(secondTestDialect != nullptr); |
| 95 | auto *secondTestDialectInterface = |
| 96 | dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect); |
| 97 | EXPECT_TRUE(secondTestDialectInterface == nullptr); |
| 98 | |
| 99 | // Use the same mechanism as for delayed registration but for an already |
| 100 | // loaded dialect and check that the interface is now registered. |
| 101 | DialectRegistry secondRegistry; |
| 102 | secondRegistry.insert<SecondTestDialect>(); |
| 103 | secondRegistry.addExtension( |
| 104 | extensionFn: +[](MLIRContext *ctx, SecondTestDialect *dialect) { |
| 105 | dialect->addInterfaces<SecondTestDialectInterface>(); |
| 106 | }); |
| 107 | context.appendDialectRegistry(registry: secondRegistry); |
| 108 | secondTestDialectInterface = |
| 109 | dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect); |
| 110 | EXPECT_TRUE(secondTestDialectInterface != nullptr); |
| 111 | } |
| 112 | |
| 113 | TEST(Dialect, RepeatedDelayedRegistration) { |
| 114 | // Set up the delayed registration. |
| 115 | DialectRegistry registry; |
| 116 | registry.insert<TestDialect>(); |
| 117 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
| 118 | dialect->addInterfaces<TestDialectInterface>(); |
| 119 | }); |
| 120 | MLIRContext context(registry); |
| 121 | |
| 122 | // Load the TestDialect and check that the interface got registered for it. |
| 123 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
| 124 | ASSERT_TRUE(testDialect != nullptr); |
| 125 | auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
| 126 | EXPECT_TRUE(testDialectInterface != nullptr); |
| 127 | |
| 128 | // Try adding the same dialect interface again and check that we don't crash |
| 129 | // on repeated interface registration. |
| 130 | DialectRegistry secondRegistry; |
| 131 | secondRegistry.insert<TestDialect>(); |
| 132 | secondRegistry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) { |
| 133 | dialect->addInterfaces<TestDialectInterface>(); |
| 134 | }); |
| 135 | context.appendDialectRegistry(registry: secondRegistry); |
| 136 | testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect); |
| 137 | EXPECT_TRUE(testDialectInterface != nullptr); |
| 138 | } |
| 139 | |
| 140 | namespace { |
| 141 | /// A dummy extension that increases a counter when being applied and |
| 142 | /// recursively adds additional extensions. |
| 143 | struct DummyExtension : DialectExtension<DummyExtension, TestDialect> { |
| 144 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension) |
| 145 | |
| 146 | DummyExtension(int *counter, int numRecursive) |
| 147 | : DialectExtension(), counter(counter), numRecursive(numRecursive) {} |
| 148 | |
| 149 | void apply(MLIRContext *ctx, TestDialect *dialect) const final { |
| 150 | ++(*counter); |
| 151 | DialectRegistry nestedRegistry; |
| 152 | for (int i = 0; i < numRecursive; ++i) { |
| 153 | // Create unique TypeIDs for these recursive extensions so they don't get |
| 154 | // de-duplicated. |
| 155 | auto extension = |
| 156 | std::make_unique<DummyExtension>(args: counter, /*numRecursive=*/args: 0); |
| 157 | auto typeID = TypeID::getFromOpaquePointer(pointer: extension.get()); |
| 158 | nestedRegistry.addExtension(extensionID: typeID, extension: std::move(extension)); |
| 159 | } |
| 160 | // Adding additional extensions may trigger a reallocation of the |
| 161 | // `extensions` vector in the dialect registry. |
| 162 | ctx->appendDialectRegistry(registry: nestedRegistry); |
| 163 | } |
| 164 | |
| 165 | private: |
| 166 | int *counter; |
| 167 | int numRecursive; |
| 168 | }; |
| 169 | } // namespace |
| 170 | |
| 171 | TEST(Dialect, NestedDialectExtension) { |
| 172 | DialectRegistry registry; |
| 173 | registry.insert<TestDialect>(); |
| 174 | |
| 175 | // Add an extension that adds 100 more extensions. |
| 176 | int counter1 = 0; |
| 177 | registry.addExtension(extensionID: TypeID::get<DummyExtension>(), |
| 178 | extension: std::make_unique<DummyExtension>(args: &counter1, args: 100)); |
| 179 | // Add one more extension. This should not crash. |
| 180 | int counter2 = 0; |
| 181 | registry.addExtension(extensionID: TypeID::getFromOpaquePointer(pointer: &counter2), |
| 182 | extension: std::make_unique<DummyExtension>(args: &counter2, args: 0)); |
| 183 | |
| 184 | // Load dialect and apply extensions. |
| 185 | MLIRContext context(registry); |
| 186 | Dialect *testDialect = context.getOrLoadDialect<TestDialect>(); |
| 187 | ASSERT_TRUE(testDialect != nullptr); |
| 188 | |
| 189 | // Extensions are de-duplicated by typeID. Make sure that each expected |
| 190 | // extension was applied at least once. |
| 191 | EXPECT_GE(counter1, 101); |
| 192 | EXPECT_GE(counter2, 1); |
| 193 | } |
| 194 | |
| 195 | TEST(Dialect, SubsetWithExtensions) { |
| 196 | DialectRegistry registry1, registry2; |
| 197 | registry1.insert<TestDialect>(); |
| 198 | registry2.insert<TestDialect>(); |
| 199 | |
| 200 | // Validate that the registries are equivalent. |
| 201 | ASSERT_TRUE(registry1.isSubsetOf(registry2)); |
| 202 | ASSERT_TRUE(registry2.isSubsetOf(registry1)); |
| 203 | |
| 204 | // Add extensions to registry2. |
| 205 | int counter = 0; |
| 206 | registry2.addExtension(extensionID: TypeID::get<DummyExtension>(), |
| 207 | extension: std::make_unique<DummyExtension>(args: &counter, args: 0)); |
| 208 | |
| 209 | // Expect that (1) is a subset of (2) but not the other way around. |
| 210 | ASSERT_TRUE(registry1.isSubsetOf(registry2)); |
| 211 | ASSERT_FALSE(registry2.isSubsetOf(registry1)); |
| 212 | |
| 213 | // Add extensions to registry1. |
| 214 | registry1.addExtension(extensionID: TypeID::get<DummyExtension>(), |
| 215 | extension: std::make_unique<DummyExtension>(args: &counter, args: 0)); |
| 216 | |
| 217 | // Expect that (1) and (2) are equivalent. |
| 218 | ASSERT_TRUE(registry1.isSubsetOf(registry2)); |
| 219 | ASSERT_TRUE(registry2.isSubsetOf(registry1)); |
| 220 | |
| 221 | // Load dialect and apply extensions. |
| 222 | MLIRContext context(registry1); |
| 223 | context.getOrLoadDialect<TestDialect>(); |
| 224 | context.appendDialectRegistry(registry: registry2); |
| 225 | // Expect that the extension as only invoked once. |
| 226 | ASSERT_EQ(counter, 1); |
| 227 | } |
| 228 | |
| 229 | } // namespace |
| 230 | |