1 | //===- InterfaceAttachmentTest.cpp - Test attaching interfaces ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This implements the tests for attaching interfaces to attributes and types |
10 | // without having to specify them on the attribute or type class directly. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/IR/BuiltinAttributes.h" |
15 | #include "mlir/IR/BuiltinDialect.h" |
16 | #include "mlir/IR/BuiltinOps.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "../../test/lib/Dialect/Test/TestAttributes.h" |
21 | #include "../../test/lib/Dialect/Test/TestDialect.h" |
22 | #include "../../test/lib/Dialect/Test/TestTypes.h" |
23 | #include "mlir/IR/OwningOpRef.h" |
24 | |
25 | using namespace mlir; |
26 | using namespace test; |
27 | |
28 | namespace { |
29 | |
30 | /// External interface model for the integer type. Only provides non-default |
31 | /// methods. |
32 | struct Model |
33 | : public TestExternalTypeInterface::ExternalModel<Model, IntegerType> { |
34 | unsigned getBitwidthPlusArg(Type type, unsigned arg) const { |
35 | return type.getIntOrFloatBitWidth() + arg; |
36 | } |
37 | |
38 | static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } |
39 | }; |
40 | |
41 | /// External interface model for the float type. Provides non-deafult and |
42 | /// overrides default methods. |
43 | struct OverridingModel |
44 | : public TestExternalTypeInterface::ExternalModel<OverridingModel, |
45 | FloatType> { |
46 | unsigned getBitwidthPlusArg(Type type, unsigned arg) const { |
47 | return type.getIntOrFloatBitWidth() + arg; |
48 | } |
49 | |
50 | static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 42 + arg; } |
51 | |
52 | unsigned getBitwidthPlusDoubleArgument(Type type, unsigned arg) const { |
53 | return 128; |
54 | } |
55 | |
56 | static unsigned staticGetArgument(unsigned arg) { return 420; } |
57 | }; |
58 | |
59 | TEST(InterfaceAttachment, Type) { |
60 | MLIRContext context; |
61 | |
62 | // Check that the type has no interface. |
63 | IntegerType i8 = IntegerType::get(&context, 8); |
64 | ASSERT_FALSE(isa<TestExternalTypeInterface>(i8)); |
65 | |
66 | // Attach an interface and check that the type now has the interface. |
67 | IntegerType::attachInterface<Model>(context); |
68 | TestExternalTypeInterface iface = dyn_cast<TestExternalTypeInterface>(i8); |
69 | ASSERT_TRUE(iface != nullptr); |
70 | EXPECT_EQ(iface.getBitwidthPlusArg(10), 18u); |
71 | EXPECT_EQ(iface.staticGetSomeValuePlusArg(0), 42u); |
72 | EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(2), 12u); |
73 | EXPECT_EQ(iface.staticGetArgument(17), 17u); |
74 | |
75 | // Same, but with the default implementation overridden. |
76 | FloatType flt = Float32Type::get(&context); |
77 | ASSERT_FALSE(isa<TestExternalTypeInterface>(flt)); |
78 | Float32Type::attachInterface<OverridingModel>(context); |
79 | iface = dyn_cast<TestExternalTypeInterface>(flt); |
80 | ASSERT_TRUE(iface != nullptr); |
81 | EXPECT_EQ(iface.getBitwidthPlusArg(10), 42u); |
82 | EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 52u); |
83 | EXPECT_EQ(iface.getBitwidthPlusDoubleArgument(3), 128u); |
84 | EXPECT_EQ(iface.staticGetArgument(17), 420u); |
85 | |
86 | // Other contexts shouldn't have the attribute attached. |
87 | MLIRContext other; |
88 | IntegerType i8other = IntegerType::get(&other, 8); |
89 | EXPECT_FALSE(isa<TestExternalTypeInterface>(i8other)); |
90 | } |
91 | |
92 | /// External interface model for the test type from the test dialect. |
93 | struct TestTypeModel |
94 | : public TestExternalTypeInterface::ExternalModel<TestTypeModel, |
95 | test::TestType> { |
96 | unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return arg; } |
97 | |
98 | static unsigned staticGetSomeValuePlusArg(unsigned arg) { return 10 + arg; } |
99 | }; |
100 | |
101 | TEST(InterfaceAttachment, TypeDelayedContextConstruct) { |
102 | // Put the interface in the registry. |
103 | DialectRegistry registry; |
104 | registry.insert<test::TestDialect>(); |
105 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
106 | test::TestType::attachInterface<TestTypeModel>(*ctx); |
107 | }); |
108 | |
109 | // Check that when a context is constructed with the given registry, the type |
110 | // interface gets registered. |
111 | MLIRContext context(registry); |
112 | context.loadDialect<test::TestDialect>(); |
113 | test::TestType testType = test::TestType::get(&context); |
114 | auto iface = dyn_cast<TestExternalTypeInterface>(testType); |
115 | ASSERT_TRUE(iface != nullptr); |
116 | EXPECT_EQ(iface.getBitwidthPlusArg(42), 42u); |
117 | EXPECT_EQ(iface.staticGetSomeValuePlusArg(10), 20u); |
118 | } |
119 | |
120 | TEST(InterfaceAttachment, TypeDelayedContextAppend) { |
121 | // Put the interface in the registry. |
122 | DialectRegistry registry; |
123 | registry.insert<test::TestDialect>(); |
124 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
125 | test::TestType::attachInterface<TestTypeModel>(*ctx); |
126 | }); |
127 | |
128 | // Check that when the registry gets appended to the context, the interface |
129 | // becomes available for objects in loaded dialects. |
130 | MLIRContext context; |
131 | context.loadDialect<test::TestDialect>(); |
132 | test::TestType testType = test::TestType::get(&context); |
133 | EXPECT_FALSE(isa<TestExternalTypeInterface>(testType)); |
134 | context.appendDialectRegistry(registry); |
135 | EXPECT_TRUE(isa<TestExternalTypeInterface>(testType)); |
136 | } |
137 | |
138 | TEST(InterfaceAttachment, RepeatedRegistration) { |
139 | DialectRegistry registry; |
140 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
141 | IntegerType::attachInterface<Model>(*ctx); |
142 | }); |
143 | MLIRContext context(registry); |
144 | |
145 | // Should't fail on repeated registration through the dialect registry. |
146 | context.appendDialectRegistry(registry); |
147 | } |
148 | |
149 | TEST(InterfaceAttachment, TypeBuiltinDelayed) { |
150 | // Builtin dialect needs to registration or loading, but delayed interface |
151 | // registration must still work. |
152 | DialectRegistry registry; |
153 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
154 | IntegerType::attachInterface<Model>(*ctx); |
155 | }); |
156 | |
157 | MLIRContext context(registry); |
158 | IntegerType i16 = IntegerType::get(&context, 16); |
159 | EXPECT_TRUE(isa<TestExternalTypeInterface>(i16)); |
160 | |
161 | MLIRContext initiallyEmpty; |
162 | IntegerType i32 = IntegerType::get(&initiallyEmpty, 32); |
163 | EXPECT_FALSE(isa<TestExternalTypeInterface>(i32)); |
164 | initiallyEmpty.appendDialectRegistry(registry); |
165 | EXPECT_TRUE(isa<TestExternalTypeInterface>(i32)); |
166 | } |
167 | |
168 | /// The interface provides a default implementation that expects |
169 | /// ConcreteType::getWidth to exist, which is the case for IntegerType. So this |
170 | /// just derives from the ExternalModel. |
171 | struct TestExternalFallbackTypeIntegerModel |
172 | : public TestExternalFallbackTypeInterface::ExternalModel< |
173 | TestExternalFallbackTypeIntegerModel, IntegerType> {}; |
174 | |
175 | /// The interface provides a default implementation that expects |
176 | /// ConcreteType::getWidth to exist, which is *not* the case for VectorType. Use |
177 | /// FallbackModel instead to override this and make sure the code still compiles |
178 | /// because we never instantiate the ExternalModel class template with a |
179 | /// template argument that would have led to compilation failures. |
180 | struct TestExternalFallbackTypeVectorModel |
181 | : public TestExternalFallbackTypeInterface::FallbackModel< |
182 | TestExternalFallbackTypeVectorModel> { |
183 | unsigned getBitwidth(Type type) const { |
184 | IntegerType elementType = |
185 | dyn_cast_or_null<IntegerType>(cast<VectorType>(type).getElementType()); |
186 | return elementType ? elementType.getWidth() : 0; |
187 | } |
188 | }; |
189 | |
190 | TEST(InterfaceAttachment, Fallback) { |
191 | MLIRContext context; |
192 | |
193 | // Just check that we can attach the interface. |
194 | IntegerType i8 = IntegerType::get(&context, 8); |
195 | ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(i8)); |
196 | IntegerType::attachInterface<TestExternalFallbackTypeIntegerModel>(context); |
197 | ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(i8)); |
198 | |
199 | // Call the method so it is guaranteed not to be instantiated. |
200 | VectorType vec = VectorType::get({42}, i8); |
201 | ASSERT_FALSE(isa<TestExternalFallbackTypeInterface>(vec)); |
202 | VectorType::attachInterface<TestExternalFallbackTypeVectorModel>(context); |
203 | ASSERT_TRUE(isa<TestExternalFallbackTypeInterface>(vec)); |
204 | EXPECT_EQ(cast<TestExternalFallbackTypeInterface>(vec).getBitwidth(), 8u); |
205 | } |
206 | |
207 | /// External model for attribute interfaces. |
208 | struct TestExternalIntegerAttrModel |
209 | : public TestExternalAttrInterface::ExternalModel< |
210 | TestExternalIntegerAttrModel, IntegerAttr> { |
211 | const Dialect *getDialectPtr(Attribute attr) const { |
212 | return &cast<IntegerAttr>(attr).getDialect(); |
213 | } |
214 | |
215 | static int () { return 42; } |
216 | }; |
217 | |
218 | TEST(InterfaceAttachment, Attribute) { |
219 | MLIRContext context; |
220 | |
221 | // Attribute interfaces use the exact same mechanism as types, so just check |
222 | // that the basics work for attributes. |
223 | IntegerAttr attr = IntegerAttr::get(IntegerType::get(&context, 32), 42); |
224 | ASSERT_FALSE(isa<TestExternalAttrInterface>(attr)); |
225 | IntegerAttr::attachInterface<TestExternalIntegerAttrModel>(context); |
226 | auto iface = dyn_cast<TestExternalAttrInterface>(attr); |
227 | ASSERT_TRUE(iface != nullptr); |
228 | EXPECT_EQ(iface.getDialectPtr(), &attr.getDialect()); |
229 | EXPECT_EQ(iface.getSomeNumber(), 42); |
230 | } |
231 | |
232 | /// External model for an interface attachable to a non-builtin attribute. |
233 | struct TestExternalSimpleAAttrModel |
234 | : public TestExternalAttrInterface::ExternalModel< |
235 | TestExternalSimpleAAttrModel, test::SimpleAAttr> { |
236 | const Dialect *getDialectPtr(Attribute attr) const { |
237 | return &attr.getDialect(); |
238 | } |
239 | |
240 | static int () { return 21; } |
241 | }; |
242 | |
243 | TEST(InterfaceAttachmentTest, AttributeDelayed) { |
244 | // Attribute interfaces use the exact same mechanism as types, so just check |
245 | // that the delayed registration work for attributes. |
246 | DialectRegistry registry; |
247 | registry.insert<test::TestDialect>(); |
248 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
249 | test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx); |
250 | }); |
251 | |
252 | MLIRContext context(registry); |
253 | context.loadDialect<test::TestDialect>(); |
254 | auto attr = test::SimpleAAttr::get(&context); |
255 | EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); |
256 | |
257 | MLIRContext initiallyEmpty; |
258 | initiallyEmpty.loadDialect<test::TestDialect>(); |
259 | attr = test::SimpleAAttr::get(&initiallyEmpty); |
260 | EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); |
261 | initiallyEmpty.appendDialectRegistry(registry); |
262 | EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); |
263 | } |
264 | |
265 | /// External interface model for the module operation. Only provides non-default |
266 | /// methods. |
267 | struct TestExternalOpModel |
268 | : public TestExternalOpInterface::ExternalModel<TestExternalOpModel, |
269 | ModuleOp> { |
270 | unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { |
271 | return op->getName().getStringRef().size() + arg; |
272 | } |
273 | |
274 | static unsigned getNameLengthPlusArgTwice(unsigned arg) { |
275 | return ModuleOp::getOperationName().size() + 2 * arg; |
276 | } |
277 | }; |
278 | |
279 | /// External interface model for the func operation. Provides non-deafult and |
280 | /// overrides default methods. |
281 | struct TestExternalOpOverridingModel |
282 | : public TestExternalOpInterface::FallbackModel< |
283 | TestExternalOpOverridingModel> { |
284 | unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { |
285 | return op->getName().getStringRef().size() + arg; |
286 | } |
287 | |
288 | static unsigned getNameLengthPlusArgTwice(unsigned arg) { |
289 | return UnrealizedConversionCastOp::getOperationName().size() + 2 * arg; |
290 | } |
291 | |
292 | unsigned getNameLengthTimesArg(Operation *op, unsigned arg) const { |
293 | return 42; |
294 | } |
295 | |
296 | static unsigned getNameLengthMinusArg(unsigned arg) { return 21; } |
297 | }; |
298 | |
299 | TEST(InterfaceAttachment, Operation) { |
300 | MLIRContext context; |
301 | OpBuilder builder(&context); |
302 | |
303 | // Initially, the operation doesn't have the interface. |
304 | OwningOpRef<ModuleOp> moduleOp = |
305 | builder.create<ModuleOp>(UnknownLoc::get(&context)); |
306 | ASSERT_FALSE(isa<TestExternalOpInterface>(moduleOp->getOperation())); |
307 | |
308 | // We can attach an external interface and now the operaiton has it. |
309 | ModuleOp::attachInterface<TestExternalOpModel>(context); |
310 | auto iface = dyn_cast<TestExternalOpInterface>(moduleOp->getOperation()); |
311 | ASSERT_TRUE(iface != nullptr); |
312 | EXPECT_EQ(iface.getNameLengthPlusArg(10), 24u); |
313 | EXPECT_EQ(iface.getNameLengthTimesArg(3), 42u); |
314 | EXPECT_EQ(iface.getNameLengthPlusArgTwice(18), 50u); |
315 | EXPECT_EQ(iface.getNameLengthMinusArg(5), 9u); |
316 | |
317 | // Default implementation can be overridden. |
318 | OwningOpRef<UnrealizedConversionCastOp> castOp = |
319 | builder.create<UnrealizedConversionCastOp>(UnknownLoc::get(&context), |
320 | TypeRange(), ValueRange()); |
321 | ASSERT_FALSE(isa<TestExternalOpInterface>(castOp->getOperation())); |
322 | UnrealizedConversionCastOp::attachInterface<TestExternalOpOverridingModel>( |
323 | context); |
324 | iface = dyn_cast<TestExternalOpInterface>(castOp->getOperation()); |
325 | ASSERT_TRUE(iface != nullptr); |
326 | EXPECT_EQ(iface.getNameLengthPlusArg(10), 44u); |
327 | EXPECT_EQ(iface.getNameLengthTimesArg(0), 42u); |
328 | EXPECT_EQ(iface.getNameLengthPlusArgTwice(8), 50u); |
329 | EXPECT_EQ(iface.getNameLengthMinusArg(1000), 21u); |
330 | |
331 | // Another context doesn't have the interfaces registered. |
332 | MLIRContext other; |
333 | OwningOpRef<ModuleOp> otherModuleOp = |
334 | ModuleOp::create(UnknownLoc::get(&other)); |
335 | ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp->getOperation())); |
336 | } |
337 | |
338 | template <class ConcreteOp> |
339 | struct TestExternalTestOpModel |
340 | : public TestExternalOpInterface::ExternalModel< |
341 | TestExternalTestOpModel<ConcreteOp>, ConcreteOp> { |
342 | unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { |
343 | return op->getName().getStringRef().size() + arg; |
344 | } |
345 | |
346 | static unsigned getNameLengthPlusArgTwice(unsigned arg) { |
347 | return ConcreteOp::getOperationName().size() + 2 * arg; |
348 | } |
349 | }; |
350 | |
351 | TEST(InterfaceAttachment, OperationDelayedContextConstruct) { |
352 | DialectRegistry registry; |
353 | registry.insert<test::TestDialect>(); |
354 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
355 | ModuleOp::attachInterface<TestExternalOpModel>(*ctx); |
356 | }); |
357 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
358 | test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); |
359 | test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); |
360 | }); |
361 | |
362 | // Construct the context directly from a registry. The interfaces are |
363 | // expected to be readily available on operations. |
364 | MLIRContext context(registry); |
365 | context.loadDialect<test::TestDialect>(); |
366 | |
367 | OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); |
368 | OpBuilder builder(module->getBody(), module->getBody()->begin()); |
369 | auto opJ = |
370 | builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); |
371 | auto opH = |
372 | builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); |
373 | auto opI = |
374 | builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); |
375 | |
376 | EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); |
377 | EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); |
378 | EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); |
379 | EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); |
380 | } |
381 | |
382 | TEST(InterfaceAttachment, OperationDelayedContextAppend) { |
383 | DialectRegistry registry; |
384 | registry.insert<test::TestDialect>(); |
385 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, BuiltinDialect *dialect) { |
386 | ModuleOp::attachInterface<TestExternalOpModel>(*ctx); |
387 | }); |
388 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, test::TestDialect *dialect) { |
389 | test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx); |
390 | test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx); |
391 | }); |
392 | |
393 | // Construct the context, create ops, and only then append the registry. The |
394 | // interfaces are expected to be available after appending the registry. |
395 | MLIRContext context; |
396 | context.loadDialect<test::TestDialect>(); |
397 | |
398 | OwningOpRef<ModuleOp> module = ModuleOp::create(UnknownLoc::get(&context)); |
399 | OpBuilder builder(module->getBody(), module->getBody()->begin()); |
400 | auto opJ = |
401 | builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type()); |
402 | auto opH = |
403 | builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult()); |
404 | auto opI = |
405 | builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult()); |
406 | |
407 | EXPECT_FALSE(isa<TestExternalOpInterface>(module->getOperation())); |
408 | EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation())); |
409 | EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation())); |
410 | EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); |
411 | |
412 | context.appendDialectRegistry(registry); |
413 | |
414 | EXPECT_TRUE(isa<TestExternalOpInterface>(module->getOperation())); |
415 | EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation())); |
416 | EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation())); |
417 | EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation())); |
418 | } |
419 | |
420 | TEST(InterfaceAttachmentTest, PromisedInterfaces) { |
421 | // Attribute interfaces use the exact same mechanism as types, so just check |
422 | // that the promise mechanism works for attributes. |
423 | MLIRContext context; |
424 | auto testDialect = context.getOrLoadDialect<test::TestDialect>(); |
425 | auto attr = test::SimpleAAttr::get(&context); |
426 | |
427 | // `SimpleAAttr` doesn't implement nor promises the |
428 | // `TestExternalAttrInterface` interface. |
429 | EXPECT_FALSE(isa<TestExternalAttrInterface>(attr)); |
430 | EXPECT_FALSE( |
431 | attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); |
432 | |
433 | // Add a promise `TestExternalAttrInterface`. |
434 | testDialect->declarePromisedInterface<test::SimpleAAttr, |
435 | TestExternalAttrInterface>(); |
436 | EXPECT_TRUE( |
437 | attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); |
438 | |
439 | // Attach the interface. |
440 | test::SimpleAAttr::attachInterface<TestExternalAttrInterface>(context); |
441 | EXPECT_TRUE(isa<TestExternalAttrInterface>(attr)); |
442 | EXPECT_TRUE( |
443 | attr.hasPromiseOrImplementsInterface<TestExternalAttrInterface>()); |
444 | } |
445 | |
446 | } // namespace |
447 | |