1//===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Pass/PassManager.h"
10#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11#include "mlir/Debug/ExecutionContext.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/Pass/Pass.h"
17#include "gtest/gtest.h"
18
19#include <memory>
20
21using namespace mlir;
22using namespace mlir::detail;
23
24namespace {
25/// Analysis that operates on any operation.
26struct GenericAnalysis {
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
28
29 GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(Val: op)) {}
30 const bool isFunc;
31};
32
33/// Analysis that operates on a specific operation.
34struct OpSpecificAnalysis {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
36
37 OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
38 const bool isSecret;
39};
40
41/// Simple pass to annotate a func::FuncOp with the results of analysis.
42struct AnnotateFunctionPass
43 : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
45
46 void runOnOperation() override {
47 func::FuncOp op = getOperation();
48 Builder builder(op->getParentOfType<ModuleOp>());
49
50 auto &ga = getAnalysis<GenericAnalysis>();
51 auto &sa = getAnalysis<OpSpecificAnalysis>();
52
53 op->setAttr(name: "isFunc", value: builder.getBoolAttr(value: ga.isFunc));
54 op->setAttr(name: "isSecret", value: builder.getBoolAttr(value: sa.isSecret));
55 }
56};
57
58TEST(PassManagerTest, OpSpecificAnalysis) {
59 MLIRContext context;
60 context.loadDialect<func::FuncDialect>();
61 Builder builder(&context);
62
63 // Create a module with 2 functions.
64 OwningOpRef<ModuleOp> module(ModuleOp::create(loc: UnknownLoc::get(context: &context)));
65 for (StringRef name : {"secret", "not_secret"}) {
66 auto func = func::FuncOp::create(location: builder.getUnknownLoc(), name,
67 type: builder.getFunctionType(inputs: {}, results: {}));
68 func.setPrivate();
69 module->push_back(op: func);
70 }
71
72 // Instantiate and run our pass.
73 auto pm = PassManager::on<ModuleOp>(ctx: &context);
74 pm.addNestedPass<func::FuncOp>(pass: std::make_unique<AnnotateFunctionPass>());
75 LogicalResult result = pm.run(op: module.get());
76 EXPECT_TRUE(succeeded(result));
77
78 // Verify that each function got annotated with expected attributes.
79 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
80 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isFunc")));
81 EXPECT_TRUE(cast<BoolAttr>(func->getDiscardableAttr("isFunc")).getValue());
82
83 bool isSecret = func.getName() == "secret";
84 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isSecret")));
85 EXPECT_EQ(cast<BoolAttr>(func->getDiscardableAttr("isSecret")).getValue(),
86 isSecret);
87 }
88}
89
90/// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
91struct AddAttrFunctionPass
92 : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
93 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
94
95 void runOnOperation() override {
96 func::FuncOp op = getOperation();
97 Builder builder(op->getParentOfType<ModuleOp>());
98 if (op->hasAttr(name: "didProcess"))
99 op->setAttr(name: "didProcessAgain", value: builder.getUnitAttr());
100
101 // We always want to set this one.
102 op->setAttr(name: "didProcess", value: builder.getUnitAttr());
103 }
104};
105
106/// Simple pass to annotate a func::FuncOp with a single attribute
107/// `didProcess2`.
108struct AddSecondAttrFunctionPass
109 : public PassWrapper<AddSecondAttrFunctionPass,
110 OperationPass<func::FuncOp>> {
111 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
112
113 void runOnOperation() override {
114 func::FuncOp op = getOperation();
115 Builder builder(op->getParentOfType<ModuleOp>());
116 op->setAttr(name: "didProcess2", value: builder.getUnitAttr());
117 }
118};
119
120TEST(PassManagerTest, ExecutionAction) {
121 MLIRContext context;
122 context.loadDialect<func::FuncDialect>();
123 Builder builder(&context);
124
125 // Create a module with 2 functions.
126 OwningOpRef<ModuleOp> module(ModuleOp::create(loc: UnknownLoc::get(context: &context)));
127 auto f = func::FuncOp::create(location: builder.getUnknownLoc(), name: "process_me_once",
128 type: builder.getFunctionType(inputs: {}, results: {}));
129 f.setPrivate();
130 module->push_back(op: f);
131
132 // Instantiate our passes.
133 auto pm = PassManager::on<ModuleOp>(ctx: &context);
134 auto pass = std::make_unique<AddAttrFunctionPass>();
135 auto *passPtr = pass.get();
136 pm.addNestedPass<func::FuncOp>(pass: std::move(pass));
137 pm.addNestedPass<func::FuncOp>(pass: std::make_unique<AddSecondAttrFunctionPass>());
138 // Duplicate the first pass to ensure that we *only* run the *first* pass, not
139 // all instances of this pass kind. Notice that this pass (and the test as a
140 // whole) are built to ensure that we can run just a single pass out of a
141 // pipeline that may contain duplicates.
142 pm.addNestedPass<func::FuncOp>(pass: std::make_unique<AddAttrFunctionPass>());
143
144 // Use the action manager to only hit the first pass, not the second one.
145 auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
146 -> tracing::ExecutionContext::Control {
147 // Not a PassExecutionAction, apply the action.
148 auto *passExec = dyn_cast<PassExecutionAction>(Val: &backtrace->getAction());
149 if (!passExec)
150 return tracing::ExecutionContext::Next;
151
152 // If this isn't a function, apply the action.
153 if (!isa<func::FuncOp>(Val: passExec->getOp()))
154 return tracing::ExecutionContext::Next;
155
156 // Only apply the first function pass. Not all instances of the first pass,
157 // only the first pass.
158 if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
159 return tracing::ExecutionContext::Next;
160
161 // Do not apply any other passes in the pass manager.
162 return tracing::ExecutionContext::Skip;
163 };
164
165 // Set up our breakpoint manager.
166 tracing::TagBreakpointManager simpleManager;
167 tracing::ExecutionContext executionCtx(onBreakpoint);
168 executionCtx.addBreakpointManager(manager: &simpleManager);
169 simpleManager.addBreakpoint(tag: PassExecutionAction::tag);
170
171 // Register the execution context in the MLIRContext.
172 context.registerActionHandler(handler: executionCtx);
173
174 // Run the pass manager, expecting our handler to be called.
175 LogicalResult result = pm.run(op: module.get());
176 EXPECT_TRUE(succeeded(result));
177
178 // Verify that each function got annotated with `didProcess` and *not*
179 // `didProcess2`.
180 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
181 ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
182 ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
183 ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
184 }
185}
186
187namespace {
188struct InvalidPass : Pass {
189 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
190
191 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
192 StringRef getName() const override { return "Invalid Pass"; }
193 void runOnOperation() override {}
194 bool canScheduleOn(RegisteredOperationName opName) const override {
195 return true;
196 }
197
198 /// A clone method to create a copy of this pass.
199 std::unique_ptr<Pass> clonePass() const override {
200 return std::make_unique<InvalidPass>(
201 args: *static_cast<const InvalidPass *>(this));
202 }
203};
204} // namespace
205
206TEST(PassManagerTest, InvalidPass) {
207 MLIRContext context;
208 context.allowUnregisteredDialects();
209
210 // Create a module
211 OwningOpRef<ModuleOp> module(ModuleOp::create(loc: UnknownLoc::get(context: &context)));
212
213 // Add a single "invalid_op" operation
214 OpBuilder builder(&module->getBodyRegion());
215 OperationState state(UnknownLoc::get(context: &context), "invalid_op");
216 builder.insert(op: Operation::create(state));
217
218 // Register a diagnostic handler to capture the diagnostic so that we can
219 // check it later.
220 std::unique_ptr<Diagnostic> diagnostic;
221 context.getDiagEngine().registerHandler(handler: [&](Diagnostic &diag) {
222 diagnostic = std::make_unique<Diagnostic>(args: std::move(diag));
223 });
224
225 // Instantiate and run our pass.
226 auto pm = PassManager::on<ModuleOp>(ctx: &context);
227 pm.nest(nestedName: "invalid_op").addPass(pass: std::make_unique<InvalidPass>());
228 LogicalResult result = pm.run(op: module.get());
229 EXPECT_TRUE(failed(result));
230 ASSERT_TRUE(diagnostic.get() != nullptr);
231 EXPECT_EQ(
232 diagnostic->str(),
233 "'invalid_op' op trying to schedule a pass on an unregistered operation");
234
235 // Check that clearing the pass manager effectively removed the pass.
236 pm.clear();
237 result = pm.run(op: module.get());
238 EXPECT_TRUE(succeeded(result));
239
240 // Check that adding the pass at the top-level triggers a fatal error.
241 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
242 "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
243 "PassManager intended to run on 'builtin.module', did you "
244 "intend to nest?");
245}
246
247/// Simple pass to annotate a func::FuncOp with the results of analysis.
248struct InitializeCheckingPass
249 : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
250 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
251 LogicalResult initialize(MLIRContext *ctx) final {
252 initialized = true;
253 return success();
254 }
255 bool initialized = false;
256
257 void runOnOperation() override {
258 if (!initialized) {
259 getOperation()->emitError() << "Pass isn't initialized!";
260 signalPassFailure();
261 }
262 }
263};
264
265TEST(PassManagerTest, PassInitialization) {
266 MLIRContext context;
267 context.allowUnregisteredDialects();
268
269 // Create a module
270 OwningOpRef<ModuleOp> module(ModuleOp::create(loc: UnknownLoc::get(context: &context)));
271
272 // Instantiate and run our pass.
273 auto pm = PassManager::on<ModuleOp>(ctx: &context);
274 pm.addPass(pass: std::make_unique<InitializeCheckingPass>());
275 EXPECT_TRUE(succeeded(pm.run(module.get())));
276
277 // Adding a second copy of the pass, we should also initialize it!
278 pm.addPass(pass: std::make_unique<InitializeCheckingPass>());
279 EXPECT_TRUE(succeeded(pm.run(module.get())));
280}
281
282} // namespace
283

source code of mlir/unittests/Pass/PassManagerTest.cpp