1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements pass that canonicalizes CIR operations, eliminating
10// redundant branches, empty scopes, and other unnecessary operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/IR/Region.h"
20#include "mlir/Support/LogicalResult.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22#include "clang/CIR/Dialect/IR/CIRDialect.h"
23#include "clang/CIR/Dialect/Passes.h"
24#include "clang/CIR/MissingFeatures.h"
25
26using namespace mlir;
27using namespace cir;
28
29namespace {
30
31/// Removes branches between two blocks if it is the only branch.
32///
33/// From:
34/// ^bb0:
35/// cir.br ^bb1
36/// ^bb1: // pred: ^bb0
37/// cir.return
38///
39/// To:
40/// ^bb0:
41/// cir.return
42struct RemoveRedundantBranches : public OpRewritePattern<BrOp> {
43 using OpRewritePattern<BrOp>::OpRewritePattern;
44
45 LogicalResult matchAndRewrite(BrOp op,
46 PatternRewriter &rewriter) const final {
47 Block *block = op.getOperation()->getBlock();
48 Block *dest = op.getDest();
49
50 assert(!cir::MissingFeatures::labelOp());
51
52 // Single edge between blocks: merge it.
53 if (block->getNumSuccessors() == 1 &&
54 dest->getSinglePredecessor() == block) {
55 rewriter.eraseOp(op: op);
56 rewriter.mergeBlocks(source: dest, dest: block);
57 return success();
58 }
59
60 return failure();
61 }
62};
63
64struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
65 using OpRewritePattern<ScopeOp>::OpRewritePattern;
66
67 LogicalResult matchAndRewrite(ScopeOp op,
68 PatternRewriter &rewriter) const final {
69 // TODO: Remove this logic once CIR uses MLIR infrastructure to remove
70 // trivially dead operations
71 if (op.isEmpty()) {
72 rewriter.eraseOp(op: op);
73 return success();
74 }
75
76 Region &region = op.getScopeRegion();
77 if (region.getBlocks().front().getOperations().size() == 1 &&
78 isa<YieldOp>(region.getBlocks().front().front())) {
79 rewriter.eraseOp(op: op);
80 return success();
81 }
82
83 return failure();
84 }
85};
86
87struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> {
88 using OpRewritePattern<SwitchOp>::OpRewritePattern;
89
90 LogicalResult matchAndRewrite(SwitchOp op,
91 PatternRewriter &rewriter) const final {
92 if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front())))
93 return failure();
94
95 rewriter.eraseOp(op: op);
96 return success();
97 }
98};
99
100//===----------------------------------------------------------------------===//
101// CIRCanonicalizePass
102//===----------------------------------------------------------------------===//
103
104struct CIRCanonicalizePass : public CIRCanonicalizeBase<CIRCanonicalizePass> {
105 using CIRCanonicalizeBase::CIRCanonicalizeBase;
106
107 // The same operation rewriting done here could have been performed
108 // by CanonicalizerPass (adding hasCanonicalizer for target Ops and
109 // implementing the same from above in CIRDialects.cpp). However, it's
110 // currently too aggressive for static analysis purposes, since it might
111 // remove things where a diagnostic can be generated.
112 //
113 // FIXME: perhaps we can add one more mode to GreedyRewriteConfig to
114 // disable this behavior.
115 void runOnOperation() override;
116};
117
118void populateCIRCanonicalizePatterns(RewritePatternSet &patterns) {
119 // clang-format off
120 patterns.add<
121 RemoveRedundantBranches,
122 RemoveEmptyScope
123 >(arg: patterns.getContext());
124 // clang-format on
125}
126
127void CIRCanonicalizePass::runOnOperation() {
128 // Collect rewrite patterns.
129 RewritePatternSet patterns(&getContext());
130 populateCIRCanonicalizePatterns(patterns);
131
132 // Collect operations to apply patterns.
133 llvm::SmallVector<Operation *, 16> ops;
134 getOperation()->walk([&](Operation *op) {
135 assert(!cir::MissingFeatures::switchOp());
136 assert(!cir::MissingFeatures::tryOp());
137 assert(!cir::MissingFeatures::complexCreateOp());
138 assert(!cir::MissingFeatures::complexRealOp());
139 assert(!cir::MissingFeatures::complexImagOp());
140 assert(!cir::MissingFeatures::callOp());
141
142 // Many operations are here to perform a manual `fold` in
143 // applyOpPatternsGreedily.
144 if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
145 VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op))
146 ops.push_back(Elt: op);
147 });
148
149 // Apply patterns.
150 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
151 signalPassFailure();
152}
153
154} // namespace
155
156std::unique_ptr<Pass> mlir::createCIRCanonicalizePass() {
157 return std::make_unique<CIRCanonicalizePass>();
158}
159

source code of clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp