1 | //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
14 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
15 | #include "mlir/IR/AffineExpr.h" |
16 | #include "mlir/IR/AffineMap.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinOps.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/TypeUtilities.h" |
25 | #include "mlir/IR/Value.h" |
26 | #include "mlir/Support/MathExtras.h" |
27 | #include "mlir/Transforms/InliningUtils.h" |
28 | #include "llvm/ADT/APFloat.h" |
29 | #include "llvm/ADT/STLExtras.h" |
30 | #include "llvm/Support/FormatVariadic.h" |
31 | #include "llvm/Support/raw_ostream.h" |
32 | #include <numeric> |
33 | |
34 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::cf; |
38 | |
39 | //===----------------------------------------------------------------------===// |
40 | // ControlFlowDialect Interfaces |
41 | //===----------------------------------------------------------------------===// |
42 | namespace { |
43 | /// This class defines the interface for handling inlining with control flow |
44 | /// operations. |
45 | struct ControlFlowInlinerInterface : public DialectInlinerInterface { |
46 | using DialectInlinerInterface::DialectInlinerInterface; |
47 | ~ControlFlowInlinerInterface() override = default; |
48 | |
49 | /// All control flow operations can be inlined. |
50 | bool isLegalToInline(Operation *call, Operation *callable, |
51 | bool wouldBeCloned) const final { |
52 | return true; |
53 | } |
54 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
55 | return true; |
56 | } |
57 | |
58 | /// ControlFlow terminator operations don't really need any special handing. |
59 | void handleTerminator(Operation *op, Block *newDest) const final {} |
60 | }; |
61 | } // namespace |
62 | |
63 | //===----------------------------------------------------------------------===// |
64 | // ControlFlowDialect |
65 | //===----------------------------------------------------------------------===// |
66 | |
67 | void ControlFlowDialect::initialize() { |
68 | addOperations< |
69 | #define GET_OP_LIST |
70 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" |
71 | >(); |
72 | addInterfaces<ControlFlowInlinerInterface>(); |
73 | declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>(); |
74 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp, |
75 | CondBranchOp>(); |
76 | declarePromisedInterface<bufferization::BufferDeallocationOpInterface, |
77 | CondBranchOp>(); |
78 | } |
79 | |
80 | //===----------------------------------------------------------------------===// |
81 | // AssertOp |
82 | //===----------------------------------------------------------------------===// |
83 | |
84 | LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { |
85 | // Erase assertion if argument is constant true. |
86 | if (matchPattern(op.getArg(), m_One())) { |
87 | rewriter.eraseOp(op); |
88 | return success(); |
89 | } |
90 | return failure(); |
91 | } |
92 | |
93 | //===----------------------------------------------------------------------===// |
94 | // BranchOp |
95 | //===----------------------------------------------------------------------===// |
96 | |
97 | /// Given a successor, try to collapse it to a new destination if it only |
98 | /// contains a passthrough unconditional branch. If the successor is |
99 | /// collapsable, `successor` and `successorOperands` are updated to reference |
100 | /// the new destination and values. `argStorage` is used as storage if operands |
101 | /// to the collapsed successor need to be remapped. It must outlive uses of |
102 | /// successorOperands. |
103 | static LogicalResult collapseBranch(Block *&successor, |
104 | ValueRange &successorOperands, |
105 | SmallVectorImpl<Value> &argStorage) { |
106 | // Check that the successor only contains a unconditional branch. |
107 | if (std::next(successor->begin()) != successor->end()) |
108 | return failure(); |
109 | // Check that the terminator is an unconditional branch. |
110 | BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); |
111 | if (!successorBranch) |
112 | return failure(); |
113 | // Check that the arguments are only used within the terminator. |
114 | for (BlockArgument arg : successor->getArguments()) { |
115 | for (Operation *user : arg.getUsers()) |
116 | if (user != successorBranch) |
117 | return failure(); |
118 | } |
119 | // Don't try to collapse branches to infinite loops. |
120 | Block *successorDest = successorBranch.getDest(); |
121 | if (successorDest == successor) |
122 | return failure(); |
123 | |
124 | // Update the operands to the successor. If the branch parent has no |
125 | // arguments, we can use the branch operands directly. |
126 | OperandRange operands = successorBranch.getOperands(); |
127 | if (successor->args_empty()) { |
128 | successor = successorDest; |
129 | successorOperands = operands; |
130 | return success(); |
131 | } |
132 | |
133 | // Otherwise, we need to remap any argument operands. |
134 | for (Value operand : operands) { |
135 | BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); |
136 | if (argOperand && argOperand.getOwner() == successor) |
137 | argStorage.push_back(successorOperands[argOperand.getArgNumber()]); |
138 | else |
139 | argStorage.push_back(operand); |
140 | } |
141 | successor = successorDest; |
142 | successorOperands = argStorage; |
143 | return success(); |
144 | } |
145 | |
146 | /// Simplify a branch to a block that has a single predecessor. This effectively |
147 | /// merges the two blocks. |
148 | static LogicalResult |
149 | simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { |
150 | // Check that the successor block has a single predecessor. |
151 | Block *succ = op.getDest(); |
152 | Block *opParent = op->getBlock(); |
153 | if (succ == opParent || !llvm::hasSingleElement(C: succ->getPredecessors())) |
154 | return failure(); |
155 | |
156 | // Merge the successor into the current block and erase the branch. |
157 | SmallVector<Value> brOperands(op.getOperands()); |
158 | rewriter.eraseOp(op: op); |
159 | rewriter.mergeBlocks(source: succ, dest: opParent, argValues: brOperands); |
160 | return success(); |
161 | } |
162 | |
163 | /// br ^bb1 |
164 | /// ^bb1 |
165 | /// br ^bbN(...) |
166 | /// |
167 | /// -> br ^bbN(...) |
168 | /// |
169 | static LogicalResult simplifyPassThroughBr(BranchOp op, |
170 | PatternRewriter &rewriter) { |
171 | Block *dest = op.getDest(); |
172 | ValueRange destOperands = op.getOperands(); |
173 | SmallVector<Value, 4> destOperandStorage; |
174 | |
175 | // Try to collapse the successor if it points somewhere other than this |
176 | // block. |
177 | if (dest == op->getBlock() || |
178 | failed(result: collapseBranch(successor&: dest, successorOperands&: destOperands, argStorage&: destOperandStorage))) |
179 | return failure(); |
180 | |
181 | // Create a new branch with the collapsed successor. |
182 | rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); |
183 | return success(); |
184 | } |
185 | |
186 | LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { |
187 | return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || |
188 | succeeded(simplifyPassThroughBr(op, rewriter))); |
189 | } |
190 | |
191 | void BranchOp::setDest(Block *block) { return setSuccessor(block); } |
192 | |
193 | void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } |
194 | |
195 | SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { |
196 | assert(index == 0 && "invalid successor index" ); |
197 | return SuccessorOperands(getDestOperandsMutable()); |
198 | } |
199 | |
200 | Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { |
201 | return getDest(); |
202 | } |
203 | |
204 | //===----------------------------------------------------------------------===// |
205 | // CondBranchOp |
206 | //===----------------------------------------------------------------------===// |
207 | |
208 | namespace { |
209 | /// cf.cond_br true, ^bb1, ^bb2 |
210 | /// -> br ^bb1 |
211 | /// cf.cond_br false, ^bb1, ^bb2 |
212 | /// -> br ^bb2 |
213 | /// |
214 | struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { |
215 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
216 | |
217 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
218 | PatternRewriter &rewriter) const override { |
219 | if (matchPattern(condbr.getCondition(), m_NonZero())) { |
220 | // True branch taken. |
221 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), |
222 | condbr.getTrueOperands()); |
223 | return success(); |
224 | } |
225 | if (matchPattern(condbr.getCondition(), m_Zero())) { |
226 | // False branch taken. |
227 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), |
228 | condbr.getFalseOperands()); |
229 | return success(); |
230 | } |
231 | return failure(); |
232 | } |
233 | }; |
234 | |
235 | /// cf.cond_br %cond, ^bb1, ^bb2 |
236 | /// ^bb1 |
237 | /// br ^bbN(...) |
238 | /// ^bb2 |
239 | /// br ^bbK(...) |
240 | /// |
241 | /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) |
242 | /// |
243 | struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { |
244 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
245 | |
246 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
247 | PatternRewriter &rewriter) const override { |
248 | Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); |
249 | ValueRange trueDestOperands = condbr.getTrueOperands(); |
250 | ValueRange falseDestOperands = condbr.getFalseOperands(); |
251 | SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; |
252 | |
253 | // Try to collapse one of the current successors. |
254 | LogicalResult collapsedTrue = |
255 | collapseBranch(successor&: trueDest, successorOperands&: trueDestOperands, argStorage&: trueDestOperandStorage); |
256 | LogicalResult collapsedFalse = |
257 | collapseBranch(successor&: falseDest, successorOperands&: falseDestOperands, argStorage&: falseDestOperandStorage); |
258 | if (failed(result: collapsedTrue) && failed(result: collapsedFalse)) |
259 | return failure(); |
260 | |
261 | // Create a new branch with the collapsed successors. |
262 | rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), |
263 | trueDest, trueDestOperands, |
264 | falseDest, falseDestOperands); |
265 | return success(); |
266 | } |
267 | }; |
268 | |
269 | /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) |
270 | /// -> br ^bb1(A, ..., N) |
271 | /// |
272 | /// cf.cond_br %cond, ^bb1(A), ^bb1(B) |
273 | /// -> %select = arith.select %cond, A, B |
274 | /// br ^bb1(%select) |
275 | /// |
276 | struct SimplifyCondBranchIdenticalSuccessors |
277 | : public OpRewritePattern<CondBranchOp> { |
278 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
279 | |
280 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
281 | PatternRewriter &rewriter) const override { |
282 | // Check that the true and false destinations are the same and have the same |
283 | // operands. |
284 | Block *trueDest = condbr.getTrueDest(); |
285 | if (trueDest != condbr.getFalseDest()) |
286 | return failure(); |
287 | |
288 | // If all of the operands match, no selects need to be generated. |
289 | OperandRange trueOperands = condbr.getTrueOperands(); |
290 | OperandRange falseOperands = condbr.getFalseOperands(); |
291 | if (trueOperands == falseOperands) { |
292 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); |
293 | return success(); |
294 | } |
295 | |
296 | // Otherwise, if the current block is the only predecessor insert selects |
297 | // for any mismatched branch operands. |
298 | if (trueDest->getUniquePredecessor() != condbr->getBlock()) |
299 | return failure(); |
300 | |
301 | // Generate a select for any operands that differ between the two. |
302 | SmallVector<Value, 8> mergedOperands; |
303 | mergedOperands.reserve(N: trueOperands.size()); |
304 | Value condition = condbr.getCondition(); |
305 | for (auto it : llvm::zip(trueOperands, falseOperands)) { |
306 | if (std::get<0>(it) == std::get<1>(it)) |
307 | mergedOperands.push_back(std::get<0>(it)); |
308 | else |
309 | mergedOperands.push_back(rewriter.create<arith::SelectOp>( |
310 | condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); |
311 | } |
312 | |
313 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); |
314 | return success(); |
315 | } |
316 | }; |
317 | |
318 | /// ... |
319 | /// cf.cond_br %cond, ^bb1(...), ^bb2(...) |
320 | /// ... |
321 | /// ^bb1: // has single predecessor |
322 | /// ... |
323 | /// cf.cond_br %cond, ^bb3(...), ^bb4(...) |
324 | /// |
325 | /// -> |
326 | /// |
327 | /// ... |
328 | /// cf.cond_br %cond, ^bb1(...), ^bb2(...) |
329 | /// ... |
330 | /// ^bb1: // has single predecessor |
331 | /// ... |
332 | /// br ^bb3(...) |
333 | /// |
334 | struct SimplifyCondBranchFromCondBranchOnSameCondition |
335 | : public OpRewritePattern<CondBranchOp> { |
336 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
337 | |
338 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
339 | PatternRewriter &rewriter) const override { |
340 | // Check that we have a single distinct predecessor. |
341 | Block *currentBlock = condbr->getBlock(); |
342 | Block *predecessor = currentBlock->getSinglePredecessor(); |
343 | if (!predecessor) |
344 | return failure(); |
345 | |
346 | // Check that the predecessor terminates with a conditional branch to this |
347 | // block and that it branches on the same condition. |
348 | auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); |
349 | if (!predBranch || condbr.getCondition() != predBranch.getCondition()) |
350 | return failure(); |
351 | |
352 | // Fold this branch to an unconditional branch. |
353 | if (currentBlock == predBranch.getTrueDest()) |
354 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), |
355 | condbr.getTrueDestOperands()); |
356 | else |
357 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), |
358 | condbr.getFalseDestOperands()); |
359 | return success(); |
360 | } |
361 | }; |
362 | |
363 | /// cf.cond_br %arg0, ^trueB, ^falseB |
364 | /// |
365 | /// ^trueB: |
366 | /// "test.consumer1"(%arg0) : (i1) -> () |
367 | /// ... |
368 | /// |
369 | /// ^falseB: |
370 | /// "test.consumer2"(%arg0) : (i1) -> () |
371 | /// ... |
372 | /// |
373 | /// -> |
374 | /// |
375 | /// cf.cond_br %arg0, ^trueB, ^falseB |
376 | /// ^trueB: |
377 | /// "test.consumer1"(%true) : (i1) -> () |
378 | /// ... |
379 | /// |
380 | /// ^falseB: |
381 | /// "test.consumer2"(%false) : (i1) -> () |
382 | /// ... |
383 | struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { |
384 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
385 | |
386 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
387 | PatternRewriter &rewriter) const override { |
388 | // Check that we have a single distinct predecessor. |
389 | bool replaced = false; |
390 | Type ty = rewriter.getI1Type(); |
391 | |
392 | // These variables serve to prevent creating duplicate constants |
393 | // and hold constant true or false values. |
394 | Value constantTrue = nullptr; |
395 | Value constantFalse = nullptr; |
396 | |
397 | // TODO These checks can be expanded to encompas any use with only |
398 | // either the true of false edge as a predecessor. For now, we fall |
399 | // back to checking the single predecessor is given by the true/fasle |
400 | // destination, thereby ensuring that only that edge can reach the |
401 | // op. |
402 | if (condbr.getTrueDest()->getSinglePredecessor()) { |
403 | for (OpOperand &use : |
404 | llvm::make_early_inc_range(condbr.getCondition().getUses())) { |
405 | if (use.getOwner()->getBlock() == condbr.getTrueDest()) { |
406 | replaced = true; |
407 | |
408 | if (!constantTrue) |
409 | constantTrue = rewriter.create<arith::ConstantOp>( |
410 | condbr.getLoc(), ty, rewriter.getBoolAttr(true)); |
411 | |
412 | rewriter.modifyOpInPlace(use.getOwner(), |
413 | [&] { use.set(constantTrue); }); |
414 | } |
415 | } |
416 | } |
417 | if (condbr.getFalseDest()->getSinglePredecessor()) { |
418 | for (OpOperand &use : |
419 | llvm::make_early_inc_range(condbr.getCondition().getUses())) { |
420 | if (use.getOwner()->getBlock() == condbr.getFalseDest()) { |
421 | replaced = true; |
422 | |
423 | if (!constantFalse) |
424 | constantFalse = rewriter.create<arith::ConstantOp>( |
425 | condbr.getLoc(), ty, rewriter.getBoolAttr(false)); |
426 | |
427 | rewriter.modifyOpInPlace(use.getOwner(), |
428 | [&] { use.set(constantFalse); }); |
429 | } |
430 | } |
431 | } |
432 | return success(isSuccess: replaced); |
433 | } |
434 | }; |
435 | } // namespace |
436 | |
437 | void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
438 | MLIRContext *context) { |
439 | results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, |
440 | SimplifyCondBranchIdenticalSuccessors, |
441 | SimplifyCondBranchFromCondBranchOnSameCondition, |
442 | CondBranchTruthPropagation>(context); |
443 | } |
444 | |
445 | SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { |
446 | assert(index < getNumSuccessors() && "invalid successor index" ); |
447 | return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() |
448 | : getFalseDestOperandsMutable()); |
449 | } |
450 | |
451 | Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { |
452 | if (IntegerAttr condAttr = |
453 | llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) |
454 | return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); |
455 | return nullptr; |
456 | } |
457 | |
458 | //===----------------------------------------------------------------------===// |
459 | // SwitchOp |
460 | //===----------------------------------------------------------------------===// |
461 | |
462 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
463 | Block *defaultDestination, ValueRange defaultOperands, |
464 | DenseIntElementsAttr caseValues, |
465 | BlockRange caseDestinations, |
466 | ArrayRef<ValueRange> caseOperands) { |
467 | build(builder, result, value, defaultOperands, caseOperands, caseValues, |
468 | defaultDestination, caseDestinations); |
469 | } |
470 | |
471 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
472 | Block *defaultDestination, ValueRange defaultOperands, |
473 | ArrayRef<APInt> caseValues, BlockRange caseDestinations, |
474 | ArrayRef<ValueRange> caseOperands) { |
475 | DenseIntElementsAttr caseValuesAttr; |
476 | if (!caseValues.empty()) { |
477 | ShapedType caseValueType = VectorType::get( |
478 | static_cast<int64_t>(caseValues.size()), value.getType()); |
479 | caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
480 | } |
481 | build(builder, result, value, defaultDestination, defaultOperands, |
482 | caseValuesAttr, caseDestinations, caseOperands); |
483 | } |
484 | |
485 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
486 | Block *defaultDestination, ValueRange defaultOperands, |
487 | ArrayRef<int32_t> caseValues, BlockRange caseDestinations, |
488 | ArrayRef<ValueRange> caseOperands) { |
489 | DenseIntElementsAttr caseValuesAttr; |
490 | if (!caseValues.empty()) { |
491 | ShapedType caseValueType = VectorType::get( |
492 | static_cast<int64_t>(caseValues.size()), value.getType()); |
493 | caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
494 | } |
495 | build(builder, result, value, defaultDestination, defaultOperands, |
496 | caseValuesAttr, caseDestinations, caseOperands); |
497 | } |
498 | |
499 | /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? |
500 | /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* |
501 | static ParseResult parseSwitchOpCases( |
502 | OpAsmParser &parser, Type &flagType, Block *&defaultDestination, |
503 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, |
504 | SmallVectorImpl<Type> &defaultOperandTypes, |
505 | DenseIntElementsAttr &caseValues, |
506 | SmallVectorImpl<Block *> &caseDestinations, |
507 | SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, |
508 | SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { |
509 | if (parser.parseKeyword(keyword: "default" ) || parser.parseColon() || |
510 | parser.parseSuccessor(dest&: defaultDestination)) |
511 | return failure(); |
512 | if (succeeded(result: parser.parseOptionalLParen())) { |
513 | if (parser.parseOperandList(result&: defaultOperands, delimiter: OpAsmParser::Delimiter::None, |
514 | /*allowResultNumber=*/false) || |
515 | parser.parseColonTypeList(result&: defaultOperandTypes) || parser.parseRParen()) |
516 | return failure(); |
517 | } |
518 | |
519 | SmallVector<APInt> values; |
520 | unsigned bitWidth = flagType.getIntOrFloatBitWidth(); |
521 | while (succeeded(result: parser.parseOptionalComma())) { |
522 | int64_t value = 0; |
523 | if (failed(result: parser.parseInteger(result&: value))) |
524 | return failure(); |
525 | values.push_back(Elt: APInt(bitWidth, value)); |
526 | |
527 | Block *destination; |
528 | SmallVector<OpAsmParser::UnresolvedOperand> operands; |
529 | SmallVector<Type> operandTypes; |
530 | if (failed(result: parser.parseColon()) || |
531 | failed(result: parser.parseSuccessor(dest&: destination))) |
532 | return failure(); |
533 | if (succeeded(result: parser.parseOptionalLParen())) { |
534 | if (failed(result: parser.parseOperandList(result&: operands, |
535 | delimiter: OpAsmParser::Delimiter::None)) || |
536 | failed(result: parser.parseColonTypeList(result&: operandTypes)) || |
537 | failed(result: parser.parseRParen())) |
538 | return failure(); |
539 | } |
540 | caseDestinations.push_back(Elt: destination); |
541 | caseOperands.emplace_back(Args&: operands); |
542 | caseOperandTypes.emplace_back(Args&: operandTypes); |
543 | } |
544 | |
545 | if (!values.empty()) { |
546 | ShapedType caseValueType = |
547 | VectorType::get(static_cast<int64_t>(values.size()), flagType); |
548 | caseValues = DenseIntElementsAttr::get(caseValueType, values); |
549 | } |
550 | return success(); |
551 | } |
552 | |
553 | static void printSwitchOpCases( |
554 | OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, |
555 | OperandRange defaultOperands, TypeRange defaultOperandTypes, |
556 | DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, |
557 | OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { |
558 | p << " default: " ; |
559 | p.printSuccessorAndUseList(successor: defaultDestination, succOperands: defaultOperands); |
560 | |
561 | if (!caseValues) |
562 | return; |
563 | |
564 | for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { |
565 | p << ','; |
566 | p.printNewline(); |
567 | p << " " ; |
568 | p << it.value().getLimitedValue(); |
569 | p << ": " ; |
570 | p.printSuccessorAndUseList(caseDestinations[it.index()], |
571 | caseOperands[it.index()]); |
572 | } |
573 | p.printNewline(); |
574 | } |
575 | |
576 | LogicalResult SwitchOp::verify() { |
577 | auto caseValues = getCaseValues(); |
578 | auto caseDestinations = getCaseDestinations(); |
579 | |
580 | if (!caseValues && caseDestinations.empty()) |
581 | return success(); |
582 | |
583 | Type flagType = getFlag().getType(); |
584 | Type caseValueType = caseValues->getType().getElementType(); |
585 | if (caseValueType != flagType) |
586 | return emitOpError() << "'flag' type (" << flagType |
587 | << ") should match case value type (" << caseValueType |
588 | << ")" ; |
589 | |
590 | if (caseValues && |
591 | caseValues->size() != static_cast<int64_t>(caseDestinations.size())) |
592 | return emitOpError() << "number of case values (" << caseValues->size() |
593 | << ") should match number of " |
594 | "case destinations (" |
595 | << caseDestinations.size() << ")" ; |
596 | return success(); |
597 | } |
598 | |
599 | SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { |
600 | assert(index < getNumSuccessors() && "invalid successor index" ); |
601 | return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() |
602 | : getCaseOperandsMutable(index - 1)); |
603 | } |
604 | |
605 | Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { |
606 | std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); |
607 | |
608 | if (!caseValues) |
609 | return getDefaultDestination(); |
610 | |
611 | SuccessorRange caseDests = getCaseDestinations(); |
612 | if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { |
613 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) |
614 | if (it.value() == value.getValue()) |
615 | return caseDests[it.index()]; |
616 | return getDefaultDestination(); |
617 | } |
618 | return nullptr; |
619 | } |
620 | |
621 | /// switch %flag : i32, [ |
622 | /// default: ^bb1 |
623 | /// ] |
624 | /// -> br ^bb1 |
625 | static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, |
626 | PatternRewriter &rewriter) { |
627 | if (!op.getCaseDestinations().empty()) |
628 | return failure(); |
629 | |
630 | rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), |
631 | op.getDefaultOperands()); |
632 | return success(); |
633 | } |
634 | |
635 | /// switch %flag : i32, [ |
636 | /// default: ^bb1, |
637 | /// 42: ^bb1, |
638 | /// 43: ^bb2 |
639 | /// ] |
640 | /// -> |
641 | /// switch %flag : i32, [ |
642 | /// default: ^bb1, |
643 | /// 43: ^bb2 |
644 | /// ] |
645 | static LogicalResult |
646 | dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { |
647 | SmallVector<Block *> newCaseDestinations; |
648 | SmallVector<ValueRange> newCaseOperands; |
649 | SmallVector<APInt> newCaseValues; |
650 | bool requiresChange = false; |
651 | auto caseValues = op.getCaseValues(); |
652 | auto caseDests = op.getCaseDestinations(); |
653 | |
654 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { |
655 | if (caseDests[it.index()] == op.getDefaultDestination() && |
656 | op.getCaseOperands(it.index()) == op.getDefaultOperands()) { |
657 | requiresChange = true; |
658 | continue; |
659 | } |
660 | newCaseDestinations.push_back(caseDests[it.index()]); |
661 | newCaseOperands.push_back(op.getCaseOperands(it.index())); |
662 | newCaseValues.push_back(it.value()); |
663 | } |
664 | |
665 | if (!requiresChange) |
666 | return failure(); |
667 | |
668 | rewriter.replaceOpWithNewOp<SwitchOp>( |
669 | op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), |
670 | newCaseValues, newCaseDestinations, newCaseOperands); |
671 | return success(); |
672 | } |
673 | |
674 | /// Helper for folding a switch with a constant value. |
675 | /// switch %c_42 : i32, [ |
676 | /// default: ^bb1 , |
677 | /// 42: ^bb2, |
678 | /// 43: ^bb3 |
679 | /// ] |
680 | /// -> br ^bb2 |
681 | static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, |
682 | const APInt &caseValue) { |
683 | auto caseValues = op.getCaseValues(); |
684 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { |
685 | if (it.value() == caseValue) { |
686 | rewriter.replaceOpWithNewOp<BranchOp>( |
687 | op, op.getCaseDestinations()[it.index()], |
688 | op.getCaseOperands(it.index())); |
689 | return; |
690 | } |
691 | } |
692 | rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), |
693 | op.getDefaultOperands()); |
694 | } |
695 | |
696 | /// switch %c_42 : i32, [ |
697 | /// default: ^bb1, |
698 | /// 42: ^bb2, |
699 | /// 43: ^bb3 |
700 | /// ] |
701 | /// -> br ^bb2 |
702 | static LogicalResult simplifyConstSwitchValue(SwitchOp op, |
703 | PatternRewriter &rewriter) { |
704 | APInt caseValue; |
705 | if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) |
706 | return failure(); |
707 | |
708 | foldSwitch(op, rewriter, caseValue); |
709 | return success(); |
710 | } |
711 | |
712 | /// switch %c_42 : i32, [ |
713 | /// default: ^bb1, |
714 | /// 42: ^bb2, |
715 | /// ] |
716 | /// ^bb2: |
717 | /// br ^bb3 |
718 | /// -> |
719 | /// switch %c_42 : i32, [ |
720 | /// default: ^bb1, |
721 | /// 42: ^bb3, |
722 | /// ] |
723 | static LogicalResult simplifyPassThroughSwitch(SwitchOp op, |
724 | PatternRewriter &rewriter) { |
725 | SmallVector<Block *> newCaseDests; |
726 | SmallVector<ValueRange> newCaseOperands; |
727 | SmallVector<SmallVector<Value>> argStorage; |
728 | auto caseValues = op.getCaseValues(); |
729 | argStorage.reserve(N: caseValues->size() + 1); |
730 | auto caseDests = op.getCaseDestinations(); |
731 | bool requiresChange = false; |
732 | for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { |
733 | Block *caseDest = caseDests[i]; |
734 | ValueRange caseOperands = op.getCaseOperands(i); |
735 | argStorage.emplace_back(); |
736 | if (succeeded(result: collapseBranch(successor&: caseDest, successorOperands&: caseOperands, argStorage&: argStorage.back()))) |
737 | requiresChange = true; |
738 | |
739 | newCaseDests.push_back(Elt: caseDest); |
740 | newCaseOperands.push_back(Elt: caseOperands); |
741 | } |
742 | |
743 | Block *defaultDest = op.getDefaultDestination(); |
744 | ValueRange defaultOperands = op.getDefaultOperands(); |
745 | argStorage.emplace_back(); |
746 | |
747 | if (succeeded( |
748 | result: collapseBranch(successor&: defaultDest, successorOperands&: defaultOperands, argStorage&: argStorage.back()))) |
749 | requiresChange = true; |
750 | |
751 | if (!requiresChange) |
752 | return failure(); |
753 | |
754 | rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, |
755 | defaultOperands, *caseValues, |
756 | newCaseDests, newCaseOperands); |
757 | return success(); |
758 | } |
759 | |
760 | /// switch %flag : i32, [ |
761 | /// default: ^bb1, |
762 | /// 42: ^bb2, |
763 | /// ] |
764 | /// ^bb2: |
765 | /// switch %flag : i32, [ |
766 | /// default: ^bb3, |
767 | /// 42: ^bb4 |
768 | /// ] |
769 | /// -> |
770 | /// switch %flag : i32, [ |
771 | /// default: ^bb1, |
772 | /// 42: ^bb2, |
773 | /// ] |
774 | /// ^bb2: |
775 | /// br ^bb4 |
776 | /// |
777 | /// and |
778 | /// |
779 | /// switch %flag : i32, [ |
780 | /// default: ^bb1, |
781 | /// 42: ^bb2, |
782 | /// ] |
783 | /// ^bb2: |
784 | /// switch %flag : i32, [ |
785 | /// default: ^bb3, |
786 | /// 43: ^bb4 |
787 | /// ] |
788 | /// -> |
789 | /// switch %flag : i32, [ |
790 | /// default: ^bb1, |
791 | /// 42: ^bb2, |
792 | /// ] |
793 | /// ^bb2: |
794 | /// br ^bb3 |
795 | static LogicalResult |
796 | simplifySwitchFromSwitchOnSameCondition(SwitchOp op, |
797 | PatternRewriter &rewriter) { |
798 | // Check that we have a single distinct predecessor. |
799 | Block *currentBlock = op->getBlock(); |
800 | Block *predecessor = currentBlock->getSinglePredecessor(); |
801 | if (!predecessor) |
802 | return failure(); |
803 | |
804 | // Check that the predecessor terminates with a switch branch to this block |
805 | // and that it branches on the same condition and that this branch isn't the |
806 | // default destination. |
807 | auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); |
808 | if (!predSwitch || op.getFlag() != predSwitch.getFlag() || |
809 | predSwitch.getDefaultDestination() == currentBlock) |
810 | return failure(); |
811 | |
812 | // Fold this switch to an unconditional branch. |
813 | SuccessorRange predDests = predSwitch.getCaseDestinations(); |
814 | auto it = llvm::find(Range&: predDests, Val: currentBlock); |
815 | if (it != predDests.end()) { |
816 | std::optional<DenseIntElementsAttr> predCaseValues = |
817 | predSwitch.getCaseValues(); |
818 | foldSwitch(op, rewriter, |
819 | predCaseValues->getValues<APInt>()[it - predDests.begin()]); |
820 | } else { |
821 | rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), |
822 | op.getDefaultOperands()); |
823 | } |
824 | return success(); |
825 | } |
826 | |
827 | /// switch %flag : i32, [ |
828 | /// default: ^bb1, |
829 | /// 42: ^bb2 |
830 | /// ] |
831 | /// ^bb1: |
832 | /// switch %flag : i32, [ |
833 | /// default: ^bb3, |
834 | /// 42: ^bb4, |
835 | /// 43: ^bb5 |
836 | /// ] |
837 | /// -> |
838 | /// switch %flag : i32, [ |
839 | /// default: ^bb1, |
840 | /// 42: ^bb2, |
841 | /// ] |
842 | /// ^bb1: |
843 | /// switch %flag : i32, [ |
844 | /// default: ^bb3, |
845 | /// 43: ^bb5 |
846 | /// ] |
847 | static LogicalResult |
848 | simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, |
849 | PatternRewriter &rewriter) { |
850 | // Check that we have a single distinct predecessor. |
851 | Block *currentBlock = op->getBlock(); |
852 | Block *predecessor = currentBlock->getSinglePredecessor(); |
853 | if (!predecessor) |
854 | return failure(); |
855 | |
856 | // Check that the predecessor terminates with a switch branch to this block |
857 | // and that it branches on the same condition and that this branch is the |
858 | // default destination. |
859 | auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); |
860 | if (!predSwitch || op.getFlag() != predSwitch.getFlag() || |
861 | predSwitch.getDefaultDestination() != currentBlock) |
862 | return failure(); |
863 | |
864 | // Delete case values that are not possible here. |
865 | DenseSet<APInt> caseValuesToRemove; |
866 | auto predDests = predSwitch.getCaseDestinations(); |
867 | auto predCaseValues = predSwitch.getCaseValues(); |
868 | for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) |
869 | if (currentBlock != predDests[i]) |
870 | caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); |
871 | |
872 | SmallVector<Block *> newCaseDestinations; |
873 | SmallVector<ValueRange> newCaseOperands; |
874 | SmallVector<APInt> newCaseValues; |
875 | bool requiresChange = false; |
876 | |
877 | auto caseValues = op.getCaseValues(); |
878 | auto caseDests = op.getCaseDestinations(); |
879 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { |
880 | if (caseValuesToRemove.contains(it.value())) { |
881 | requiresChange = true; |
882 | continue; |
883 | } |
884 | newCaseDestinations.push_back(caseDests[it.index()]); |
885 | newCaseOperands.push_back(op.getCaseOperands(it.index())); |
886 | newCaseValues.push_back(it.value()); |
887 | } |
888 | |
889 | if (!requiresChange) |
890 | return failure(); |
891 | |
892 | rewriter.replaceOpWithNewOp<SwitchOp>( |
893 | op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), |
894 | newCaseValues, newCaseDestinations, newCaseOperands); |
895 | return success(); |
896 | } |
897 | |
898 | void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
899 | MLIRContext *context) { |
900 | results.add(&simplifySwitchWithOnlyDefault) |
901 | .add(&dropSwitchCasesThatMatchDefault) |
902 | .add(&simplifyConstSwitchValue) |
903 | .add(&simplifyPassThroughSwitch) |
904 | .add(&simplifySwitchFromSwitchOnSameCondition) |
905 | .add(&simplifySwitchFromDefaultSwitchOnSameCondition); |
906 | } |
907 | |
908 | //===----------------------------------------------------------------------===// |
909 | // TableGen'd op method definitions |
910 | //===----------------------------------------------------------------------===// |
911 | |
912 | #define GET_OP_CLASSES |
913 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" |
914 | |