1//===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This implements the tests for attaching interfaces to attributes and types
10// without having to specify them on the attribute or type class directly.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/BuiltinDialect.h"
16#include "mlir/IR/BuiltinOps.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "gtest/gtest.h"
19
20#include "../../test/lib/Dialect/Test/TestAttributes.h"
21#include "../../test/lib/Dialect/Test/TestDialect.h"
22#include "../../test/lib/Dialect/Test/TestTypes.h"
23#include "mlir/IR/OwningOpRef.h"
24
25using namespace mlir;
26using namespace test;
27
28namespace {
29
30/// External interface model for the integer type. Only provides non-default
31/// methods.
32struct Model
33 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
34 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
35 return type.getIntOrFloatBitWidth() + arg;
36 }
37
38 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
39};
40
41/// External interface model for the float type. Provides non-deafult and
42/// overrides default methods.
43struct OverridingModel
44 : public TestExternalTypeInterface::ExternalModel<OverridingModel,
45 FloatType> {
46 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
47 return type.getIntOrFloatBitWidth() + arg;
48 }
49
50 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
51
52 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
53 return 128;
54 }
55
56 static unsigned staticGetArgument(unsigned arg) { return 420; }
57};
58
59TEST(InterfaceAttachment, Type) {
60 MLIRContext context;
61
62 // Check that the type has no interface.
63 IntegerType i8 = IntegerType::get(&context, 8);
64 ASSERT_FALSE(isa<TestExternalTypeInterface>(i8));
65
66 // Attach an interface and check that the type now has the interface.
67 IntegerType::attachInterface<Model>(context);
68 TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8);
69 ASSERT_TRUE(iface != nullptr);
70 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
71 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
72 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
73 EXPECT_EQ(iface.staticGetArgument(17), 17u);
74
75 // Same, but with the default implementation overridden.
76 FloatType flt = Float32Type::get(&context);
77 ASSERT_FALSE(isa<TestExternalTypeInterface>(flt));
78 Float32Type::attachInterface<OverridingModel>(context);
79 iface = dyn_cast<TestExternalTypeInterface>(flt);
80 ASSERT_TRUE(iface != nullptr);
81 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
82 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
83 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
84 EXPECT_EQ(iface.staticGetArgument(17), 420u);
85
86 // Other contexts shouldn't have the attribute attached.
87 MLIRContext other;
88 IntegerType i8other = IntegerType::get(&other, 8);
89 EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other));
90}
91
92/// External interface model for the test type from the test dialect.
93struct TestTypeModel
94 : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
95 test::TestType> {
96 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
97
98 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
99};
100
101TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
102 // Put the interface in the registry.
103 DialectRegistry registry;
104 registry.insert<test::TestDialect>();
105 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
106 test::TestType::attachInterface<TestTypeModel>(*ctx);
107 });
108
109 // Check that when a context is constructed with the given registry, the type
110 // interface gets registered.
111 MLIRContext context(registry);
112 context.loadDialect<test::TestDialect>();
113 test::TestType testType = test::TestType::get(&context);
114 auto iface = dyn_cast<TestExternalTypeInterface>(testType);
115 ASSERT_TRUE(iface != nullptr);
116 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
117 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
118}
119
120TEST(InterfaceAttachment, TypeDelayedContextAppend) {
121 // Put the interface in the registry.
122 DialectRegistry registry;
123 registry.insert<test::TestDialect>();
124 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
125 test::TestType::attachInterface<TestTypeModel>(*ctx);
126 });
127
128 // Check that when the registry gets appended to the context, the interface
129 // becomes available for objects in loaded dialects.
130 MLIRContext context;
131 context.loadDialect<test::TestDialect>();
132 test::TestType testType = test::TestType::get(&context);
133 EXPECT_FALSE(isa<TestExternalTypeInterface>(testType));
134 context.appendDialectRegistry(registry);
135 EXPECT_TRUE(isa<TestExternalTypeInterface>(testType));
136}
137
138TEST(InterfaceAttachment, RepeatedRegistration) {
139 DialectRegistry registry;
140 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
141 IntegerType::attachInterface<Model>(*ctx);
142 });
143 MLIRContext context(registry);
144
145 // Should't fail on repeated registration through the dialect registry.
146 context.appendDialectRegistry(registry);
147}
148
149TEST(InterfaceAttachment, TypeBuiltinDelayed) {
150 // Builtin dialect needs to registration or loading, but delayed interface
151 // registration must still work.
152 DialectRegistry registry;
153 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
154 IntegerType::attachInterface<Model>(*ctx);
155 });
156
157 MLIRContext context(registry);
158 IntegerType i16 = IntegerType::get(&context, 16);
159 EXPECT_TRUE(isa<TestExternalTypeInterface>(i16));
160
161 MLIRContext initiallyEmpty;
162 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
163 EXPECT_FALSE(isa<TestExternalTypeInterface>(i32));
164 initiallyEmpty.appendDialectRegistry(registry);
165 EXPECT_TRUE(isa<TestExternalTypeInterface>(i32));
166}
167
168/// The interface provides a default implementation that expects
169/// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
170/// just derives from the ExternalModel.
171struct TestExternalFallbackTypeIntegerModel
172 : public TestExternalFallbackTypeInterface::ExternalModel<
173 TestExternalFallbackTypeIntegerModel, IntegerType> {};
174
175/// The interface provides a default implementation that expects
176/// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
177/// FallbackModel instead to override this and make sure the code still compiles
178/// because we never instantiate the ExternalModel class template with a
179/// template argument that would have led to compilation failures.
180struct TestExternalFallbackTypeVectorModel
181 : public TestExternalFallbackTypeInterface::FallbackModel<
182 TestExternalFallbackTypeVectorModel> {
183 unsigned getBitwidth(Type type) const {
184 IntegerType elementType =
185 dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType());
186 return elementType ? elementType.getWidth() : 0;
187 }
188};
189
190TEST(InterfaceAttachment, Fallback) {
191 MLIRContext context;
192
193 // Just check that we can attach the interface.
194 IntegerType i8 = IntegerType::get(&context, 8);
195 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8));
196 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
197 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8));
198
199 // Call the method so it is guaranteed not to be instantiated.
200 VectorType vec = VectorType::get({42}, i8);
201 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec));
202 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
203 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec));
204 EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u);
205}
206
207/// External model for attribute interfaces.
208struct TestExternalIntegerAttrModel
209 : public TestExternalAttrInterface::ExternalModel<
210 TestExternalIntegerAttrModel, IntegerAttr> {
211 const Dialect *getDialectPtr(Attribute attr) const {
212 return &cast<IntegerAttr>(attr).getDialect();
213 }
214
215 static int getSomeNumber() { return 42; }
216};
217
218TEST(InterfaceAttachment, Attribute) {
219 MLIRContext context;
220
221 // Attribute interfaces use the exact same mechanism as types, so just check
222 // that the basics work for attributes.
223 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
224 ASSERT_FALSE(isa<TestExternalAttrInterface>(attr));
225 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
226 auto iface = dyn_cast<TestExternalAttrInterface>(attr);
227 ASSERT_TRUE(iface != nullptr);
228 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
229 EXPECT_EQ(iface.getSomeNumber(), 42);
230}
231
232/// External model for an interface attachable to a non-builtin attribute.
233struct TestExternalSimpleAAttrModel
234 : public TestExternalAttrInterface::ExternalModel<
235 TestExternalSimpleAAttrModel, test::SimpleAAttr> {
236 const Dialect *getDialectPtr(Attribute attr) const {
237 return &attr.getDialect();
238 }
239
240 static int getSomeNumber() { return 21; }
241};
242
243TEST(InterfaceAttachmentTest, AttributeDelayed) {
244 // Attribute interfaces use the exact same mechanism as types, so just check
245 // that the delayed registration work for attributes.
246 DialectRegistry registry;
247 registry.insert<test::TestDialect>();
248 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
249 test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
250 });
251
252 MLIRContext context(registry);
253 context.loadDialect<test::TestDialect>();
254 auto attr = test::SimpleAAttr::get(&context);
255 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
256
257 MLIRContext initiallyEmpty;
258 initiallyEmpty.loadDialect<test::TestDialect>();
259 attr = test::SimpleAAttr::get(&initiallyEmpty);
260 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
261 initiallyEmpty.appendDialectRegistry(registry);
262 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
263}
264
265/// External interface model for the module operation. Only provides non-default
266/// methods.
267struct TestExternalOpModel
268 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
269 ModuleOp> {
270 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
271 return op->getName().getStringRef().size() + arg;
272 }
273
274 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
275 return ModuleOp::getOperationName().size() + 2 * arg;
276 }
277};
278
279/// External interface model for the func operation. Provides non-deafult and
280/// overrides default methods.
281struct TestExternalOpOverridingModel
282 : public TestExternalOpInterface::FallbackModel<
283 TestExternalOpOverridingModel> {
284 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
285 return op->getName().getStringRef().size() + arg;
286 }
287
288 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
289 return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg;
290 }
291
292 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
293 return 42;
294 }
295
296 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
297};
298
299TEST(InterfaceAttachment, Operation) {
300 MLIRContext context;
301 OpBuilder builder(&context);
302
303 // Initially, the operation doesn't have the interface.
304 OwningOpRef<ModuleOp> moduleOp =
305 builder.create<ModuleOp>(UnknownLoc::get(&context));
306 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
307
308 // We can attach an external interface and now the operaiton has it.
309 ModuleOp::attachInterface<TestExternalOpModel>(context);
310 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
311 ASSERT_TRUE(iface != nullptr);
312 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
313 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
314 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
315 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
316
317 // Default implementation can be overridden.
318 OwningOpRef<UnrealizedConversionCastOp> castOp =
319 builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context),
320 TypeRange(), ValueRange());
321 ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation()));
322 UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>(
323 context);
324 iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation());
325 ASSERT_TRUE(iface != nullptr);
326 EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u);
327 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
328 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u);
329 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
330
331 // Another context doesn't have the interfaces registered.
332 MLIRContext other;
333 OwningOpRef<ModuleOp> otherModuleOp =
334 ModuleOp::create(UnknownLoc::get(&other));
335 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
336}
337
338template <class ConcreteOp>
339struct TestExternalTestOpModel
340 : public TestExternalOpInterface::ExternalModel<
341 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
342 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
343 return op->getName().getStringRef().size() + arg;
344 }
345
346 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
347 return ConcreteOp::getOperationName().size() + 2 * arg;
348 }
349};
350
351TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
352 DialectRegistry registry;
353 registry.insert<test::TestDialect>();
354 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
355 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
356 });
357 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
358 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
359 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
360 });
361
362 // Construct the context directly from a registry. The interfaces are
363 // expected to be readily available on operations.
364 MLIRContext context(registry);
365 context.loadDialect<test::TestDialect>();
366
367 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
368 OpBuilder builder(module->getBody(), module->getBody()->begin());
369 auto opJ =
370 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
371 auto opH =
372 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
373 auto opI =
374 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
375
376 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
377 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
378 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
379 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
380}
381
382TEST(InterfaceAttachment, OperationDelayedContextAppend) {
383 DialectRegistry registry;
384 registry.insert<test::TestDialect>();
385 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
386 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
387 });
388 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
389 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
390 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
391 });
392
393 // Construct the context, create ops, and only then append the registry. The
394 // interfaces are expected to be available after appending the registry.
395 MLIRContext context;
396 context.loadDialect<test::TestDialect>();
397
398 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
399 OpBuilder builder(module->getBody(), module->getBody()->begin());
400 auto opJ =
401 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
402 auto opH =
403 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
404 auto opI =
405 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
406
407 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
408 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
409 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
410 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
411
412 context.appendDialectRegistry(registry);
413
414 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
415 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
416 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
417 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
418}
419
420TEST(InterfaceAttachmentTest, PromisedInterfaces) {
421 // Attribute interfaces use the exact same mechanism as types, so just check
422 // that the promise mechanism works for attributes.
423 MLIRContext context;
424 auto testDialect = context.getOrLoadDialect<test::TestDialect>();
425 auto attr = test::SimpleAAttr::get(&context);
426
427 // `SimpleAAttr` doesn't implement nor promises the
428 // `TestExternalAttrInterface` interface.
429 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
430 EXPECT_FALSE(
431 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
432
433 // Add a promise `TestExternalAttrInterface`.
434 testDialect->declarePromisedInterface<test::SimpleAAttr,
435 TestExternalAttrInterface>();
436 EXPECT_TRUE(
437 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
438
439 // Attach the interface.
440 test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
441 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
442 EXPECT_TRUE(
443 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
444}
445
446} // namespace
447

source code of mlir/unittests/IR/InterfaceAttachmentTest.cpp