1//===- PassManagerTest.cpp - PassManager unit tests -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Pass/PassManager.h"
10#include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h"
11#include "mlir/Debug/ExecutionContext.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/Pass/Pass.h"
17#include "gtest/gtest.h"
18
19#include <memory>
20
21using namespace mlir;
22using namespace mlir::detail;
23
24namespace {
25/// Analysis that operates on any operation.
26struct GenericAnalysis {
27 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis)
28
29 GenericAnalysis(Operation *op) : isFunc(isa<func::FuncOp>(op)) {}
30 const bool isFunc;
31};
32
33/// Analysis that operates on a specific operation.
34struct OpSpecificAnalysis {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis)
36
37 OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {}
38 const bool isSecret;
39};
40
41/// Simple pass to annotate a func::FuncOp with the results of analysis.
42struct AnnotateFunctionPass
43 : public PassWrapper<AnnotateFunctionPass, OperationPass<func::FuncOp>> {
44 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass)
45
46 void runOnOperation() override {
47 func::FuncOp op = getOperation();
48 Builder builder(op->getParentOfType<ModuleOp>());
49
50 auto &ga = getAnalysis<GenericAnalysis>();
51 auto &sa = getAnalysis<OpSpecificAnalysis>();
52
53 op->setAttr("isFunc", builder.getBoolAttr(value: ga.isFunc));
54 op->setAttr("isSecret", builder.getBoolAttr(value: sa.isSecret));
55 }
56};
57
58TEST(PassManagerTest, OpSpecificAnalysis) {
59 MLIRContext context;
60 context.loadDialect<func::FuncDialect>();
61 Builder builder(&context);
62
63 // Create a module with 2 functions.
64 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
65 for (StringRef name : {"secret", "not_secret"}) {
66 auto func = func::FuncOp::create(
67 builder.getUnknownLoc(), name,
68 builder.getFunctionType(std::nullopt, std::nullopt));
69 func.setPrivate();
70 module->push_back(func);
71 }
72
73 // Instantiate and run our pass.
74 auto pm = PassManager::on<ModuleOp>(ctx: &context);
75 pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
76 LogicalResult result = pm.run(op: module.get());
77 EXPECT_TRUE(succeeded(result));
78
79 // Verify that each function got annotated with expected attributes.
80 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
81 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isFunc")));
82 EXPECT_TRUE(cast<BoolAttr>(func->getDiscardableAttr("isFunc")).getValue());
83
84 bool isSecret = func.getName() == "secret";
85 ASSERT_TRUE(isa<BoolAttr>(func->getDiscardableAttr("isSecret")));
86 EXPECT_EQ(cast<BoolAttr>(func->getDiscardableAttr("isSecret")).getValue(),
87 isSecret);
88 }
89}
90
91/// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`.
92struct AddAttrFunctionPass
93 : public PassWrapper<AddAttrFunctionPass, OperationPass<func::FuncOp>> {
94 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass)
95
96 void runOnOperation() override {
97 func::FuncOp op = getOperation();
98 Builder builder(op->getParentOfType<ModuleOp>());
99 if (op->hasAttr("didProcess"))
100 op->setAttr("didProcessAgain", builder.getUnitAttr());
101
102 // We always want to set this one.
103 op->setAttr("didProcess", builder.getUnitAttr());
104 }
105};
106
107/// Simple pass to annotate a func::FuncOp with a single attribute
108/// `didProcess2`.
109struct AddSecondAttrFunctionPass
110 : public PassWrapper<AddSecondAttrFunctionPass,
111 OperationPass<func::FuncOp>> {
112 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass)
113
114 void runOnOperation() override {
115 func::FuncOp op = getOperation();
116 Builder builder(op->getParentOfType<ModuleOp>());
117 op->setAttr("didProcess2", builder.getUnitAttr());
118 }
119};
120
121TEST(PassManagerTest, ExecutionAction) {
122 MLIRContext context;
123 context.loadDialect<func::FuncDialect>();
124 Builder builder(&context);
125
126 // Create a module with 2 functions.
127 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
128 auto f =
129 func::FuncOp::create(builder.getUnknownLoc(), "process_me_once",
130 builder.getFunctionType(std::nullopt, std::nullopt));
131 f.setPrivate();
132 module->push_back(f);
133
134 // Instantiate our passes.
135 auto pm = PassManager::on<ModuleOp>(ctx: &context);
136 auto pass = std::make_unique<AddAttrFunctionPass>();
137 auto *passPtr = pass.get();
138 pm.addNestedPass<func::FuncOp>(std::move(pass));
139 pm.addNestedPass<func::FuncOp>(std::make_unique<AddSecondAttrFunctionPass>());
140 // Duplicate the first pass to ensure that we *only* run the *first* pass, not
141 // all instances of this pass kind. Notice that this pass (and the test as a
142 // whole) are built to ensure that we can run just a single pass out of a
143 // pipeline that may contain duplicates.
144 pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
145
146 // Use the action manager to only hit the first pass, not the second one.
147 auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace)
148 -> tracing::ExecutionContext::Control {
149 // Not a PassExecutionAction, apply the action.
150 auto *passExec = dyn_cast<PassExecutionAction>(Val: &backtrace->getAction());
151 if (!passExec)
152 return tracing::ExecutionContext::Next;
153
154 // If this isn't a function, apply the action.
155 if (!isa<func::FuncOp>(passExec->getOp()))
156 return tracing::ExecutionContext::Next;
157
158 // Only apply the first function pass. Not all instances of the first pass,
159 // only the first pass.
160 if (passExec->getPass().getThreadingSiblingOrThis() == passPtr)
161 return tracing::ExecutionContext::Next;
162
163 // Do not apply any other passes in the pass manager.
164 return tracing::ExecutionContext::Skip;
165 };
166
167 // Set up our breakpoint manager.
168 tracing::TagBreakpointManager simpleManager;
169 tracing::ExecutionContext executionCtx(onBreakpoint);
170 executionCtx.addBreakpointManager(manager: &simpleManager);
171 simpleManager.addBreakpoint(tag: PassExecutionAction::tag);
172
173 // Register the execution context in the MLIRContext.
174 context.registerActionHandler(handler: executionCtx);
175
176 // Run the pass manager, expecting our handler to be called.
177 LogicalResult result = pm.run(op: module.get());
178 EXPECT_TRUE(succeeded(result));
179
180 // Verify that each function got annotated with `didProcess` and *not*
181 // `didProcess2`.
182 for (func::FuncOp func : module->getOps<func::FuncOp>()) {
183 ASSERT_TRUE(func->getDiscardableAttr("didProcess"));
184 ASSERT_FALSE(func->getDiscardableAttr("didProcess2"));
185 ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain"));
186 }
187}
188
189namespace {
190struct InvalidPass : Pass {
191 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass)
192
193 InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
194 StringRef getName() const override { return "Invalid Pass"; }
195 void runOnOperation() override {}
196 bool canScheduleOn(RegisteredOperationName opName) const override {
197 return true;
198 }
199
200 /// A clone method to create a copy of this pass.
201 std::unique_ptr<Pass> clonePass() const override {
202 return std::make_unique<InvalidPass>(
203 args: *static_cast<const InvalidPass *>(this));
204 }
205};
206} // namespace
207
208TEST(PassManagerTest, InvalidPass) {
209 MLIRContext context;
210 context.allowUnregisteredDialects();
211
212 // Create a module
213 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
214
215 // Add a single "invalid_op" operation
216 OpBuilder builder(&module->getBodyRegion());
217 OperationState state(UnknownLoc::get(&context), "invalid_op");
218 builder.insert(op: Operation::create(state));
219
220 // Register a diagnostic handler to capture the diagnostic so that we can
221 // check it later.
222 std::unique_ptr<Diagnostic> diagnostic;
223 context.getDiagEngine().registerHandler(handler: [&](Diagnostic &diag) {
224 diagnostic = std::make_unique<Diagnostic>(args: std::move(diag));
225 });
226
227 // Instantiate and run our pass.
228 auto pm = PassManager::on<ModuleOp>(ctx: &context);
229 pm.nest(nestedName: "invalid_op").addPass(pass: std::make_unique<InvalidPass>());
230 LogicalResult result = pm.run(op: module.get());
231 EXPECT_TRUE(failed(result));
232 ASSERT_TRUE(diagnostic.get() != nullptr);
233 EXPECT_EQ(
234 diagnostic->str(),
235 "'invalid_op' op trying to schedule a pass on an unregistered operation");
236
237 // Check that clearing the pass manager effectively removed the pass.
238 pm.clear();
239 result = pm.run(op: module.get());
240 EXPECT_TRUE(succeeded(result));
241
242 // Check that adding the pass at the top-level triggers a fatal error.
243 ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
244 "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
245 "PassManager intended to run on 'builtin.module', did you "
246 "intend to nest?");
247}
248
249/// Simple pass to annotate a func::FuncOp with the results of analysis.
250struct InitializeCheckingPass
251 : public PassWrapper<InitializeCheckingPass, OperationPass<ModuleOp>> {
252 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass)
253 LogicalResult initialize(MLIRContext *ctx) final {
254 initialized = true;
255 return success();
256 }
257 bool initialized = false;
258
259 void runOnOperation() override {
260 if (!initialized) {
261 getOperation()->emitError() << "Pass isn't initialized!";
262 signalPassFailure();
263 }
264 }
265};
266
267TEST(PassManagerTest, PassInitialization) {
268 MLIRContext context;
269 context.allowUnregisteredDialects();
270
271 // Create a module
272 OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
273
274 // Instantiate and run our pass.
275 auto pm = PassManager::on<ModuleOp>(ctx: &context);
276 pm.addPass(pass: std::make_unique<InitializeCheckingPass>());
277 EXPECT_TRUE(succeeded(pm.run(module.get())));
278
279 // Adding a second copy of the pass, we should also initialize it!
280 pm.addPass(pass: std::make_unique<InitializeCheckingPass>());
281 EXPECT_TRUE(succeeded(pm.run(module.get())));
282}
283
284} // namespace
285

source code of mlir/unittests/Pass/PassManagerTest.cpp