1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "PassDetail.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/IR/Block.h" |
12 | #include "mlir/IR/Operation.h" |
13 | #include "mlir/IR/PatternMatch.h" |
14 | #include "mlir/IR/Region.h" |
15 | #include "mlir/Support/LogicalResult.h" |
16 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
17 | #include "clang/CIR/Dialect/IR/CIRDialect.h" |
18 | #include "clang/CIR/Dialect/Passes.h" |
19 | #include "llvm/ADT/SmallVector.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace cir; |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // Rewrite patterns |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | namespace { |
29 | |
30 | /// Simplify suitable ternary operations into select operations. |
31 | /// |
32 | /// For now we only simplify those ternary operations whose true and false |
33 | /// branches directly yield a value or a constant. That is, both of the true and |
34 | /// the false branch must either contain a cir.yield operation as the only |
35 | /// operation in the branch, or contain a cir.const operation followed by a |
36 | /// cir.yield operation that yields the constant value. |
37 | /// |
38 | /// For example, we will simplify the following ternary operation: |
39 | /// |
40 | /// %0 = ... |
41 | /// %1 = cir.ternary (%condition, true { |
42 | /// %2 = cir.const ... |
43 | /// cir.yield %2 |
44 | /// } false { |
45 | /// cir.yield %0 |
46 | /// |
47 | /// into the following sequence of operations: |
48 | /// |
49 | /// %1 = cir.const ... |
50 | /// %0 = cir.select if %condition then %1 else %2 |
51 | struct SimplifyTernary final : public OpRewritePattern<TernaryOp> { |
52 | using OpRewritePattern<TernaryOp>::OpRewritePattern; |
53 | |
54 | LogicalResult matchAndRewrite(TernaryOp op, |
55 | PatternRewriter &rewriter) const override { |
56 | if (op->getNumResults() != 1) |
57 | return mlir::failure(); |
58 | |
59 | if (!isSimpleTernaryBranch(region&: op.getTrueRegion()) || |
60 | !isSimpleTernaryBranch(region&: op.getFalseRegion())) |
61 | return mlir::failure(); |
62 | |
63 | cir::YieldOp trueBranchYieldOp = |
64 | mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator()); |
65 | cir::YieldOp falseBranchYieldOp = |
66 | mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator()); |
67 | mlir::Value trueValue = trueBranchYieldOp.getArgs()[0]; |
68 | mlir::Value falseValue = falseBranchYieldOp.getArgs()[0]; |
69 | |
70 | rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op); |
71 | rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op); |
72 | rewriter.eraseOp(op: trueBranchYieldOp); |
73 | rewriter.eraseOp(op: falseBranchYieldOp); |
74 | rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue, |
75 | falseValue); |
76 | |
77 | return mlir::success(); |
78 | } |
79 | |
80 | private: |
81 | bool isSimpleTernaryBranch(mlir::Region ®ion) const { |
82 | if (!region.hasOneBlock()) |
83 | return false; |
84 | |
85 | mlir::Block &onlyBlock = region.front(); |
86 | mlir::Block::OpListType &ops = onlyBlock.getOperations(); |
87 | |
88 | // The region/block could only contain at most 2 operations. |
89 | if (ops.size() > 2) |
90 | return false; |
91 | |
92 | if (ops.size() == 1) { |
93 | // The region/block only contain a cir.yield operation. |
94 | return true; |
95 | } |
96 | |
97 | // Check whether the region/block contains a cir.const followed by a |
98 | // cir.yield that yields the value. |
99 | auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator()); |
100 | auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>( |
101 | yieldOp.getArgs()[0].getDefiningOp()); |
102 | return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock; |
103 | } |
104 | }; |
105 | |
106 | /// Simplify select operations with boolean constants into simpler forms. |
107 | /// |
108 | /// This pattern simplifies select operations where both true and false values |
109 | /// are boolean constants. Two specific cases are handled: |
110 | /// |
111 | /// 1. When selecting between true and false based on a condition, |
112 | /// the operation simplifies to just the condition itself: |
113 | /// |
114 | /// %0 = cir.select if %condition then true else false |
115 | /// -> |
116 | /// (replaced with %condition directly) |
117 | /// |
118 | /// 2. When selecting between false and true based on a condition, |
119 | /// the operation simplifies to the logical negation of the condition: |
120 | /// |
121 | /// %0 = cir.select if %condition then false else true |
122 | /// -> |
123 | /// %0 = cir.unary not %condition |
124 | struct SimplifySelect : public OpRewritePattern<SelectOp> { |
125 | using OpRewritePattern<SelectOp>::OpRewritePattern; |
126 | |
127 | LogicalResult matchAndRewrite(SelectOp op, |
128 | PatternRewriter &rewriter) const final { |
129 | mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp(); |
130 | mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp(); |
131 | auto trueValueConstOp = |
132 | mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp); |
133 | auto falseValueConstOp = |
134 | mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp); |
135 | if (!trueValueConstOp || !falseValueConstOp) |
136 | return mlir::failure(); |
137 | |
138 | auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue()); |
139 | auto falseValue = |
140 | mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue()); |
141 | if (!trueValue || !falseValue) |
142 | return mlir::failure(); |
143 | |
144 | // cir.select if %0 then #true else #false -> %0 |
145 | if (trueValue.getValue() && !falseValue.getValue()) { |
146 | rewriter.replaceAllUsesWith(op, op.getCondition()); |
147 | rewriter.eraseOp(op: op); |
148 | return mlir::success(); |
149 | } |
150 | |
151 | // cir.select if %0 then #false else #true -> cir.unary not %0 |
152 | if (!trueValue.getValue() && falseValue.getValue()) { |
153 | rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not, |
154 | op.getCondition()); |
155 | return mlir::success(); |
156 | } |
157 | |
158 | return mlir::failure(); |
159 | } |
160 | }; |
161 | |
162 | /// Simplify `cir.switch` operations by folding cascading cases |
163 | /// into a single `cir.case` with the `anyof` kind. |
164 | /// |
165 | /// This pattern identifies cascading cases within a `cir.switch` operation. |
166 | /// Cascading cases are defined as consecutive `cir.case` operations of kind |
167 | /// `equal`, each containing a single `cir.yield` operation in their body. |
168 | /// |
169 | /// The pattern merges these cascading cases into a single `cir.case` operation |
170 | /// with kind `anyof`, aggregating all the case values. |
171 | /// |
172 | /// The merging process continues until a `cir.case` with a different body |
173 | /// (e.g., containing `cir.break` or compound stmt) is encountered, which |
174 | /// breaks the chain. |
175 | /// |
176 | /// Example: |
177 | /// |
178 | /// Before: |
179 | /// cir.case equal, [#cir.int<0> : !s32i] { |
180 | /// cir.yield |
181 | /// } |
182 | /// cir.case equal, [#cir.int<1> : !s32i] { |
183 | /// cir.yield |
184 | /// } |
185 | /// cir.case equal, [#cir.int<2> : !s32i] { |
186 | /// cir.break |
187 | /// } |
188 | /// |
189 | /// After applying SimplifySwitch: |
190 | /// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : |
191 | /// !s32i] { |
192 | /// cir.break |
193 | /// } |
194 | struct SimplifySwitch : public OpRewritePattern<SwitchOp> { |
195 | using OpRewritePattern<SwitchOp>::OpRewritePattern; |
196 | LogicalResult matchAndRewrite(SwitchOp op, |
197 | PatternRewriter &rewriter) const override { |
198 | |
199 | LogicalResult changed = mlir::failure(); |
200 | SmallVector<CaseOp, 8> cases; |
201 | SmallVector<CaseOp, 4> cascadingCases; |
202 | SmallVector<mlir::Attribute, 4> cascadingCaseValues; |
203 | |
204 | op.collectCases(cases); |
205 | if (cases.empty()) |
206 | return mlir::failure(); |
207 | |
208 | auto flushMergedOps = [&]() { |
209 | for (CaseOp &c : cascadingCases) |
210 | rewriter.eraseOp(c); |
211 | cascadingCases.clear(); |
212 | cascadingCaseValues.clear(); |
213 | }; |
214 | |
215 | auto mergeCascadingInto = [&](CaseOp &target) { |
216 | rewriter.modifyOpInPlace(target, [&]() { |
217 | target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues)); |
218 | target.setKind(CaseOpKind::Anyof); |
219 | }); |
220 | changed = mlir::success(); |
221 | }; |
222 | |
223 | for (CaseOp c : cases) { |
224 | cir::CaseOpKind kind = c.getKind(); |
225 | if (kind == cir::CaseOpKind::Equal && |
226 | isa<YieldOp>(c.getCaseRegion().front().front())) { |
227 | // If the case contains only a YieldOp, collect it for cascading merge |
228 | cascadingCases.push_back(c); |
229 | cascadingCaseValues.push_back(c.getValue()[0]); |
230 | } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) { |
231 | // merge previously collected cascading cases |
232 | cascadingCaseValues.push_back(c.getValue()[0]); |
233 | mergeCascadingInto(c); |
234 | flushMergedOps(); |
235 | } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) { |
236 | // If a Default, Anyof or Range case is found and there are previous |
237 | // cascading cases, merge all of them into the last cascading case. |
238 | // We don't currently fold case range statements with other case |
239 | // statements. |
240 | assert(!cir::MissingFeatures::foldRangeCase()); |
241 | CaseOp lastCascadingCase = cascadingCases.back(); |
242 | mergeCascadingInto(lastCascadingCase); |
243 | cascadingCases.pop_back(); |
244 | flushMergedOps(); |
245 | } else { |
246 | cascadingCases.clear(); |
247 | cascadingCaseValues.clear(); |
248 | } |
249 | } |
250 | |
251 | // Edge case: all cases are simple cascading cases |
252 | if (cascadingCases.size() == cases.size()) { |
253 | CaseOp lastCascadingCase = cascadingCases.back(); |
254 | mergeCascadingInto(lastCascadingCase); |
255 | cascadingCases.pop_back(); |
256 | flushMergedOps(); |
257 | } |
258 | |
259 | return changed; |
260 | } |
261 | }; |
262 | |
263 | //===----------------------------------------------------------------------===// |
264 | // CIRSimplifyPass |
265 | //===----------------------------------------------------------------------===// |
266 | |
267 | struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> { |
268 | using CIRSimplifyBase::CIRSimplifyBase; |
269 | |
270 | void runOnOperation() override; |
271 | }; |
272 | |
273 | void populateMergeCleanupPatterns(RewritePatternSet &patterns) { |
274 | // clang-format off |
275 | patterns.add< |
276 | SimplifyTernary, |
277 | SimplifySelect, |
278 | SimplifySwitch |
279 | >(arg: patterns.getContext()); |
280 | // clang-format on |
281 | } |
282 | |
283 | void CIRSimplifyPass::runOnOperation() { |
284 | // Collect rewrite patterns. |
285 | RewritePatternSet patterns(&getContext()); |
286 | populateMergeCleanupPatterns(patterns); |
287 | |
288 | // Collect operations to apply patterns. |
289 | llvm::SmallVector<Operation *, 16> ops; |
290 | getOperation()->walk([&](Operation *op) { |
291 | if (isa<TernaryOp, SelectOp, SwitchOp>(op)) |
292 | ops.push_back(Elt: op); |
293 | }); |
294 | |
295 | // Apply patterns. |
296 | if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) |
297 | signalPassFailure(); |
298 | } |
299 | |
300 | } // namespace |
301 | |
302 | std::unique_ptr<Pass> mlir::createCIRSimplifyPass() { |
303 | return std::make_unique<CIRSimplifyPass>(); |
304 | } |
305 | |