1//===- DialectTest.cpp - Dialect unit tests -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Dialect.h"
10#include "mlir/IR/DialectInterface.h"
11#include "gtest/gtest.h"
12
13using namespace mlir;
14using namespace mlir::detail;
15
16namespace {
17struct TestDialect : public Dialect {
18 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
19
20 static StringRef getDialectNamespace() { return "test"; };
21 TestDialect(MLIRContext *context)
22 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
23};
24struct AnotherTestDialect : public Dialect {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect)
26
27 static StringRef getDialectNamespace() { return "test"; };
28 AnotherTestDialect(MLIRContext *context)
29 : Dialect(getDialectNamespace(), context,
30 TypeID::get<AnotherTestDialect>()) {}
31};
32
33TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
34 MLIRContext context;
35
36 // Registering a dialect with the same namespace twice should result in a
37 // failure.
38 context.loadDialect<TestDialect>();
39 ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
40}
41
42struct SecondTestDialect : public Dialect {
43 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect)
44
45 static StringRef getDialectNamespace() { return "test2"; }
46 SecondTestDialect(MLIRContext *context)
47 : Dialect(getDialectNamespace(), context,
48 TypeID::get<SecondTestDialect>()) {}
49};
50
51struct TestDialectInterfaceBase
52 : public DialectInterface::Base<TestDialectInterfaceBase> {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase)
54
55 TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
56 virtual int function() const { return 42; }
57};
58
59struct TestDialectInterface : public TestDialectInterfaceBase {
60 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface)
61
62 using TestDialectInterfaceBase::TestDialectInterfaceBase;
63 int function() const final { return 56; }
64};
65
66struct SecondTestDialectInterface : public TestDialectInterfaceBase {
67 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface)
68
69 using TestDialectInterfaceBase::TestDialectInterfaceBase;
70 int function() const final { return 78; }
71};
72
73TEST(Dialect, DelayedInterfaceRegistration) {
74 DialectRegistry registry;
75 registry.insert<TestDialect, SecondTestDialect>();
76
77 // Delayed registration of an interface for TestDialect.
78 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) {
79 dialect->addInterfaces<TestDialectInterface>();
80 });
81
82 MLIRContext context(registry);
83
84 // Load the TestDialect and check that the interface got registered for it.
85 Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
86 ASSERT_TRUE(testDialect != nullptr);
87 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect);
88 EXPECT_TRUE(testDialectInterface != nullptr);
89
90 // Load the SecondTestDialect and check that the interface is not registered
91 // for it.
92 Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
93 ASSERT_TRUE(secondTestDialect != nullptr);
94 auto *secondTestDialectInterface =
95 dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect);
96 EXPECT_TRUE(secondTestDialectInterface == nullptr);
97
98 // Use the same mechanism as for delayed registration but for an already
99 // loaded dialect and check that the interface is now registered.
100 DialectRegistry secondRegistry;
101 secondRegistry.insert<SecondTestDialect>();
102 secondRegistry.addExtension(
103 extensionFn: +[](MLIRContext *ctx, SecondTestDialect *dialect) {
104 dialect->addInterfaces<SecondTestDialectInterface>();
105 });
106 context.appendDialectRegistry(registry: secondRegistry);
107 secondTestDialectInterface =
108 dyn_cast<SecondTestDialectInterface>(Val: secondTestDialect);
109 EXPECT_TRUE(secondTestDialectInterface != nullptr);
110}
111
112TEST(Dialect, RepeatedDelayedRegistration) {
113 // Set up the delayed registration.
114 DialectRegistry registry;
115 registry.insert<TestDialect>();
116 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) {
117 dialect->addInterfaces<TestDialectInterface>();
118 });
119 MLIRContext context(registry);
120
121 // Load the TestDialect and check that the interface got registered for it.
122 Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
123 ASSERT_TRUE(testDialect != nullptr);
124 auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect);
125 EXPECT_TRUE(testDialectInterface != nullptr);
126
127 // Try adding the same dialect interface again and check that we don't crash
128 // on repeated interface registration.
129 DialectRegistry secondRegistry;
130 secondRegistry.insert<TestDialect>();
131 secondRegistry.addExtension(extensionFn: +[](MLIRContext *ctx, TestDialect *dialect) {
132 dialect->addInterfaces<TestDialectInterface>();
133 });
134 context.appendDialectRegistry(registry: secondRegistry);
135 testDialectInterface = dyn_cast<TestDialectInterfaceBase>(Val: testDialect);
136 EXPECT_TRUE(testDialectInterface != nullptr);
137}
138
139namespace {
140/// A dummy extension that increases a counter when being applied and
141/// recursively adds additional extensions.
142struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
143 DummyExtension(int *counter, int numRecursive)
144 : DialectExtension(), counter(counter), numRecursive(numRecursive) {}
145
146 void apply(MLIRContext *ctx, TestDialect *dialect) const final {
147 ++(*counter);
148 DialectRegistry nestedRegistry;
149 for (int i = 0; i < numRecursive; ++i)
150 nestedRegistry.addExtension(
151 extension: std::make_unique<DummyExtension>(args: counter, /*numRecursive=*/args: 0));
152 // Adding additional extensions may trigger a reallocation of the
153 // `extensions` vector in the dialect registry.
154 ctx->appendDialectRegistry(registry: nestedRegistry);
155 }
156
157private:
158 int *counter;
159 int numRecursive;
160};
161} // namespace
162
163TEST(Dialect, NestedDialectExtension) {
164 DialectRegistry registry;
165 registry.insert<TestDialect>();
166
167 // Add an extension that adds 100 more extensions.
168 int counter1 = 0;
169 registry.addExtension(extension: std::make_unique<DummyExtension>(args: &counter1, args: 100));
170 // Add one more extension. This should not crash.
171 int counter2 = 0;
172 registry.addExtension(extension: std::make_unique<DummyExtension>(args: &counter2, args: 0));
173
174 // Load dialect and apply extensions.
175 MLIRContext context(registry);
176 Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
177 ASSERT_TRUE(testDialect != nullptr);
178
179 // Extensions may be applied multiple times. Make sure that each expected
180 // extension was applied at least once.
181 EXPECT_GE(counter1, 101);
182 EXPECT_GE(counter2, 1);
183}
184
185} // namespace
186

source code of mlir/unittests/IR/DialectTest.cpp