1//===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This implements the tests for attaching interfaces to attributes and types
10// without having to specify them on the attribute or type class directly.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/BuiltinDialect.h"
16#include "mlir/IR/BuiltinOps.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "gtest/gtest.h"
19
20#include "../../test/lib/Dialect/Test/TestAttributes.h"
21#include "../../test/lib/Dialect/Test/TestDialect.h"
22#include "../../test/lib/Dialect/Test/TestOps.h"
23#include "../../test/lib/Dialect/Test/TestTypes.h"
24#include "mlir/IR/OwningOpRef.h"
25
26using namespace mlir;
27using namespace test;
28
29namespace {
30
31/// External interface model for the integer type. Only provides non-default
32/// methods.
33struct Model
34 : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> {
35 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
36 return type.getIntOrFloatBitWidth() + arg;
37 }
38
39 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
40};
41
42/// External interface model for the float type. Provides non-deafult and
43/// overrides default methods.
44struct OverridingModel
45 : public TestExternalTypeInterface::ExternalModel<OverridingModel,
46 FloatType> {
47 unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
48 return type.getIntOrFloatBitWidth() + arg;
49 }
50
51 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; }
52
53 unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const {
54 return 128;
55 }
56
57 static unsigned staticGetArgument(unsigned arg) { return 420; }
58};
59
60TEST(InterfaceAttachment, Type) {
61 MLIRContext context;
62
63 // Check that the type has no interface.
64 IntegerType i8 = IntegerType::get(&context, 8);
65 ASSERT_FALSE(isa<TestExternalTypeInterface>(i8));
66
67 // Attach an interface and check that the type now has the interface.
68 IntegerType::attachInterface<Model>(context);
69 TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8);
70 ASSERT_TRUE(iface != nullptr);
71 EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u);
72 EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u);
73 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u);
74 EXPECT_EQ(iface.staticGetArgument(17), 17u);
75
76 // Same, but with the default implementation overridden.
77 FloatType flt = Float32Type::get(&context);
78 ASSERT_FALSE(isa<TestExternalTypeInterface>(flt));
79 Float32Type::attachInterface<OverridingModel>(context);
80 iface = dyn_cast<TestExternalTypeInterface>(flt);
81 ASSERT_TRUE(iface != nullptr);
82 EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u);
83 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u);
84 EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u);
85 EXPECT_EQ(iface.staticGetArgument(17), 420u);
86
87 // Other contexts shouldn't have the attribute attached.
88 MLIRContext other;
89 IntegerType i8other = IntegerType::get(&other, 8);
90 EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other));
91}
92
93/// External interface model for the test type from the test dialect.
94struct TestTypeModel
95 : public TestExternalTypeInterface::ExternalModel<TestTypeModel,
96 test::TestType> {
97 unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; }
98
99 static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; }
100};
101
102TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
103 // Put the interface in the registry.
104 DialectRegistry registry;
105 registry.insert<test::TestDialect>();
106 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
107 test::TestType::attachInterface<TestTypeModel>(*ctx);
108 });
109
110 // Check that when a context is constructed with the given registry, the type
111 // interface gets registered.
112 MLIRContext context(registry);
113 context.loadDialect<test::TestDialect>();
114 test::TestType testType = test::TestType::get(&context);
115 auto iface = dyn_cast<TestExternalTypeInterface>(testType);
116 ASSERT_TRUE(iface != nullptr);
117 EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u);
118 EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u);
119}
120
121TEST(InterfaceAttachment, TypeDelayedContextAppend) {
122 // Put the interface in the registry.
123 DialectRegistry registry;
124 registry.insert<test::TestDialect>();
125 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
126 test::TestType::attachInterface<TestTypeModel>(*ctx);
127 });
128
129 // Check that when the registry gets appended to the context, the interface
130 // becomes available for objects in loaded dialects.
131 MLIRContext context;
132 context.loadDialect<test::TestDialect>();
133 test::TestType testType = test::TestType::get(&context);
134 EXPECT_FALSE(isa<TestExternalTypeInterface>(testType));
135 context.appendDialectRegistry(registry);
136 EXPECT_TRUE(isa<TestExternalTypeInterface>(testType));
137}
138
139TEST(InterfaceAttachment, RepeatedRegistration) {
140 DialectRegistry registry;
141 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
142 IntegerType::attachInterface<Model>(*ctx);
143 });
144 MLIRContext context(registry);
145
146 // Should't fail on repeated registration through the dialect registry.
147 context.appendDialectRegistry(registry);
148}
149
150TEST(InterfaceAttachment, TypeBuiltinDelayed) {
151 // Builtin dialect needs to registration or loading, but delayed interface
152 // registration must still work.
153 DialectRegistry registry;
154 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
155 IntegerType::attachInterface<Model>(*ctx);
156 });
157
158 MLIRContext context(registry);
159 IntegerType i16 = IntegerType::get(&context, 16);
160 EXPECT_TRUE(isa<TestExternalTypeInterface>(i16));
161
162 MLIRContext initiallyEmpty;
163 IntegerType i32 = IntegerType::get(&initiallyEmpty, 32);
164 EXPECT_FALSE(isa<TestExternalTypeInterface>(i32));
165 initiallyEmpty.appendDialectRegistry(registry);
166 EXPECT_TRUE(isa<TestExternalTypeInterface>(i32));
167}
168
169/// The interface provides a default implementation that expects
170/// ConcreteType::getWidth to exist, which is the case for IntegerType. So this
171/// just derives from the ExternalModel.
172struct TestExternalFallbackTypeIntegerModel
173 : public TestExternalFallbackTypeInterface::ExternalModel<
174 TestExternalFallbackTypeIntegerModel, IntegerType> {};
175
176/// The interface provides a default implementation that expects
177/// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use
178/// FallbackModel instead to override this and make sure the code still compiles
179/// because we never instantiate the ExternalModel class template with a
180/// template argument that would have led to compilation failures.
181struct TestExternalFallbackTypeVectorModel
182 : public TestExternalFallbackTypeInterface::FallbackModel<
183 TestExternalFallbackTypeVectorModel> {
184 unsigned getBitwidth(Type type) const {
185 IntegerType elementType =
186 dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType());
187 return elementType ? elementType.getWidth() : 0;
188 }
189};
190
191TEST(InterfaceAttachment, Fallback) {
192 MLIRContext context;
193
194 // Just check that we can attach the interface.
195 IntegerType i8 = IntegerType::get(&context, 8);
196 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8));
197 IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context);
198 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8));
199
200 // Call the method so it is guaranteed not to be instantiated.
201 VectorType vec = VectorType::get({42}, i8);
202 ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec));
203 VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context);
204 ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec));
205 EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u);
206}
207
208/// External model for attribute interfaces.
209struct TestExternalIntegerAttrModel
210 : public TestExternalAttrInterface::ExternalModel<
211 TestExternalIntegerAttrModel, IntegerAttr> {
212 const Dialect *getDialectPtr(Attribute attr) const {
213 return &cast<IntegerAttr>(attr).getDialect();
214 }
215
216 static int getSomeNumber() { return 42; }
217};
218
219TEST(InterfaceAttachment, Attribute) {
220 MLIRContext context;
221
222 // Attribute interfaces use the exact same mechanism as types, so just check
223 // that the basics work for attributes.
224 IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42);
225 ASSERT_FALSE(isa<TestExternalAttrInterface>(attr));
226 IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context);
227 auto iface = dyn_cast<TestExternalAttrInterface>(attr);
228 ASSERT_TRUE(iface != nullptr);
229 EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect());
230 EXPECT_EQ(iface.getSomeNumber(), 42);
231}
232
233/// External model for an interface attachable to a non-builtin attribute.
234struct TestExternalSimpleAAttrModel
235 : public TestExternalAttrInterface::ExternalModel<
236 TestExternalSimpleAAttrModel, test::SimpleAAttr> {
237 const Dialect *getDialectPtr(Attribute attr) const {
238 return &attr.getDialect();
239 }
240
241 static int getSomeNumber() { return 21; }
242};
243
244TEST(InterfaceAttachmentTest, AttributeDelayed) {
245 // Attribute interfaces use the exact same mechanism as types, so just check
246 // that the delayed registration work for attributes.
247 DialectRegistry registry;
248 registry.insert<test::TestDialect>();
249 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
250 test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
251 });
252
253 MLIRContext context(registry);
254 context.loadDialect<test::TestDialect>();
255 auto attr = test::SimpleAAttr::get(&context);
256 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
257
258 MLIRContext initiallyEmpty;
259 initiallyEmpty.loadDialect<test::TestDialect>();
260 attr = test::SimpleAAttr::get(&initiallyEmpty);
261 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
262 initiallyEmpty.appendDialectRegistry(registry);
263 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
264}
265
266/// External interface model for the module operation. Only provides non-default
267/// methods.
268struct TestExternalOpModel
269 : public TestExternalOpInterface::ExternalModel<TestExternalOpModel,
270 ModuleOp> {
271 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
272 return op->getName().getStringRef().size() + arg;
273 }
274
275 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
276 return ModuleOp::getOperationName().size() + 2 * arg;
277 }
278};
279
280/// External interface model for the func operation. Provides non-deafult and
281/// overrides default methods.
282struct TestExternalOpOverridingModel
283 : public TestExternalOpInterface::FallbackModel<
284 TestExternalOpOverridingModel> {
285 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
286 return op->getName().getStringRef().size() + arg;
287 }
288
289 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
290 return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg;
291 }
292
293 unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const {
294 return 42;
295 }
296
297 static unsigned getNameLengthMinusArg(unsigned arg) { return 21; }
298};
299
300TEST(InterfaceAttachment, Operation) {
301 MLIRContext context;
302 OpBuilder builder(&context);
303
304 // Initially, the operation doesn't have the interface.
305 OwningOpRef<ModuleOp> moduleOp =
306 builder.create<ModuleOp>(UnknownLoc::get(&context));
307 ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation()));
308
309 // We can attach an external interface and now the operaiton has it.
310 ModuleOp::attachInterface<TestExternalOpModel>(context);
311 auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation());
312 ASSERT_TRUE(iface != nullptr);
313 EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u);
314 EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u);
315 EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u);
316 EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u);
317
318 // Default implementation can be overridden.
319 OwningOpRef<UnrealizedConversionCastOp> castOp =
320 builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context),
321 TypeRange(), ValueRange());
322 ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation()));
323 UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>(
324 context);
325 iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation());
326 ASSERT_TRUE(iface != nullptr);
327 EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u);
328 EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u);
329 EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u);
330 EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u);
331
332 // Another context doesn't have the interfaces registered.
333 MLIRContext other;
334 OwningOpRef<ModuleOp> otherModuleOp =
335 ModuleOp::create(UnknownLoc::get(&other));
336 ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation()));
337}
338
339template <class ConcreteOp>
340struct TestExternalTestOpModel
341 : public TestExternalOpInterface::ExternalModel<
342 TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
343 unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
344 return op->getName().getStringRef().size() + arg;
345 }
346
347 static unsigned getNameLengthPlusArgTwice(unsigned arg) {
348 return ConcreteOp::getOperationName().size() + 2 * arg;
349 }
350};
351
352TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
353 DialectRegistry registry;
354 registry.insert<test::TestDialect>();
355 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
356 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
357 });
358 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
359 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
360 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
361 });
362
363 // Construct the context directly from a registry. The interfaces are
364 // expected to be readily available on operations.
365 MLIRContext context(registry);
366 context.loadDialect<test::TestDialect>();
367
368 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
369 OpBuilder builder(module->getBody(), module->getBody()->begin());
370 auto opJ =
371 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
372 auto opH =
373 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
374 auto opI =
375 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
376
377 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
378 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
379 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
380 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
381}
382
383TEST(InterfaceAttachment, OperationDelayedContextAppend) {
384 DialectRegistry registry;
385 registry.insert<test::TestDialect>();
386 registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) {
387 ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
388 });
389 registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) {
390 test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
391 test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
392 });
393
394 // Construct the context, create ops, and only then append the registry. The
395 // interfaces are expected to be available after appending the registry.
396 MLIRContext context;
397 context.loadDialect<test::TestDialect>();
398
399 OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context));
400 OpBuilder builder(module->getBody(), module->getBody()->begin());
401 auto opJ =
402 builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
403 auto opH =
404 builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
405 auto opI =
406 builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());
407
408 EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation()));
409 EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
410 EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
411 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
412
413 context.appendDialectRegistry(registry);
414
415 EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation()));
416 EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
417 EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
418 EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
419}
420
421TEST(InterfaceAttachmentTest, PromisedInterfaces) {
422 // Attribute interfaces use the exact same mechanism as types, so just check
423 // that the promise mechanism works for attributes.
424 MLIRContext context;
425 auto *testDialect = context.getOrLoadDialect<test::TestDialect>();
426 auto attr = test::SimpleAAttr::get(&context);
427
428 // `SimpleAAttr` doesn't implement nor promises the
429 // `TestExternalAttrInterface` interface.
430 EXPECT_FALSE(isa<TestExternalAttrInterface>(attr));
431 EXPECT_FALSE(
432 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
433
434 // Add a promise `TestExternalAttrInterface`.
435 testDialect->declarePromisedInterface<TestExternalAttrInterface,
436 test::SimpleAAttr>();
437 EXPECT_TRUE(
438 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
439
440 // Attach the interface.
441 test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context);
442 EXPECT_TRUE(isa<TestExternalAttrInterface>(attr));
443 EXPECT_TRUE(
444 attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>());
445}
446
447} // namespace
448

source code of mlir/unittests/IR/InterfaceAttachmentTest.cpp