1//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implements mlir::walkAndApplyPatterns.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Transforms/WalkPatternRewriteDriver.h"
14
15#include "mlir/IR/MLIRContext.h"
16#include "mlir/IR/OperationSupport.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/IR/Verifier.h"
19#include "mlir/IR/Visitors.h"
20#include "mlir/Rewrite/PatternApplicator.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/ErrorHandling.h"
23
24#define DEBUG_TYPE "walk-rewriter"
25
26namespace mlir {
27
28namespace {
29struct WalkAndApplyPatternsAction final
30 : tracing::ActionImpl<WalkAndApplyPatternsAction> {
31 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
32 using ActionImpl::ActionImpl;
33 static constexpr StringLiteral tag = "walk-and-apply-patterns";
34 void print(raw_ostream &os) const override { os << tag; }
35};
36
37#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
38// Forwarding listener to guard against unsupported erasures of non-descendant
39// ops/blocks. Because we use walk-based pattern application, erasing the
40// op/block from the *next* iteration (e.g., a user of the visited op) is not
41// valid. Note that this is only used with expensive pattern API checks.
42struct ErasedOpsListener final : RewriterBase::ForwardingListener {
43 using RewriterBase::ForwardingListener::ForwardingListener;
44
45 void notifyOperationErased(Operation *op) override {
46 checkErasure(op);
47 ForwardingListener::notifyOperationErased(op);
48 }
49
50 void notifyBlockErased(Block *block) override {
51 checkErasure(block->getParentOp());
52 ForwardingListener::notifyBlockErased(block);
53 }
54
55 void checkErasure(Operation *op) const {
56 Operation *ancestorOp = op;
57 while (ancestorOp && ancestorOp != visitedOp)
58 ancestorOp = ancestorOp->getParentOp();
59
60 if (ancestorOp != visitedOp)
61 llvm::report_fatal_error(
62 "unsupported erasure in WalkPatternRewriter; "
63 "erasure is only supported for matched ops and their descendants");
64 }
65
66 Operation *visitedOp = nullptr;
67};
68#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
69} // namespace
70
71void walkAndApplyPatterns(Operation *op,
72 const FrozenRewritePatternSet &patterns,
73 RewriterBase::Listener *listener) {
74#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
75 if (failed(verify(op)))
76 llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
77#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
78
79 MLIRContext *ctx = op->getContext();
80 PatternRewriter rewriter(ctx);
81#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
82 ErasedOpsListener erasedListener(listener);
83 rewriter.setListener(&erasedListener);
84#else
85 rewriter.setListener(listener);
86#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
87
88 PatternApplicator applicator(patterns);
89 applicator.applyDefaultCostModel();
90
91 ctx->executeAction<WalkAndApplyPatternsAction>(
92 actionFn: [&] {
93 for (Region &region : op->getRegions()) {
94 region.walk(callback: [&](Operation *visitedOp) {
95 LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
96 llvm::dbgs(), OpPrintingFlags().skipRegions());
97 llvm::dbgs() << "\n";);
98#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99 erasedListener.visitedOp = visitedOp;
100#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
101 if (succeeded(Result: applicator.matchAndRewrite(op: visitedOp, rewriter))) {
102 LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
103 }
104 });
105 }
106 },
107 irUnits: {op});
108
109#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
110 if (failed(verify(op)))
111 llvm::report_fatal_error(
112 "walk pattern rewriter result IR failed to verify");
113#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
114}
115
116} // namespace mlir
117

source code of mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp