1 | //===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Pass/Pass.h" |
10 | #include "llvm/ADT/STLExtras.h" |
11 | |
12 | #include "gmock/gmock.h" |
13 | |
14 | std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0); |
15 | |
16 | #define GEN_PASS_DECL |
17 | #define GEN_PASS_REGISTRATION |
18 | #include "PassGenTest.h.inc" |
19 | |
20 | #define GEN_PASS_DEF_TESTPASS |
21 | #define GEN_PASS_DEF_TESTPASSWITHOPTIONS |
22 | #define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR |
23 | #include "PassGenTest.h.inc" |
24 | |
25 | struct TestPass : public impl::TestPassBase<TestPass> { |
26 | using TestPassBase::TestPassBase; |
27 | |
28 | void runOnOperation() override {} |
29 | |
30 | std::unique_ptr<mlir::Pass> clone() const { |
31 | return TestPassBase<TestPass>::clone(); |
32 | } |
33 | }; |
34 | |
35 | TEST(PassGenTest, defaultGeneratedConstructor) { |
36 | std::unique_ptr<mlir::Pass> pass = createTestPass(); |
37 | EXPECT_TRUE(pass.get() != nullptr); |
38 | } |
39 | |
40 | TEST(PassGenTest, PassClone) { |
41 | mlir::MLIRContext context; |
42 | |
43 | const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) { |
44 | return static_cast<const TestPass *>(pass.get()); |
45 | }; |
46 | |
47 | const auto origPass = createTestPass(); |
48 | const auto clonePass = unwrap(origPass)->clone(); |
49 | |
50 | EXPECT_TRUE(clonePass.get() != nullptr); |
51 | EXPECT_TRUE(origPass.get() != clonePass.get()); |
52 | } |
53 | |
54 | struct TestPassWithOptions |
55 | : public impl::TestPassWithOptionsBase<TestPassWithOptions> { |
56 | using TestPassWithOptionsBase::TestPassWithOptionsBase; |
57 | |
58 | void runOnOperation() override {} |
59 | |
60 | std::unique_ptr<mlir::Pass> clone() const { |
61 | return TestPassWithOptionsBase<TestPassWithOptions>::clone(); |
62 | } |
63 | |
64 | int getTestOption() const { return testOption; } |
65 | |
66 | llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; } |
67 | }; |
68 | |
69 | TEST(PassGenTest, PassOptions) { |
70 | mlir::MLIRContext context; |
71 | |
72 | TestPassWithOptionsOptions options; |
73 | options.testOption = 57; |
74 | |
75 | llvm::SmallVector<int64_t, 2> testListOption = {1, 2}; |
76 | options.testListOption = testListOption; |
77 | |
78 | const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) { |
79 | return static_cast<const TestPassWithOptions *>(pass.get()); |
80 | }; |
81 | |
82 | const auto pass = createTestPassWithOptions(options); |
83 | |
84 | EXPECT_EQ(unwrap(pass)->getTestOption(), 57); |
85 | EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1); |
86 | EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2); |
87 | } |
88 | |
89 | struct TestPassWithCustomConstructor |
90 | : public impl::TestPassWithCustomConstructorBase< |
91 | TestPassWithCustomConstructor> { |
92 | explicit TestPassWithCustomConstructor(int v) : extraVal(v) {} |
93 | |
94 | void runOnOperation() override {} |
95 | |
96 | std::unique_ptr<mlir::Pass> clone() const { |
97 | return TestPassWithCustomConstructorBase< |
98 | TestPassWithCustomConstructor>::clone(); |
99 | } |
100 | |
101 | unsigned int = 23; |
102 | }; |
103 | |
104 | std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) { |
105 | return std::make_unique<TestPassWithCustomConstructor>(args&: v); |
106 | } |
107 | |
108 | TEST(PassGenTest, PassCloneWithCustomConstructor) { |
109 | mlir::MLIRContext context; |
110 | |
111 | const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) { |
112 | return static_cast<const TestPassWithCustomConstructor *>(pass.get()); |
113 | }; |
114 | |
115 | const auto origPass = createTestPassWithCustomConstructor(v: 10); |
116 | const auto clonePass = unwrap(origPass)->clone(); |
117 | |
118 | EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal); |
119 | } |
120 | |