1//===- PassGenTest.cpp - TableGen PassGen Tests ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Pass/Pass.h"
10#include "llvm/ADT/STLExtras.h"
11
12#include "gmock/gmock.h"
13
14std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0);
15
16#define GEN_PASS_DECL
17#define GEN_PASS_REGISTRATION
18#include "PassGenTest.h.inc"
19
20#define GEN_PASS_DEF_TESTPASS
21#define GEN_PASS_DEF_TESTPASSWITHOPTIONS
22#define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR
23#include "PassGenTest.h.inc"
24
25struct TestPass : public impl::TestPassBase<TestPass> {
26 using TestPassBase::TestPassBase;
27
28 void runOnOperation() override {}
29
30 std::unique_ptr<mlir::Pass> clone() const {
31 return TestPassBase<TestPass>::clone();
32 }
33};
34
35TEST(PassGenTest, defaultGeneratedConstructor) {
36 std::unique_ptr<mlir::Pass> pass = createTestPass();
37 EXPECT_TRUE(pass.get() != nullptr);
38}
39
40TEST(PassGenTest, PassClone) {
41 mlir::MLIRContext context;
42
43 const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
44 return static_cast<const TestPass *>(pass.get());
45 };
46
47 const auto origPass = createTestPass();
48 const auto clonePass = unwrap(origPass)->clone();
49
50 EXPECT_TRUE(clonePass.get() != nullptr);
51 EXPECT_TRUE(origPass.get() != clonePass.get());
52}
53
54struct TestPassWithOptions
55 : public impl::TestPassWithOptionsBase<TestPassWithOptions> {
56 using TestPassWithOptionsBase::TestPassWithOptionsBase;
57
58 void runOnOperation() override {}
59
60 std::unique_ptr<mlir::Pass> clone() const {
61 return TestPassWithOptionsBase<TestPassWithOptions>::clone();
62 }
63
64 int getTestOption() const { return testOption; }
65
66 llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; }
67};
68
69TEST(PassGenTest, PassOptions) {
70 mlir::MLIRContext context;
71
72 TestPassWithOptionsOptions options;
73 options.testOption = 57;
74
75 llvm::SmallVector<int64_t, 2> testListOption = {1, 2};
76 options.testListOption = testListOption;
77
78 const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
79 return static_cast<const TestPassWithOptions *>(pass.get());
80 };
81
82 const auto pass = createTestPassWithOptions(options);
83
84 EXPECT_EQ(unwrap(pass)->getTestOption(), 57);
85 EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1);
86 EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2);
87}
88
89struct TestPassWithCustomConstructor
90 : public impl::TestPassWithCustomConstructorBase<
91 TestPassWithCustomConstructor> {
92 explicit TestPassWithCustomConstructor(int v) : extraVal(v) {}
93
94 void runOnOperation() override {}
95
96 std::unique_ptr<mlir::Pass> clone() const {
97 return TestPassWithCustomConstructorBase<
98 TestPassWithCustomConstructor>::clone();
99 }
100
101 unsigned int extraVal = 23;
102};
103
104std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) {
105 return std::make_unique<TestPassWithCustomConstructor>(args&: v);
106}
107
108TEST(PassGenTest, PassCloneWithCustomConstructor) {
109 mlir::MLIRContext context;
110
111 const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
112 return static_cast<const TestPassWithCustomConstructor *>(pass.get());
113 };
114
115 const auto origPass = createTestPassWithCustomConstructor(v: 10);
116 const auto clonePass = unwrap(origPass)->clone();
117
118 EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);
119}
120

source code of mlir/unittests/TableGen/PassGenTest.cpp