1//===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements common pass infrastructure.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Pass/Pass.h"
14#include "PassDetail.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/IR/Threading.h"
19#include "mlir/IR/Verifier.h"
20#include "mlir/Support/FileUtilities.h"
21#include "llvm/ADT/Hashing.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/ScopeExit.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/CrashRecoveryContext.h"
26#include "llvm/Support/Mutex.h"
27#include "llvm/Support/Signals.h"
28#include "llvm/Support/Threading.h"
29#include "llvm/Support/ToolOutputFile.h"
30#include <optional>
31
32using namespace mlir;
33using namespace mlir::detail;
34
35//===----------------------------------------------------------------------===//
36// PassExecutionAction
37//===----------------------------------------------------------------------===//
38
39PassExecutionAction::PassExecutionAction(ArrayRef<IRUnit> irUnits,
40 const Pass &pass)
41 : Base(irUnits), pass(pass) {}
42
43void PassExecutionAction::print(raw_ostream &os) const {
44 os << llvm::formatv(Fmt: "`{0}` running `{1}` on Operation `{2}`", Vals: tag,
45 Vals: pass.getName(), Vals: getOp()->getName());
46}
47
48Operation *PassExecutionAction::getOp() const {
49 ArrayRef<IRUnit> irUnits = getContextIRUnits();
50 return irUnits.empty() ? nullptr
51 : llvm::dyn_cast_if_present<Operation *>(Val: irUnits[0]);
52}
53
54//===----------------------------------------------------------------------===//
55// Pass
56//===----------------------------------------------------------------------===//
57
58/// Out of line virtual method to ensure vtables and metadata are emitted to a
59/// single .o file.
60void Pass::anchor() {}
61
62/// Attempt to initialize the options of this pass from the given string.
63LogicalResult Pass::initializeOptions(
64 StringRef options,
65 function_ref<LogicalResult(const Twine &)> errorHandler) {
66 std::string errStr;
67 llvm::raw_string_ostream os(errStr);
68 if (failed(result: passOptions.parseFromString(options, errorStream&: os))) {
69 os.flush();
70 return errorHandler(errStr);
71 }
72 return success();
73}
74
75/// Copy the option values from 'other', which is another instance of this
76/// pass.
77void Pass::copyOptionValuesFrom(const Pass *other) {
78 passOptions.copyOptionValuesFrom(other: other->passOptions);
79}
80
81/// Prints out the pass in the textual representation of pipelines. If this is
82/// an adaptor pass, print its pass managers.
83void Pass::printAsTextualPipeline(raw_ostream &os) {
84 // Special case for adaptors to print its pass managers.
85 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(Val: this)) {
86 llvm::interleave(
87 c: adaptor->getPassManagers(),
88 each_fn: [&](OpPassManager &pm) { pm.printAsTextualPipeline(os); },
89 between_fn: [&] { os << ","; });
90 return;
91 }
92 // Otherwise, print the pass argument followed by its options. If the pass
93 // doesn't have an argument, print the name of the pass to give some indicator
94 // of what pass was run.
95 StringRef argument = getArgument();
96 if (!argument.empty())
97 os << argument;
98 else
99 os << "unknown<" << getName() << ">";
100 passOptions.print(os);
101}
102
103//===----------------------------------------------------------------------===//
104// OpPassManagerImpl
105//===----------------------------------------------------------------------===//
106
107namespace mlir {
108namespace detail {
109struct OpPassManagerImpl {
110 OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting)
111 : name(opName.getStringRef().str()), opName(opName),
112 initializationGeneration(0), nesting(nesting) {}
113 OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
114 : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()),
115 initializationGeneration(0), nesting(nesting) {}
116 OpPassManagerImpl(OpPassManager::Nesting nesting)
117 : initializationGeneration(0), nesting(nesting) {}
118 OpPassManagerImpl(const OpPassManagerImpl &rhs)
119 : name(rhs.name), opName(rhs.opName),
120 initializationGeneration(rhs.initializationGeneration),
121 nesting(rhs.nesting) {
122 for (const std::unique_ptr<Pass> &pass : rhs.passes) {
123 std::unique_ptr<Pass> newPass = pass->clone();
124 newPass->threadingSibling = pass.get();
125 passes.push_back(x: std::move(newPass));
126 }
127 }
128
129 /// Merge the passes of this pass manager into the one provided.
130 void mergeInto(OpPassManagerImpl &rhs);
131
132 /// Nest a new operation pass manager for the given operation kind under this
133 /// pass manager.
134 OpPassManager &nest(OperationName nestedName) {
135 return nest(nested: OpPassManager(nestedName, nesting));
136 }
137 OpPassManager &nest(StringRef nestedName) {
138 return nest(nested: OpPassManager(nestedName, nesting));
139 }
140 OpPassManager &nestAny() { return nest(nested: OpPassManager(nesting)); }
141
142 /// Nest the given pass manager under this pass manager.
143 OpPassManager &nest(OpPassManager &&nested);
144
145 /// Add the given pass to this pass manager. If this pass has a concrete
146 /// operation type, it must be the same type as this pass manager.
147 void addPass(std::unique_ptr<Pass> pass);
148
149 /// Clear the list of passes in this pass manager, other options are
150 /// preserved.
151 void clear();
152
153 /// Finalize the pass list in preparation for execution. This includes
154 /// coalescing adjacent pass managers when possible, verifying scheduled
155 /// passes, etc.
156 LogicalResult finalizePassList(MLIRContext *ctx);
157
158 /// Return the operation name of this pass manager.
159 std::optional<OperationName> getOpName(MLIRContext &context) {
160 if (!name.empty() && !opName)
161 opName = OperationName(name, &context);
162 return opName;
163 }
164 std::optional<StringRef> getOpName() const {
165 return name.empty() ? std::optional<StringRef>()
166 : std::optional<StringRef>(name);
167 }
168
169 /// Return the name used to anchor this pass manager. This is either the name
170 /// of an operation, or the result of `getAnyOpAnchorName()` in the case of an
171 /// op-agnostic pass manager.
172 StringRef getOpAnchorName() const {
173 return getOpName().value_or(u: OpPassManager::getAnyOpAnchorName());
174 }
175
176 /// Indicate if the current pass manager can be scheduled on the given
177 /// operation type.
178 bool canScheduleOn(MLIRContext &context, OperationName opName);
179
180 /// The name of the operation that passes of this pass manager operate on.
181 std::string name;
182
183 /// The cached OperationName (internalized in the context) for the name of the
184 /// operation that passes of this pass manager operate on.
185 std::optional<OperationName> opName;
186
187 /// The set of passes to run as part of this pass manager.
188 std::vector<std::unique_ptr<Pass>> passes;
189
190 /// The current initialization generation of this pass manager. This is used
191 /// to indicate when a pass manager should be reinitialized.
192 unsigned initializationGeneration;
193
194 /// Control the implicit nesting of passes that mismatch the name set for this
195 /// OpPassManager.
196 OpPassManager::Nesting nesting;
197};
198} // namespace detail
199} // namespace mlir
200
201void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
202 assert(name == rhs.name && "merging unrelated pass managers");
203 for (auto &pass : passes)
204 rhs.passes.push_back(x: std::move(pass));
205 passes.clear();
206}
207
208OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) {
209 auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
210 addPass(pass: std::unique_ptr<Pass>(adaptor));
211 return adaptor->getPassManagers().front();
212}
213
214void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
215 // If this pass runs on a different operation than this pass manager, then
216 // implicitly nest a pass manager for this operation if enabled.
217 std::optional<StringRef> pmOpName = getOpName();
218 std::optional<StringRef> passOpName = pass->getOpName();
219 if (pmOpName && passOpName && *pmOpName != *passOpName) {
220 if (nesting == OpPassManager::Nesting::Implicit)
221 return nest(nestedName: *passOpName).addPass(pass: std::move(pass));
222 llvm::report_fatal_error(reason: llvm::Twine("Can't add pass '") + pass->getName() +
223 "' restricted to '" + *passOpName +
224 "' on a PassManager intended to run on '" +
225 getOpAnchorName() + "', did you intend to nest?");
226 }
227
228 passes.emplace_back(args: std::move(pass));
229}
230
231void OpPassManagerImpl::clear() { passes.clear(); }
232
233LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) {
234 auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) {
235 for (auto &pm : adaptor->getPassManagers())
236 if (failed(result: pm.getImpl().finalizePassList(ctx)))
237 return failure();
238 return success();
239 };
240
241 // Walk the pass list and merge adjacent adaptors.
242 OpToOpPassAdaptor *lastAdaptor = nullptr;
243 for (auto &pass : passes) {
244 // Check to see if this pass is an adaptor.
245 if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(Val: pass.get())) {
246 // If it is the first adaptor in a possible chain, remember it and
247 // continue.
248 if (!lastAdaptor) {
249 lastAdaptor = currentAdaptor;
250 continue;
251 }
252
253 // Otherwise, try to merge into the existing adaptor and delete the
254 // current one. If merging fails, just remember this as the last adaptor.
255 if (succeeded(result: currentAdaptor->tryMergeInto(ctx, rhs&: *lastAdaptor)))
256 pass.reset();
257 else
258 lastAdaptor = currentAdaptor;
259 } else if (lastAdaptor) {
260 // If this pass isn't an adaptor, finalize it and forget the last adaptor.
261 if (failed(result: finalizeAdaptor(lastAdaptor)))
262 return failure();
263 lastAdaptor = nullptr;
264 }
265 }
266
267 // If there was an adaptor at the end of the manager, finalize it as well.
268 if (lastAdaptor && failed(result: finalizeAdaptor(lastAdaptor)))
269 return failure();
270
271 // Now that the adaptors have been merged, erase any empty slots corresponding
272 // to the merged adaptors that were nulled-out in the loop above.
273 llvm::erase_if(C&: passes, P: std::logical_not<std::unique_ptr<Pass>>());
274
275 // If this is a op-agnostic pass manager, there is nothing left to do.
276 std::optional<OperationName> rawOpName = getOpName(context&: *ctx);
277 if (!rawOpName)
278 return success();
279
280 // Otherwise, verify that all of the passes are valid for the current
281 // operation anchor.
282 std::optional<RegisteredOperationName> opName =
283 rawOpName->getRegisteredInfo();
284 for (std::unique_ptr<Pass> &pass : passes) {
285 if (opName && !pass->canScheduleOn(opName: *opName)) {
286 return emitError(UnknownLoc::get(ctx))
287 << "unable to schedule pass '" << pass->getName()
288 << "' on a PassManager intended to run on '" << getOpAnchorName()
289 << "'!";
290 }
291 }
292 return success();
293}
294
295bool OpPassManagerImpl::canScheduleOn(MLIRContext &context,
296 OperationName opName) {
297 // If this pass manager is op-specific, we simply check if the provided
298 // operation name is the same as this one.
299 std::optional<OperationName> pmOpName = getOpName(context);
300 if (pmOpName)
301 return pmOpName == opName;
302
303 // Otherwise, this is an op-agnostic pass manager. Check that the operation
304 // can be scheduled on all passes within the manager.
305 std::optional<RegisteredOperationName> registeredInfo =
306 opName.getRegisteredInfo();
307 if (!registeredInfo ||
308 !registeredInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
309 return false;
310 return llvm::all_of(Range&: passes, P: [&](const std::unique_ptr<Pass> &pass) {
311 return pass->canScheduleOn(opName: *registeredInfo);
312 });
313}
314
315//===----------------------------------------------------------------------===//
316// OpPassManager
317//===----------------------------------------------------------------------===//
318
319OpPassManager::OpPassManager(Nesting nesting)
320 : impl(new OpPassManagerImpl(nesting)) {}
321OpPassManager::OpPassManager(StringRef name, Nesting nesting)
322 : impl(new OpPassManagerImpl(name, nesting)) {}
323OpPassManager::OpPassManager(OperationName name, Nesting nesting)
324 : impl(new OpPassManagerImpl(name, nesting)) {}
325OpPassManager::OpPassManager(OpPassManager &&rhs) { *this = std::move(rhs); }
326OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
327OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
328 impl = std::make_unique<OpPassManagerImpl>(args&: *rhs.impl);
329 return *this;
330}
331OpPassManager &OpPassManager::operator=(OpPassManager &&rhs) {
332 impl = std::move(rhs.impl);
333 return *this;
334}
335
336OpPassManager::~OpPassManager() = default;
337
338OpPassManager::pass_iterator OpPassManager::begin() {
339 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
340}
341OpPassManager::pass_iterator OpPassManager::end() {
342 return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
343}
344
345OpPassManager::const_pass_iterator OpPassManager::begin() const {
346 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
347}
348OpPassManager::const_pass_iterator OpPassManager::end() const {
349 return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
350}
351
352/// Nest a new operation pass manager for the given operation kind under this
353/// pass manager.
354OpPassManager &OpPassManager::nest(OperationName nestedName) {
355 return impl->nest(nestedName);
356}
357OpPassManager &OpPassManager::nest(StringRef nestedName) {
358 return impl->nest(nestedName);
359}
360OpPassManager &OpPassManager::nestAny() { return impl->nestAny(); }
361
362/// Add the given pass to this pass manager. If this pass has a concrete
363/// operation type, it must be the same type as this pass manager.
364void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
365 impl->addPass(pass: std::move(pass));
366}
367
368void OpPassManager::clear() { impl->clear(); }
369
370/// Returns the number of passes held by this manager.
371size_t OpPassManager::size() const { return impl->passes.size(); }
372
373/// Returns the internal implementation instance.
374OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
375
376/// Return the operation name that this pass manager operates on.
377std::optional<StringRef> OpPassManager::getOpName() const {
378 return impl->getOpName();
379}
380
381/// Return the operation name that this pass manager operates on.
382std::optional<OperationName>
383OpPassManager::getOpName(MLIRContext &context) const {
384 return impl->getOpName(context);
385}
386
387StringRef OpPassManager::getOpAnchorName() const {
388 return impl->getOpAnchorName();
389}
390
391/// Prints out the passes of the pass manager as the textual representation
392/// of pipelines.
393void printAsTextualPipeline(
394 raw_ostream &os, StringRef anchorName,
395 const llvm::iterator_range<OpPassManager::pass_iterator> &passes) {
396 os << anchorName << "(";
397 llvm::interleave(
398 c: passes, each_fn: [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
399 between_fn: [&]() { os << ","; });
400 os << ")";
401}
402void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
403 StringRef anchorName = getOpAnchorName();
404 ::printAsTextualPipeline(
405 os, anchorName,
406 passes: {MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(),
407 MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()});
408}
409
410void OpPassManager::dump() {
411 llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes:\n";
412 printAsTextualPipeline(os&: llvm::errs());
413 llvm::errs() << "\n";
414}
415
416static void registerDialectsForPipeline(const OpPassManager &pm,
417 DialectRegistry &dialects) {
418 for (const Pass &pass : pm.getPasses())
419 pass.getDependentDialects(registry&: dialects);
420}
421
422void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
423 registerDialectsForPipeline(pm: *this, dialects);
424}
425
426void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
427
428OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
429
430LogicalResult OpPassManager::initialize(MLIRContext *context,
431 unsigned newInitGeneration) {
432 if (impl->initializationGeneration == newInitGeneration)
433 return success();
434 impl->initializationGeneration = newInitGeneration;
435 for (Pass &pass : getPasses()) {
436 // If this pass isn't an adaptor, directly initialize it.
437 auto *adaptor = dyn_cast<OpToOpPassAdaptor>(Val: &pass);
438 if (!adaptor) {
439 if (failed(result: pass.initialize(context)))
440 return failure();
441 continue;
442 }
443
444 // Otherwise, initialize each of the adaptors pass managers.
445 for (OpPassManager &adaptorPM : adaptor->getPassManagers())
446 if (failed(result: adaptorPM.initialize(context, newInitGeneration)))
447 return failure();
448 }
449 return success();
450}
451
452llvm::hash_code OpPassManager::hash() {
453 llvm::hash_code hashCode{};
454 for (Pass &pass : getPasses()) {
455 // If this pass isn't an adaptor, directly hash it.
456 auto *adaptor = dyn_cast<OpToOpPassAdaptor>(Val: &pass);
457 if (!adaptor) {
458 hashCode = llvm::hash_combine(args: hashCode, args: &pass);
459 continue;
460 }
461 // Otherwise, hash recursively each of the adaptors pass managers.
462 for (OpPassManager &adaptorPM : adaptor->getPassManagers())
463 llvm::hash_combine(args: hashCode, args: adaptorPM.hash());
464 }
465 return hashCode;
466}
467
468
469//===----------------------------------------------------------------------===//
470// OpToOpPassAdaptor
471//===----------------------------------------------------------------------===//
472
473LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
474 AnalysisManager am, bool verifyPasses,
475 unsigned parentInitGeneration) {
476 std::optional<RegisteredOperationName> opInfo = op->getRegisteredInfo();
477 if (!opInfo)
478 return op->emitOpError()
479 << "trying to schedule a pass on an unregistered operation";
480 if (!opInfo->hasTrait<OpTrait::IsIsolatedFromAbove>())
481 return op->emitOpError() << "trying to schedule a pass on an operation not "
482 "marked as 'IsolatedFromAbove'";
483 if (!pass->canScheduleOn(opName: *op->getName().getRegisteredInfo()))
484 return op->emitOpError()
485 << "trying to schedule a pass on an unsupported operation";
486
487 // Initialize the pass state with a callback for the pass to dynamically
488 // execute a pipeline on the currently visited operation.
489 PassInstrumentor *pi = am.getPassInstrumentor();
490 PassInstrumentation::PipelineParentInfo parentInfo = {.parentThreadID: llvm::get_threadid(),
491 .parentPass: pass};
492 auto dynamicPipelineCallback = [&](OpPassManager &pipeline,
493 Operation *root) -> LogicalResult {
494 if (!op->isAncestor(other: root))
495 return root->emitOpError()
496 << "Trying to schedule a dynamic pipeline on an "
497 "operation that isn't "
498 "nested under the current operation the pass is processing";
499 assert(
500 pipeline.getImpl().canScheduleOn(*op->getContext(), root->getName()));
501
502 // Before running, finalize the passes held by the pipeline.
503 if (failed(result: pipeline.getImpl().finalizePassList(ctx: root->getContext())))
504 return failure();
505
506 // Initialize the user provided pipeline and execute the pipeline.
507 if (failed(result: pipeline.initialize(context: root->getContext(), newInitGeneration: parentInitGeneration)))
508 return failure();
509 AnalysisManager nestedAm = root == op ? am : am.nest(op: root);
510 return OpToOpPassAdaptor::runPipeline(pm&: pipeline, op: root, am: nestedAm,
511 verifyPasses, parentInitGeneration,
512 instrumentor: pi, parentInfo: &parentInfo);
513 };
514 pass->passState.emplace(args&: op, args&: am, args&: dynamicPipelineCallback);
515
516 // Instrument before the pass has run.
517 if (pi)
518 pi->runBeforePass(pass, op);
519
520 bool passFailed = false;
521 op->getContext()->executeAction<PassExecutionAction>(
522 actionFn: [&]() {
523 // Invoke the virtual runOnOperation method.
524 if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(Val: pass))
525 adaptor->runOnOperation(verifyPasses);
526 else
527 pass->runOnOperation();
528 passFailed = pass->passState->irAndPassFailed.getInt();
529 },
530 irUnits: {op}, args&: *pass);
531
532 // Invalidate any non preserved analyses.
533 am.invalidate(pa: pass->passState->preservedAnalyses);
534
535 // When verifyPasses is specified, we run the verifier (unless the pass
536 // failed).
537 if (!passFailed && verifyPasses) {
538 bool runVerifierNow = true;
539
540 // If the pass is an adaptor pass, we don't run the verifier recursively
541 // because the nested operations should have already been verified after
542 // nested passes had run.
543 bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(Val: pass);
544
545 // Reduce compile time by avoiding running the verifier if the pass didn't
546 // change the IR since the last time the verifier was run:
547 //
548 // 1) If the pass said that it preserved all analyses then it can't have
549 // permuted the IR.
550 //
551 // We run these checks in EXPENSIVE_CHECKS mode out of caution.
552#ifndef EXPENSIVE_CHECKS
553 runVerifierNow = !pass->passState->preservedAnalyses.isAll();
554#endif
555 if (runVerifierNow)
556 passFailed = failed(result: verify(op, verifyRecursively: runVerifierRecursively));
557 }
558
559 // Instrument after the pass has run.
560 if (pi) {
561 if (passFailed)
562 pi->runAfterPassFailed(pass, op);
563 else
564 pi->runAfterPass(pass, op);
565 }
566
567 // Return if the pass signaled a failure.
568 return failure(isFailure: passFailed);
569}
570
571/// Run the given operation and analysis manager on a provided op pass manager.
572LogicalResult OpToOpPassAdaptor::runPipeline(
573 OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses,
574 unsigned parentInitGeneration, PassInstrumentor *instrumentor,
575 const PassInstrumentation::PipelineParentInfo *parentInfo) {
576 assert((!instrumentor || parentInfo) &&
577 "expected parent info if instrumentor is provided");
578 auto scopeExit = llvm::make_scope_exit(F: [&] {
579 // Clear out any computed operation analyses. These analyses won't be used
580 // any more in this pipeline, and this helps reduce the current working set
581 // of memory. If preserving these analyses becomes important in the future
582 // we can re-evaluate this.
583 am.clear();
584 });
585
586 // Run the pipeline over the provided operation.
587 if (instrumentor) {
588 instrumentor->runBeforePipeline(name: pm.getOpName(context&: *op->getContext()),
589 parentInfo: *parentInfo);
590 }
591
592 for (Pass &pass : pm.getPasses())
593 if (failed(result: run(pass: &pass, op, am, verifyPasses, parentInitGeneration)))
594 return failure();
595
596 if (instrumentor) {
597 instrumentor->runAfterPipeline(name: pm.getOpName(context&: *op->getContext()),
598 parentInfo: *parentInfo);
599 }
600 return success();
601}
602
603/// Find an operation pass manager with the given anchor name, or nullptr if one
604/// does not exist.
605static OpPassManager *
606findPassManagerWithAnchor(MutableArrayRef<OpPassManager> mgrs, StringRef name) {
607 auto *it = llvm::find_if(
608 Range&: mgrs, P: [&](OpPassManager &mgr) { return mgr.getOpAnchorName() == name; });
609 return it == mgrs.end() ? nullptr : &*it;
610}
611
612/// Find an operation pass manager that can operate on an operation of the given
613/// type, or nullptr if one does not exist.
614static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
615 OperationName name,
616 MLIRContext &context) {
617 auto *it = llvm::find_if(Range&: mgrs, P: [&](OpPassManager &mgr) {
618 return mgr.getImpl().canScheduleOn(context, opName: name);
619 });
620 return it == mgrs.end() ? nullptr : &*it;
621}
622
623OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
624 mgrs.emplace_back(Args: std::move(mgr));
625}
626
627void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
628 for (auto &pm : mgrs)
629 pm.getDependentDialects(dialects);
630}
631
632LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx,
633 OpToOpPassAdaptor &rhs) {
634 // Functor used to check if a pass manager is generic, i.e. op-agnostic.
635 auto isGenericPM = [&](OpPassManager &pm) { return !pm.getOpName(); };
636
637 // Functor used to detect if the given generic pass manager will have a
638 // potential schedule conflict with the given `otherPMs`.
639 auto hasScheduleConflictWith = [&](OpPassManager &genericPM,
640 MutableArrayRef<OpPassManager> otherPMs) {
641 return llvm::any_of(Range&: otherPMs, P: [&](OpPassManager &pm) {
642 // If this is a non-generic pass manager, a conflict will arise if a
643 // non-generic pass manager's operation name can be scheduled on the
644 // generic passmanager.
645 if (std::optional<OperationName> pmOpName = pm.getOpName(context&: *ctx))
646 return genericPM.getImpl().canScheduleOn(context&: *ctx, opName: *pmOpName);
647 // Otherwise, this is a generic pass manager. We current can't determine
648 // when generic pass managers can be merged, so conservatively assume they
649 // conflict.
650 return true;
651 });
652 };
653
654 // Check that if either adaptor has a generic pass manager, that pm is
655 // compatible within any non-generic pass managers.
656 //
657 // Check the current adaptor.
658 auto *lhsGenericPMIt = llvm::find_if(Range&: mgrs, P: isGenericPM);
659 if (lhsGenericPMIt != mgrs.end() &&
660 hasScheduleConflictWith(*lhsGenericPMIt, rhs.mgrs))
661 return failure();
662 // Check the rhs adaptor.
663 auto *rhsGenericPMIt = llvm::find_if(Range&: rhs.mgrs, P: isGenericPM);
664 if (rhsGenericPMIt != rhs.mgrs.end() &&
665 hasScheduleConflictWith(*rhsGenericPMIt, mgrs))
666 return failure();
667
668 for (auto &pm : mgrs) {
669 // If an existing pass manager exists, then merge the given pass manager
670 // into it.
671 if (auto *existingPM =
672 findPassManagerWithAnchor(mgrs: rhs.mgrs, name: pm.getOpAnchorName())) {
673 pm.getImpl().mergeInto(rhs&: existingPM->getImpl());
674 } else {
675 // Otherwise, add the given pass manager to the list.
676 rhs.mgrs.emplace_back(Args: std::move(pm));
677 }
678 }
679 mgrs.clear();
680
681 // After coalescing, sort the pass managers within rhs by name.
682 auto compareFn = [](const OpPassManager *lhs, const OpPassManager *rhs) {
683 // Order op-specific pass managers first and op-agnostic pass managers last.
684 if (std::optional<StringRef> lhsName = lhs->getOpName()) {
685 if (std::optional<StringRef> rhsName = rhs->getOpName())
686 return lhsName->compare(RHS: *rhsName);
687 return -1; // lhs(op-specific) < rhs(op-agnostic)
688 }
689 return 1; // lhs(op-agnostic) > rhs(op-specific)
690 };
691 llvm::array_pod_sort(Start: rhs.mgrs.begin(), End: rhs.mgrs.end(), Compare: compareFn);
692 return success();
693}
694
695/// Returns the adaptor pass name.
696std::string OpToOpPassAdaptor::getAdaptorName() {
697 std::string name = "Pipeline Collection : [";
698 llvm::raw_string_ostream os(name);
699 llvm::interleaveComma(c: getPassManagers(), os, each_fn: [&](OpPassManager &pm) {
700 os << '\'' << pm.getOpAnchorName() << '\'';
701 });
702 os << ']';
703 return os.str();
704}
705
706void OpToOpPassAdaptor::runOnOperation() {
707 llvm_unreachable(
708 "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
709}
710
711/// Run the held pipeline over all nested operations.
712void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
713 if (getContext().isMultithreadingEnabled())
714 runOnOperationAsyncImpl(verifyPasses);
715 else
716 runOnOperationImpl(verifyPasses);
717}
718
719/// Run this pass adaptor synchronously.
720void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
721 auto am = getAnalysisManager();
722 PassInstrumentation::PipelineParentInfo parentInfo = {.parentThreadID: llvm::get_threadid(),
723 .parentPass: this};
724 auto *instrumentor = am.getPassInstrumentor();
725 for (auto &region : getOperation()->getRegions()) {
726 for (auto &block : region) {
727 for (auto &op : block) {
728 auto *mgr = findPassManagerFor(mgrs, name: op.getName(), context&: *op.getContext());
729 if (!mgr)
730 continue;
731
732 // Run the held pipeline over the current operation.
733 unsigned initGeneration = mgr->impl->initializationGeneration;
734 if (failed(result: runPipeline(pm&: *mgr, op: &op, am: am.nest(op: &op), verifyPasses,
735 parentInitGeneration: initGeneration, instrumentor, parentInfo: &parentInfo)))
736 return signalPassFailure();
737 }
738 }
739 }
740}
741
742/// Utility functor that checks if the two ranges of pass managers have a size
743/// mismatch.
744static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
745 ArrayRef<OpPassManager> rhs) {
746 return lhs.size() != rhs.size() ||
747 llvm::any_of(Range: llvm::seq<size_t>(Begin: 0, End: lhs.size()),
748 P: [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
749}
750
751/// Run this pass adaptor synchronously.
752void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
753 AnalysisManager am = getAnalysisManager();
754 MLIRContext *context = &getContext();
755
756 // Create the async executors if they haven't been created, or if the main
757 // pipeline has changed.
758 if (asyncExecutors.empty() || hasSizeMismatch(lhs: asyncExecutors.front(), rhs: mgrs))
759 asyncExecutors.assign(NumElts: context->getThreadPool().getMaxConcurrency(), Elt: mgrs);
760
761 // This struct represents the information for a single operation to be
762 // scheduled on a pass manager.
763 struct OpPMInfo {
764 OpPMInfo(unsigned passManagerIdx, Operation *op, AnalysisManager am)
765 : passManagerIdx(passManagerIdx), op(op), am(am) {}
766
767 /// The index of the pass manager to schedule the operation on.
768 unsigned passManagerIdx;
769 /// The operation to schedule.
770 Operation *op;
771 /// The analysis manager for the operation.
772 AnalysisManager am;
773 };
774
775 // Run a prepass over the operation to collect the nested operations to
776 // execute over. This ensures that an analysis manager exists for each
777 // operation, as well as providing a queue of operations to execute over.
778 std::vector<OpPMInfo> opInfos;
779 DenseMap<OperationName, std::optional<unsigned>> knownOpPMIdx;
780 for (auto &region : getOperation()->getRegions()) {
781 for (Operation &op : region.getOps()) {
782 // Get the pass manager index for this operation type.
783 auto pmIdxIt = knownOpPMIdx.try_emplace(Key: op.getName(), Args: std::nullopt);
784 if (pmIdxIt.second) {
785 if (auto *mgr = findPassManagerFor(mgrs, name: op.getName(), context&: *context))
786 pmIdxIt.first->second = std::distance(first: mgrs.begin(), last: mgr);
787 }
788
789 // If this operation can be scheduled, add it to the list.
790 if (pmIdxIt.first->second)
791 opInfos.emplace_back(args&: *pmIdxIt.first->second, args: &op, args: am.nest(op: &op));
792 }
793 }
794
795 // Get the current thread for this adaptor.
796 PassInstrumentation::PipelineParentInfo parentInfo = {.parentThreadID: llvm::get_threadid(),
797 .parentPass: this};
798 auto *instrumentor = am.getPassInstrumentor();
799
800 // An atomic failure variable for the async executors.
801 std::vector<std::atomic<bool>> activePMs(asyncExecutors.size());
802 std::fill(first: activePMs.begin(), last: activePMs.end(), value: false);
803 auto processFn = [&](OpPMInfo &opInfo) {
804 // Find an executor for this operation.
805 auto it = llvm::find_if(Range&: activePMs, P: [](std::atomic<bool> &isActive) {
806 bool expectedInactive = false;
807 return isActive.compare_exchange_strong(i1&: expectedInactive, i2: true);
808 });
809 unsigned pmIndex = it - activePMs.begin();
810
811 // Get the pass manager for this operation and execute it.
812 OpPassManager &pm = asyncExecutors[pmIndex][opInfo.passManagerIdx];
813 LogicalResult pipelineResult = runPipeline(
814 pm, op: opInfo.op, am: opInfo.am, verifyPasses,
815 parentInitGeneration: pm.impl->initializationGeneration, instrumentor, parentInfo: &parentInfo);
816
817 // Reset the active bit for this pass manager.
818 activePMs[pmIndex].store(i: false);
819 return pipelineResult;
820 };
821
822 // Signal a failure if any of the executors failed.
823 if (failed(result: failableParallelForEach(context, range&: opInfos, func&: processFn)))
824 signalPassFailure();
825}
826
827//===----------------------------------------------------------------------===//
828// PassManager
829//===----------------------------------------------------------------------===//
830
831PassManager::PassManager(MLIRContext *ctx, StringRef operationName,
832 Nesting nesting)
833 : OpPassManager(operationName, nesting), context(ctx), passTiming(false),
834 verifyPasses(true) {}
835
836PassManager::PassManager(OperationName operationName, Nesting nesting)
837 : OpPassManager(operationName, nesting),
838 context(operationName.getContext()), passTiming(false),
839 verifyPasses(true) {}
840
841PassManager::~PassManager() = default;
842
843void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
844
845/// Run the passes within this manager on the provided operation.
846LogicalResult PassManager::run(Operation *op) {
847 MLIRContext *context = getContext();
848 std::optional<OperationName> anchorOp = getOpName(context&: *context);
849 if (anchorOp && anchorOp != op->getName())
850 return emitError(loc: op->getLoc())
851 << "can't run '" << getOpAnchorName() << "' pass manager on '"
852 << op->getName() << "' op";
853
854 // Register all dialects for the current pipeline.
855 DialectRegistry dependentDialects;
856 getDependentDialects(dialects&: dependentDialects);
857 context->appendDialectRegistry(registry: dependentDialects);
858 for (StringRef name : dependentDialects.getDialectNames())
859 context->getOrLoadDialect(name);
860
861 // Before running, make sure to finalize the pipeline pass list.
862 if (failed(result: getImpl().finalizePassList(ctx: context)))
863 return failure();
864
865 // Notify the context that we start running a pipeline for bookkeeping.
866 context->enterMultiThreadedExecution();
867
868 // Initialize all of the passes within the pass manager with a new generation.
869 llvm::hash_code newInitKey = context->getRegistryHash();
870 llvm::hash_code pipelineKey = hash();
871 if (newInitKey != initializationKey || pipelineKey != pipelineInitializationKey) {
872 if (failed(result: initialize(context, newInitGeneration: impl->initializationGeneration + 1)))
873 return failure();
874 initializationKey = newInitKey;
875 pipelineKey = pipelineInitializationKey;
876 }
877
878 // Construct a top level analysis manager for the pipeline.
879 ModuleAnalysisManager am(op, instrumentor.get());
880
881 // If reproducer generation is enabled, run the pass manager with crash
882 // handling enabled.
883 LogicalResult result =
884 crashReproGenerator ? runWithCrashRecovery(op, am) : runPasses(op, am);
885
886 // Notify the context that the run is done.
887 context->exitMultiThreadedExecution();
888
889 // Dump all of the pass statistics if necessary.
890 if (passStatisticsMode)
891 dumpStatistics();
892 return result;
893}
894
895/// Add the provided instrumentation to the pass manager.
896void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
897 if (!instrumentor)
898 instrumentor = std::make_unique<PassInstrumentor>();
899
900 instrumentor->addInstrumentation(pi: std::move(pi));
901}
902
903LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) {
904 return OpToOpPassAdaptor::runPipeline(pm&: *this, op, am, verifyPasses,
905 parentInitGeneration: impl->initializationGeneration);
906}
907
908//===----------------------------------------------------------------------===//
909// AnalysisManager
910//===----------------------------------------------------------------------===//
911
912/// Get an analysis manager for the given operation, which must be a proper
913/// descendant of the current operation represented by this analysis manager.
914AnalysisManager AnalysisManager::nest(Operation *op) {
915 Operation *currentOp = impl->getOperation();
916 assert(currentOp->isProperAncestor(op) &&
917 "expected valid descendant operation");
918
919 // Check for the base case where the provided operation is immediately nested.
920 if (currentOp == op->getParentOp())
921 return nestImmediate(op);
922
923 // Otherwise, we need to collect all ancestors up to the current operation.
924 SmallVector<Operation *, 4> opAncestors;
925 do {
926 opAncestors.push_back(Elt: op);
927 op = op->getParentOp();
928 } while (op != currentOp);
929
930 AnalysisManager result = *this;
931 for (Operation *op : llvm::reverse(C&: opAncestors))
932 result = result.nestImmediate(op);
933 return result;
934}
935
936/// Get an analysis manager for the given immediately nested child operation.
937AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
938 assert(impl->getOperation() == op->getParentOp() &&
939 "expected immediate child operation");
940
941 auto it = impl->childAnalyses.find(Val: op);
942 if (it == impl->childAnalyses.end())
943 it = impl->childAnalyses
944 .try_emplace(Key: op, Args: std::make_unique<NestedAnalysisMap>(args&: op, args&: impl))
945 .first;
946 return {it->second.get()};
947}
948
949/// Invalidate any non preserved analyses.
950void detail::NestedAnalysisMap::invalidate(
951 const detail::PreservedAnalyses &pa) {
952 // If all analyses were preserved, then there is nothing to do here.
953 if (pa.isAll())
954 return;
955
956 // Invalidate the analyses for the current operation directly.
957 analyses.invalidate(pa);
958
959 // If no analyses were preserved, then just simply clear out the child
960 // analysis results.
961 if (pa.isNone()) {
962 childAnalyses.clear();
963 return;
964 }
965
966 // Otherwise, invalidate each child analysis map.
967 SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
968 while (!mapsToInvalidate.empty()) {
969 auto *map = mapsToInvalidate.pop_back_val();
970 for (auto &analysisPair : map->childAnalyses) {
971 analysisPair.second->invalidate(pa);
972 if (!analysisPair.second->childAnalyses.empty())
973 mapsToInvalidate.push_back(Elt: analysisPair.second.get());
974 }
975 }
976}
977
978//===----------------------------------------------------------------------===//
979// PassInstrumentation
980//===----------------------------------------------------------------------===//
981
982PassInstrumentation::~PassInstrumentation() = default;
983
984void PassInstrumentation::runBeforePipeline(
985 std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
986
987void PassInstrumentation::runAfterPipeline(
988 std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
989
990//===----------------------------------------------------------------------===//
991// PassInstrumentor
992//===----------------------------------------------------------------------===//
993
994namespace mlir {
995namespace detail {
996struct PassInstrumentorImpl {
997 /// Mutex to keep instrumentation access thread-safe.
998 llvm::sys::SmartMutex<true> mutex;
999
1000 /// Set of registered instrumentations.
1001 std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
1002};
1003} // namespace detail
1004} // namespace mlir
1005
1006PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
1007PassInstrumentor::~PassInstrumentor() = default;
1008
1009/// See PassInstrumentation::runBeforePipeline for details.
1010void PassInstrumentor::runBeforePipeline(
1011 std::optional<OperationName> name,
1012 const PassInstrumentation::PipelineParentInfo &parentInfo) {
1013 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1014 for (auto &instr : impl->instrumentations)
1015 instr->runBeforePipeline(name, parentInfo);
1016}
1017
1018/// See PassInstrumentation::runAfterPipeline for details.
1019void PassInstrumentor::runAfterPipeline(
1020 std::optional<OperationName> name,
1021 const PassInstrumentation::PipelineParentInfo &parentInfo) {
1022 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1023 for (auto &instr : llvm::reverse(C&: impl->instrumentations))
1024 instr->runAfterPipeline(name, parentInfo);
1025}
1026
1027/// See PassInstrumentation::runBeforePass for details.
1028void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
1029 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1030 for (auto &instr : impl->instrumentations)
1031 instr->runBeforePass(pass, op);
1032}
1033
1034/// See PassInstrumentation::runAfterPass for details.
1035void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
1036 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1037 for (auto &instr : llvm::reverse(C&: impl->instrumentations))
1038 instr->runAfterPass(pass, op);
1039}
1040
1041/// See PassInstrumentation::runAfterPassFailed for details.
1042void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
1043 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1044 for (auto &instr : llvm::reverse(C&: impl->instrumentations))
1045 instr->runAfterPassFailed(pass, op);
1046}
1047
1048/// See PassInstrumentation::runBeforeAnalysis for details.
1049void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
1050 Operation *op) {
1051 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1052 for (auto &instr : impl->instrumentations)
1053 instr->runBeforeAnalysis(name, id, op);
1054}
1055
1056/// See PassInstrumentation::runAfterAnalysis for details.
1057void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
1058 Operation *op) {
1059 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1060 for (auto &instr : llvm::reverse(C&: impl->instrumentations))
1061 instr->runAfterAnalysis(name, id, op);
1062}
1063
1064/// Add the given instrumentation to the collection.
1065void PassInstrumentor::addInstrumentation(
1066 std::unique_ptr<PassInstrumentation> pi) {
1067 llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
1068 impl->instrumentations.emplace_back(args: std::move(pi));
1069}
1070

source code of mlir/lib/Pass/Pass.cpp