1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements pass that canonicalizes CIR operations, eliminating |
10 | // redundant branches, empty scopes, and other unnecessary operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "PassDetail.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/IR/Block.h" |
17 | #include "mlir/IR/Operation.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/IR/Region.h" |
20 | #include "mlir/Support/LogicalResult.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | #include "clang/CIR/Dialect/IR/CIRDialect.h" |
23 | #include "clang/CIR/Dialect/Passes.h" |
24 | #include "clang/CIR/MissingFeatures.h" |
25 | |
26 | using namespace mlir; |
27 | using namespace cir; |
28 | |
29 | namespace { |
30 | |
31 | /// Removes branches between two blocks if it is the only branch. |
32 | /// |
33 | /// From: |
34 | /// ^bb0: |
35 | /// cir.br ^bb1 |
36 | /// ^bb1: // pred: ^bb0 |
37 | /// cir.return |
38 | /// |
39 | /// To: |
40 | /// ^bb0: |
41 | /// cir.return |
42 | struct RemoveRedundantBranches : public OpRewritePattern<BrOp> { |
43 | using OpRewritePattern<BrOp>::OpRewritePattern; |
44 | |
45 | LogicalResult matchAndRewrite(BrOp op, |
46 | PatternRewriter &rewriter) const final { |
47 | Block *block = op.getOperation()->getBlock(); |
48 | Block *dest = op.getDest(); |
49 | |
50 | assert(!cir::MissingFeatures::labelOp()); |
51 | |
52 | // Single edge between blocks: merge it. |
53 | if (block->getNumSuccessors() == 1 && |
54 | dest->getSinglePredecessor() == block) { |
55 | rewriter.eraseOp(op: op); |
56 | rewriter.mergeBlocks(source: dest, dest: block); |
57 | return success(); |
58 | } |
59 | |
60 | return failure(); |
61 | } |
62 | }; |
63 | |
64 | struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> { |
65 | using OpRewritePattern<ScopeOp>::OpRewritePattern; |
66 | |
67 | LogicalResult matchAndRewrite(ScopeOp op, |
68 | PatternRewriter &rewriter) const final { |
69 | // TODO: Remove this logic once CIR uses MLIR infrastructure to remove |
70 | // trivially dead operations |
71 | if (op.isEmpty()) { |
72 | rewriter.eraseOp(op: op); |
73 | return success(); |
74 | } |
75 | |
76 | Region ®ion = op.getScopeRegion(); |
77 | if (region.getBlocks().front().getOperations().size() == 1 && |
78 | isa<YieldOp>(region.getBlocks().front().front())) { |
79 | rewriter.eraseOp(op: op); |
80 | return success(); |
81 | } |
82 | |
83 | return failure(); |
84 | } |
85 | }; |
86 | |
87 | struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> { |
88 | using OpRewritePattern<SwitchOp>::OpRewritePattern; |
89 | |
90 | LogicalResult matchAndRewrite(SwitchOp op, |
91 | PatternRewriter &rewriter) const final { |
92 | if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front()))) |
93 | return failure(); |
94 | |
95 | rewriter.eraseOp(op: op); |
96 | return success(); |
97 | } |
98 | }; |
99 | |
100 | //===----------------------------------------------------------------------===// |
101 | // CIRCanonicalizePass |
102 | //===----------------------------------------------------------------------===// |
103 | |
104 | struct CIRCanonicalizePass : public CIRCanonicalizeBase<CIRCanonicalizePass> { |
105 | using CIRCanonicalizeBase::CIRCanonicalizeBase; |
106 | |
107 | // The same operation rewriting done here could have been performed |
108 | // by CanonicalizerPass (adding hasCanonicalizer for target Ops and |
109 | // implementing the same from above in CIRDialects.cpp). However, it's |
110 | // currently too aggressive for static analysis purposes, since it might |
111 | // remove things where a diagnostic can be generated. |
112 | // |
113 | // FIXME: perhaps we can add one more mode to GreedyRewriteConfig to |
114 | // disable this behavior. |
115 | void runOnOperation() override; |
116 | }; |
117 | |
118 | void populateCIRCanonicalizePatterns(RewritePatternSet &patterns) { |
119 | // clang-format off |
120 | patterns.add< |
121 | RemoveRedundantBranches, |
122 | RemoveEmptyScope |
123 | >(arg: patterns.getContext()); |
124 | // clang-format on |
125 | } |
126 | |
127 | void CIRCanonicalizePass::runOnOperation() { |
128 | // Collect rewrite patterns. |
129 | RewritePatternSet patterns(&getContext()); |
130 | populateCIRCanonicalizePatterns(patterns); |
131 | |
132 | // Collect operations to apply patterns. |
133 | llvm::SmallVector<Operation *, 16> ops; |
134 | getOperation()->walk([&](Operation *op) { |
135 | assert(!cir::MissingFeatures::switchOp()); |
136 | assert(!cir::MissingFeatures::tryOp()); |
137 | assert(!cir::MissingFeatures::complexCreateOp()); |
138 | assert(!cir::MissingFeatures::complexRealOp()); |
139 | assert(!cir::MissingFeatures::complexImagOp()); |
140 | assert(!cir::MissingFeatures::callOp()); |
141 | |
142 | // Many operations are here to perform a manual `fold` in |
143 | // applyOpPatternsGreedily. |
144 | if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, |
145 | VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op)) |
146 | ops.push_back(Elt: op); |
147 | }); |
148 | |
149 | // Apply patterns. |
150 | if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) |
151 | signalPassFailure(); |
152 | } |
153 | |
154 | } // namespace |
155 | |
156 | std::unique_ptr<Pass> mlir::createCIRCanonicalizePass() { |
157 | return std::make_unique<CIRCanonicalizePass>(); |
158 | } |
159 | |