1//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/DialectConversion.h"
10#include "gtest/gtest.h"
11
12using namespace mlir;
13
14static Operation *createOp(MLIRContext *context) {
15 context->allowUnregisteredDialects();
16 return Operation::create(
17 UnknownLoc::get(context), OperationName("foo.bar", context), std::nullopt,
18 std::nullopt, std::nullopt, /*properties=*/nullptr, std::nullopt, 0);
19}
20
21namespace {
22struct DummyOp {
23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyOp)
24
25 static StringRef getOperationName() { return "foo.bar"; }
26};
27
28TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
29 MLIRContext context;
30 ConversionTarget target(context);
31
32 int index = 0;
33 int callbackCalled1 = 0;
34 target.addDynamicallyLegalOp<DummyOp>(callback: [&](Operation *) {
35 callbackCalled1 = ++index;
36 return true;
37 });
38
39 int callbackCalled2 = 0;
40 target.addDynamicallyLegalOp<DummyOp>(
41 callback: [&](Operation *) -> std::optional<bool> {
42 callbackCalled2 = ++index;
43 return std::nullopt;
44 });
45
46 auto *op = createOp(context: &context);
47 EXPECT_TRUE(target.isLegal(op));
48 EXPECT_EQ(2, callbackCalled1);
49 EXPECT_EQ(1, callbackCalled2);
50 EXPECT_FALSE(target.isIllegal(op));
51 EXPECT_EQ(4, callbackCalled1);
52 EXPECT_EQ(3, callbackCalled2);
53 op->destroy();
54}
55
56TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
57 MLIRContext context;
58 ConversionTarget target(context);
59
60 int index = 0;
61 int callbackCalled = 0;
62 target.addDynamicallyLegalOp<DummyOp>(
63 callback: [&](Operation *) -> std::optional<bool> {
64 callbackCalled = ++index;
65 return std::nullopt;
66 });
67
68 auto *op = createOp(context: &context);
69 EXPECT_FALSE(target.isLegal(op));
70 EXPECT_EQ(1, callbackCalled);
71 EXPECT_FALSE(target.isIllegal(op));
72 EXPECT_EQ(2, callbackCalled);
73 op->destroy();
74}
75
76TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
77 MLIRContext context;
78 ConversionTarget target(context);
79
80 int index = 0;
81 int callbackCalled1 = 0;
82 target.markUnknownOpDynamicallyLegal(fn: [&](Operation *) {
83 callbackCalled1 = ++index;
84 return true;
85 });
86
87 int callbackCalled2 = 0;
88 target.markUnknownOpDynamicallyLegal(fn: [&](Operation *) -> std::optional<bool> {
89 callbackCalled2 = ++index;
90 return std::nullopt;
91 });
92
93 auto *op = createOp(context: &context);
94 EXPECT_TRUE(target.isLegal(op));
95 EXPECT_EQ(2, callbackCalled1);
96 EXPECT_EQ(1, callbackCalled2);
97 EXPECT_FALSE(target.isIllegal(op));
98 EXPECT_EQ(4, callbackCalled1);
99 EXPECT_EQ(3, callbackCalled2);
100 op->destroy();
101}
102
103TEST(DialectConversionTest, DynamicallyLegalReturnNone) {
104 MLIRContext context;
105 ConversionTarget target(context);
106
107 target.addDynamicallyLegalOp<DummyOp>(
108 callback: [&](Operation *) -> std::optional<bool> { return std::nullopt; });
109
110 auto *op = createOp(context: &context);
111 EXPECT_FALSE(target.isLegal(op));
112 EXPECT_FALSE(target.isIllegal(op));
113
114 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
115 EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
116
117 op->destroy();
118}
119
120TEST(DialectConversionTest, DynamicallyLegalUnknownReturnNone) {
121 MLIRContext context;
122 ConversionTarget target(context);
123
124 target.markUnknownOpDynamicallyLegal(
125 fn: [&](Operation *) -> std::optional<bool> { return std::nullopt; });
126
127 auto *op = createOp(context: &context);
128 EXPECT_FALSE(target.isLegal(op));
129 EXPECT_FALSE(target.isIllegal(op));
130
131 EXPECT_TRUE(succeeded(applyPartialConversion(op, target, {})));
132 EXPECT_TRUE(failed(applyFullConversion(op, target, {})));
133
134 op->destroy();
135}
136} // namespace
137

source code of mlir/unittests/Transforms/DialectConversion.cpp