1 | //===- DialectConversion.cpp - Dialect conversion unit tests --------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/DialectConversion.h" |
10 | #include "gtest/gtest.h" |
11 | |
12 | using namespace mlir; |
13 | |
14 | static Operation *createOp(MLIRContext *context) { |
15 | context->allowUnregisteredDialects(); |
16 | return Operation::create( |
17 | UnknownLoc::get(context), OperationName("foo.bar" , context), std::nullopt, |
18 | std::nullopt, std::nullopt, /*properties=*/nullptr, std::nullopt, 0); |
19 | } |
20 | |
21 | namespace { |
22 | struct DummyOp { |
23 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyOp) |
24 | |
25 | static StringRef getOperationName() { return "foo.bar" ; } |
26 | }; |
27 | |
28 | TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) { |
29 | MLIRContext context; |
30 | ConversionTarget target(context); |
31 | |
32 | int index = 0; |
33 | int callbackCalled1 = 0; |
34 | target.addDynamicallyLegalOp<DummyOp>(callback: [&](Operation *) { |
35 | callbackCalled1 = ++index; |
36 | return true; |
37 | }); |
38 | |
39 | int callbackCalled2 = 0; |
40 | target.addDynamicallyLegalOp<DummyOp>( |
41 | callback: [&](Operation *) -> std::optional<bool> { |
42 | callbackCalled2 = ++index; |
43 | return std::nullopt; |
44 | }); |
45 | |
46 | auto *op = createOp(context: &context); |
47 | EXPECT_TRUE(target.isLegal(op)); |
48 | EXPECT_EQ(2, callbackCalled1); |
49 | EXPECT_EQ(1, callbackCalled2); |
50 | EXPECT_FALSE(target.isIllegal(op)); |
51 | EXPECT_EQ(4, callbackCalled1); |
52 | EXPECT_EQ(3, callbackCalled2); |
53 | op->destroy(); |
54 | } |
55 | |
56 | TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) { |
57 | MLIRContext context; |
58 | ConversionTarget target(context); |
59 | |
60 | int index = 0; |
61 | int callbackCalled = 0; |
62 | target.addDynamicallyLegalOp<DummyOp>( |
63 | callback: [&](Operation *) -> std::optional<bool> { |
64 | callbackCalled = ++index; |
65 | return std::nullopt; |
66 | }); |
67 | |
68 | auto *op = createOp(context: &context); |
69 | EXPECT_FALSE(target.isLegal(op)); |
70 | EXPECT_EQ(1, callbackCalled); |
71 | EXPECT_FALSE(target.isIllegal(op)); |
72 | EXPECT_EQ(2, callbackCalled); |
73 | op->destroy(); |
74 | } |
75 | |
76 | TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) { |
77 | MLIRContext context; |
78 | ConversionTarget target(context); |
79 | |
80 | int index = 0; |
81 | int callbackCalled1 = 0; |
82 | target.markUnknownOpDynamicallyLegal(fn: [&](Operation *) { |
83 | callbackCalled1 = ++index; |
84 | return true; |
85 | }); |
86 | |
87 | int callbackCalled2 = 0; |
88 | target.markUnknownOpDynamicallyLegal(fn: [&](Operation *) -> std::optional<bool> { |
89 | callbackCalled2 = ++index; |
90 | return std::nullopt; |
91 | }); |
92 | |
93 | auto *op = createOp(context: &context); |
94 | EXPECT_TRUE(target.isLegal(op)); |
95 | EXPECT_EQ(2, callbackCalled1); |
96 | EXPECT_EQ(1, callbackCalled2); |
97 | EXPECT_FALSE(target.isIllegal(op)); |
98 | EXPECT_EQ(4, callbackCalled1); |
99 | EXPECT_EQ(3, callbackCalled2); |
100 | op->destroy(); |
101 | } |
102 | |
103 | TEST(DialectConversionTest, DynamicallyLegalReturnNone) { |
104 | MLIRContext context; |
105 | ConversionTarget target(context); |
106 | |
107 | target.addDynamicallyLegalOp<DummyOp>( |
108 | callback: [&](Operation *) -> std::optional<bool> { return std::nullopt; }); |
109 | |
110 | auto *op = createOp(context: &context); |
111 | EXPECT_FALSE(target.isLegal(op)); |
112 | EXPECT_FALSE(target.isIllegal(op)); |
113 | |
114 | EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); |
115 | EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); |
116 | |
117 | op->destroy(); |
118 | } |
119 | |
120 | TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) { |
121 | MLIRContext context; |
122 | ConversionTarget target(context); |
123 | |
124 | target.markUnknownOpDynamicallyLegal( |
125 | fn: [&](Operation *) -> std::optional<bool> { return std::nullopt; }); |
126 | |
127 | auto *op = createOp(context: &context); |
128 | EXPECT_FALSE(target.isLegal(op)); |
129 | EXPECT_FALSE(target.isIllegal(op)); |
130 | |
131 | EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {}))); |
132 | EXPECT_TRUE(failed(applyFullConversion(op, target, {}))); |
133 | |
134 | op->destroy(); |
135 | } |
136 | } // namespace |
137 | |