1 | //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This implements the tests for attaching interfaces to attributes and types |
10 | // without having to specify them on the attribute or type class directly. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/IR/BuiltinAttributes.h" |
15 | #include "mlir/IR/BuiltinDialect.h" |
16 | #include "mlir/IR/BuiltinOps.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "../../test/lib/Dialect/Test/TestAttributes.h" |
21 | #include "../../test/lib/Dialect/Test/TestDialect.h" |
22 | #include "../../test/lib/Dialect/Test/TestOps.h" |
23 | #include "../../test/lib/Dialect/Test/TestTypes.h" |
24 | #include "mlir/IR/OwningOpRef.h" |
25 | |
26 | using namespace mlir; |
27 | using namespace test; |
28 | |
29 | namespace { |
30 | |
31 | /// External interface model for the integer type. Only provides non-default |
32 | /// methods. |
33 | struct Model |
34 | : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { |
35 | unsigned getBitwidthPlusArg(Type type, unsigned arg) const { |
36 | return type.getIntOrFloatBitWidth() + arg; |
37 | } |
38 | |
39 | static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } |
40 | }; |
41 | |
42 | /// External interface model for the float type. Provides non-deafult and |
43 | /// overrides default methods. |
44 | struct OverridingModel |
45 | : public TestExternalTypeInterface::ExternalModel<OverridingModel, |
46 | FloatType> { |
47 | unsigned getBitwidthPlusArg(Type type, unsigned arg) const { |
48 | return type.getIntOrFloatBitWidth() + arg; |
49 | } |
50 | |
51 | static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } |
52 | |
53 | unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { |
54 | return 128; |
55 | } |
56 | |
57 | static unsigned staticGetArgument(unsigned arg) { return 420; } |
58 | }; |
59 | |
60 | TEST(InterfaceAttachment, Type) { |
61 | MLIRContext context; |
62 | |
63 | // Check that the type has no interface. |
64 | IntegerType i8 = IntegerType::get(&context, 8); |
65 | ASSERT_FALSE(isa<TestExternalTypeInterface>(i8)); |
66 | |
67 | // Attach an interface and check that the type now has the interface. |
68 | IntegerType::attachInterface<Model>(context); |
69 | TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8); |
70 | ASSERT_TRUE(iface != nullptr); |
71 | EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); |
72 | EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); |
73 | EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); |
74 | EXPECT_EQ(iface.staticGetArgument(17), 17u); |
75 | |
76 | // Same, but with the default implementation overridden. |
77 | FloatType flt = Float32Type::get(&context); |
78 | ASSERT_FALSE(isa<TestExternalTypeInterface>(flt)); |
79 | Float32Type::attachInterface<OverridingModel>(context); |
80 | iface = dyn_cast<TestExternalTypeInterface>(flt); |
81 | ASSERT_TRUE(iface != nullptr); |
82 | EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); |
83 | EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); |
84 | EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); |
85 | EXPECT_EQ(iface.staticGetArgument(17), 420u); |
86 | |
87 | // Other contexts shouldn't have the attribute attached. |
88 | MLIRContext other; |
89 | IntegerType i8other = IntegerType::get(&other, 8); |
90 | EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other)); |
91 | } |
92 | |
93 | /// External interface model for the test type from the test dialect. |
94 | struct TestTypeModel |
95 | : public TestExternalTypeInterface::ExternalModel<TestTypeModel, |
96 | test::TestType> { |
97 | unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } |
98 | |
99 | static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } |
100 | }; |
101 | |
102 | TEST(InterfaceAttachment, TypeDelayedContextConstruct) { |
103 | // Put the interface in the registry. |
104 | DialectRegistry registry; |
105 | registry.insert<test::TestDialect>(); |
106 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
107 | test::TestType::attachInterface<TestTypeModel>(*ctx); |
108 | }); |
109 | |
110 | // Check that when a context is constructed with the given registry, the type |
111 | // interface gets registered. |
112 | MLIRContext context(registry); |
113 | context.loadDialect<test::TestDialect>(); |
114 | test::TestType testType = test::TestType::get(&context); |
115 | auto iface = dyn_cast<TestExternalTypeInterface>(testType); |
116 | ASSERT_TRUE(iface != nullptr); |
117 | EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); |
118 | EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); |
119 | } |
120 | |
121 | TEST(InterfaceAttachment, TypeDelayedContextAppend) { |
122 | // Put the interface in the registry. |
123 | DialectRegistry registry; |
124 | registry.insert<test::TestDialect>(); |
125 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
126 | test::TestType::attachInterface<TestTypeModel>(*ctx); |
127 | }); |
128 | |
129 | // Check that when the registry gets appended to the context, the interface |
130 | // becomes available for objects in loaded dialects. |
131 | MLIRContext context; |
132 | context.loadDialect<test::TestDialect>(); |
133 | test::TestType testType = test::TestType::get(&context); |
134 | EXPECT_FALSE(isa<TestExternalTypeInterface>(testType)); |
135 | context.appendDialectRegistry(registry); |
136 | EXPECT_TRUE(isa<TestExternalTypeInterface>(testType)); |
137 | } |
138 | |
139 | TEST(InterfaceAttachment, RepeatedRegistration) { |
140 | DialectRegistry registry; |
141 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
142 | IntegerType::attachInterface<Model>(*ctx); |
143 | }); |
144 | MLIRContext context(registry); |
145 | |
146 | // Should't fail on repeated registration through the dialect registry. |
147 | context.appendDialectRegistry(registry); |
148 | } |
149 | |
150 | TEST(InterfaceAttachment, TypeBuiltinDelayed) { |
151 | // Builtin dialect needs to registration or loading, but delayed interface |
152 | // registration must still work. |
153 | DialectRegistry registry; |
154 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
155 | IntegerType::attachInterface<Model>(*ctx); |
156 | }); |
157 | |
158 | MLIRContext context(registry); |
159 | IntegerType i16 = IntegerType::get(&context, 16); |
160 | EXPECT_TRUE(isa<TestExternalTypeInterface>(i16)); |
161 | |
162 | MLIRContext initiallyEmpty; |
163 | IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); |
164 | EXPECT_FALSE(isa<TestExternalTypeInterface>(i32)); |
165 | initiallyEmpty.appendDialectRegistry(registry); |
166 | EXPECT_TRUE(isa<TestExternalTypeInterface>(i32)); |
167 | } |
168 | |
169 | /// The interface provides a default implementation that expects |
170 | /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this |
171 | /// just derives from the ExternalModel. |
172 | struct TestExternalFallbackTypeIntegerModel |
173 | : public TestExternalFallbackTypeInterface::ExternalModel< |
174 | TestExternalFallbackTypeIntegerModel, IntegerType> {}; |
175 | |
176 | /// The interface provides a default implementation that expects |
177 | /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use |
178 | /// FallbackModel instead to override this and make sure the code still compiles |
179 | /// because we never instantiate the ExternalModel class template with a |
180 | /// template argument that would have led to compilation failures. |
181 | struct TestExternalFallbackTypeVectorModel |
182 | : public TestExternalFallbackTypeInterface::FallbackModel< |
183 | TestExternalFallbackTypeVectorModel> { |
184 | unsigned getBitwidth(Type type) const { |
185 | IntegerType elementType = |
186 | dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType()); |
187 | return elementType ? elementType.getWidth() : 0; |
188 | } |
189 | }; |
190 | |
191 | TEST(InterfaceAttachment, Fallback) { |
192 | MLIRContext context; |
193 | |
194 | // Just check that we can attach the interface. |
195 | IntegerType i8 = IntegerType::get(&context, 8); |
196 | ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8)); |
197 | IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); |
198 | ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8)); |
199 | |
200 | // Call the method so it is guaranteed not to be instantiated. |
201 | VectorType vec = VectorType::get({42}, i8); |
202 | ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec)); |
203 | VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); |
204 | ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec)); |
205 | EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u); |
206 | } |
207 | |
208 | /// External model for attribute interfaces. |
209 | struct TestExternalIntegerAttrModel |
210 | : public TestExternalAttrInterface::ExternalModel< |
211 | TestExternalIntegerAttrModel, IntegerAttr> { |
212 | const Dialect *getDialectPtr(Attribute attr) const { |
213 | return &cast<IntegerAttr>(attr).getDialect(); |
214 | } |
215 | |
216 | static int () { return 42; } |
217 | }; |
218 | |
219 | TEST(InterfaceAttachment, Attribute) { |
220 | MLIRContext context; |
221 | |
222 | // Attribute interfaces use the exact same mechanism as types, so just check |
223 | // that the basics work for attributes. |
224 | IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); |
225 | ASSERT_FALSE(isa<TestExternalAttrInterface>(attr)); |
226 | IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context); |
227 | auto iface = dyn_cast<TestExternalAttrInterface>(attr); |
228 | ASSERT_TRUE(iface != nullptr); |
229 | EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); |
230 | EXPECT_EQ(iface.getSomeNumber(), 42); |
231 | } |
232 | |
233 | /// External model for an interface attachable to a non-builtin attribute. |
234 | struct TestExternalSimpleAAttrModel |
235 | : public TestExternalAttrInterface::ExternalModel< |
236 | TestExternalSimpleAAttrModel, test::SimpleAAttr> { |
237 | const Dialect *getDialectPtr(Attribute attr) const { |
238 | return &attr.getDialect(); |
239 | } |
240 | |
241 | static int () { return 21; } |
242 | }; |
243 | |
244 | TEST(InterfaceAttachmentTest, AttributeDelayed) { |
245 | // Attribute interfaces use the exact same mechanism as types, so just check |
246 | // that the delayed registration work for attributes. |
247 | DialectRegistry registry; |
248 | registry.insert<test::TestDialect>(); |
249 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
250 | test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx); |
251 | }); |
252 | |
253 | MLIRContext context(registry); |
254 | context.loadDialect<test::TestDialect>(); |
255 | auto attr = test::SimpleAAttr::get(&context); |
256 | EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); |
257 | |
258 | MLIRContext initiallyEmpty; |
259 | initiallyEmpty.loadDialect<test::TestDialect>(); |
260 | attr = test::SimpleAAttr::get(&initiallyEmpty); |
261 | EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); |
262 | initiallyEmpty.appendDialectRegistry(registry); |
263 | EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); |
264 | } |
265 | |
266 | /// External interface model for the module operation. Only provides non-default |
267 | /// methods. |
268 | struct TestExternalOpModel |
269 | : public TestExternalOpInterface::ExternalModel<TestExternalOpModel, |
270 | ModuleOp> { |
271 | unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { |
272 | return op->getName().getStringRef().size() + arg; |
273 | } |
274 | |
275 | static unsigned getNameLengthPlusArgTwice(unsigned arg) { |
276 | return ModuleOp::getOperationName().size() + 2 * arg; |
277 | } |
278 | }; |
279 | |
280 | /// External interface model for the func operation. Provides non-deafult and |
281 | /// overrides default methods. |
282 | struct TestExternalOpOverridingModel |
283 | : public TestExternalOpInterface::FallbackModel< |
284 | TestExternalOpOverridingModel> { |
285 | unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { |
286 | return op->getName().getStringRef().size() + arg; |
287 | } |
288 | |
289 | static unsigned getNameLengthPlusArgTwice(unsigned arg) { |
290 | return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg; |
291 | } |
292 | |
293 | unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { |
294 | return 42; |
295 | } |
296 | |
297 | static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } |
298 | }; |
299 | |
300 | TEST(InterfaceAttachment, Operation) { |
301 | MLIRContext context; |
302 | OpBuilder builder(&context); |
303 | |
304 | // Initially, the operation doesn't have the interface. |
305 | OwningOpRef<ModuleOp> moduleOp = |
306 | builder.create<ModuleOp>(UnknownLoc::get(&context)); |
307 | ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation())); |
308 | |
309 | // We can attach an external interface and now the operaiton has it. |
310 | ModuleOp::attachInterface<TestExternalOpModel>(context); |
311 | auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation()); |
312 | ASSERT_TRUE(iface != nullptr); |
313 | EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u); |
314 | EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u); |
315 | EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u); |
316 | EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); |
317 | |
318 | // Default implementation can be overridden. |
319 | OwningOpRef<UnrealizedConversionCastOp> castOp = |
320 | builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context), |
321 | TypeRange(), ValueRange()); |
322 | ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation())); |
323 | UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>( |
324 | context); |
325 | iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation()); |
326 | ASSERT_TRUE(iface != nullptr); |
327 | EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u); |
328 | EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); |
329 | EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u); |
330 | EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); |
331 | |
332 | // Another context doesn't have the interfaces registered. |
333 | MLIRContext other; |
334 | OwningOpRef<ModuleOp> otherModuleOp = |
335 | ModuleOp::create(UnknownLoc::get(&other)); |
336 | ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation())); |
337 | } |
338 | |
339 | template <class ConcreteOp> |
340 | struct TestExternalTestOpModel |
341 | : public TestExternalOpInterface::ExternalModel< |
342 | TestExternalTestOpModel<ConcreteOp>, ConcreteOp> { |
343 | unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { |
344 | return op->getName().getStringRef().size() + arg; |
345 | } |
346 | |
347 | static unsigned getNameLengthPlusArgTwice(unsigned arg) { |
348 | return ConcreteOp::getOperationName().size() + 2 * arg; |
349 | } |
350 | }; |
351 | |
352 | TEST(InterfaceAttachment, OperationDelayedContextConstruct) { |
353 | DialectRegistry registry; |
354 | registry.insert<test::TestDialect>(); |
355 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
356 | ModuleOp::attachInterface<TestExternalOpModel>(*ctx); |
357 | }); |
358 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
359 | test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); |
360 | test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); |
361 | }); |
362 | |
363 | // Construct the context directly from a registry. The interfaces are |
364 | // expected to be readily available on operations. |
365 | MLIRContext context(registry); |
366 | context.loadDialect<test::TestDialect>(); |
367 | |
368 | OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); |
369 | OpBuilder builder(module->getBody(), module->getBody()->begin()); |
370 | auto opJ = |
371 | builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); |
372 | auto opH = |
373 | builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); |
374 | auto opI = |
375 | builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); |
376 | |
377 | EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); |
378 | EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); |
379 | EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); |
380 | EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); |
381 | } |
382 | |
383 | TEST(InterfaceAttachment, OperationDelayedContextAppend) { |
384 | DialectRegistry registry; |
385 | registry.insert<test::TestDialect>(); |
386 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
387 | ModuleOp::attachInterface<TestExternalOpModel>(*ctx); |
388 | }); |
389 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
390 | test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); |
391 | test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); |
392 | }); |
393 | |
394 | // Construct the context, create ops, and only then append the registry. The |
395 | // interfaces are expected to be available after appending the registry. |
396 | MLIRContext context; |
397 | context.loadDialect<test::TestDialect>(); |
398 | |
399 | OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); |
400 | OpBuilder builder(module->getBody(), module->getBody()->begin()); |
401 | auto opJ = |
402 | builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); |
403 | auto opH = |
404 | builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); |
405 | auto opI = |
406 | builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); |
407 | |
408 | EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation())); |
409 | EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation())); |
410 | EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation())); |
411 | EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); |
412 | |
413 | context.appendDialectRegistry(registry); |
414 | |
415 | EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); |
416 | EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); |
417 | EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); |
418 | EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); |
419 | } |
420 | |
421 | TEST(InterfaceAttachmentTest, PromisedInterfaces) { |
422 | // Attribute interfaces use the exact same mechanism as types, so just check |
423 | // that the promise mechanism works for attributes. |
424 | MLIRContext context; |
425 | auto *testDialect = context.getOrLoadDialect<test::TestDialect>(); |
426 | auto attr = test::SimpleAAttr::get(&context); |
427 | |
428 | // `SimpleAAttr` doesn't implement nor promises the |
429 | // `TestExternalAttrInterface` interface. |
430 | EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); |
431 | EXPECT_FALSE( |
432 | attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); |
433 | |
434 | // Add a promise `TestExternalAttrInterface`. |
435 | testDialect->declarePromisedInterface<TestExternalAttrInterface, |
436 | test::SimpleAAttr>(); |
437 | EXPECT_TRUE( |
438 | attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); |
439 | |
440 | // Attach the interface. |
441 | test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context); |
442 | EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); |
443 | EXPECT_TRUE( |
444 | attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); |
445 | } |
446 | |
447 | } // namespace |
448 | |