1//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Support/MathExtras.h"
27#include "mlir/Transforms/InliningUtils.h"
28#include "llvm/ADT/APFloat.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
31#include "llvm/Support/raw_ostream.h"
32#include <numeric>
33
34#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
35
36using namespace mlir;
37using namespace mlir::cf;
38
39//===----------------------------------------------------------------------===//
40// ControlFlowDialect Interfaces
41//===----------------------------------------------------------------------===//
42namespace {
43/// This class defines the interface for handling inlining with control flow
44/// operations.
45struct ControlFlowInlinerInterface : public DialectInlinerInterface {
46 using DialectInlinerInterface::DialectInlinerInterface;
47 ~ControlFlowInlinerInterface() override = default;
48
49 /// All control flow operations can be inlined.
50 bool isLegalToInline(Operation *call, Operation *callable,
51 bool wouldBeCloned) const final {
52 return true;
53 }
54 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
55 return true;
56 }
57
58 /// ControlFlow terminator operations don't really need any special handing.
59 void handleTerminator(Operation *op, Block *newDest) const final {}
60};
61} // namespace
62
63//===----------------------------------------------------------------------===//
64// ControlFlowDialect
65//===----------------------------------------------------------------------===//
66
67void ControlFlowDialect::initialize() {
68 addOperations<
69#define GET_OP_LIST
70#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
71 >();
72 addInterfaces<ControlFlowInlinerInterface>();
73 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
74 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
75 CondBranchOp>();
76 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
77 CondBranchOp>();
78}
79
80//===----------------------------------------------------------------------===//
81// AssertOp
82//===----------------------------------------------------------------------===//
83
84LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
85 // Erase assertion if argument is constant true.
86 if (matchPattern(op.getArg(), m_One())) {
87 rewriter.eraseOp(op);
88 return success();
89 }
90 return failure();
91}
92
93//===----------------------------------------------------------------------===//
94// BranchOp
95//===----------------------------------------------------------------------===//
96
97/// Given a successor, try to collapse it to a new destination if it only
98/// contains a passthrough unconditional branch. If the successor is
99/// collapsable, `successor` and `successorOperands` are updated to reference
100/// the new destination and values. `argStorage` is used as storage if operands
101/// to the collapsed successor need to be remapped. It must outlive uses of
102/// successorOperands.
103static LogicalResult collapseBranch(Block *&successor,
104 ValueRange &successorOperands,
105 SmallVectorImpl<Value> &argStorage) {
106 // Check that the successor only contains a unconditional branch.
107 if (std::next(successor->begin()) != successor->end())
108 return failure();
109 // Check that the terminator is an unconditional branch.
110 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
111 if (!successorBranch)
112 return failure();
113 // Check that the arguments are only used within the terminator.
114 for (BlockArgument arg : successor->getArguments()) {
115 for (Operation *user : arg.getUsers())
116 if (user != successorBranch)
117 return failure();
118 }
119 // Don't try to collapse branches to infinite loops.
120 Block *successorDest = successorBranch.getDest();
121 if (successorDest == successor)
122 return failure();
123
124 // Update the operands to the successor. If the branch parent has no
125 // arguments, we can use the branch operands directly.
126 OperandRange operands = successorBranch.getOperands();
127 if (successor->args_empty()) {
128 successor = successorDest;
129 successorOperands = operands;
130 return success();
131 }
132
133 // Otherwise, we need to remap any argument operands.
134 for (Value operand : operands) {
135 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
136 if (argOperand && argOperand.getOwner() == successor)
137 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
138 else
139 argStorage.push_back(operand);
140 }
141 successor = successorDest;
142 successorOperands = argStorage;
143 return success();
144}
145
146/// Simplify a branch to a block that has a single predecessor. This effectively
147/// merges the two blocks.
148static LogicalResult
149simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
150 // Check that the successor block has a single predecessor.
151 Block *succ = op.getDest();
152 Block *opParent = op->getBlock();
153 if (succ == opParent || !llvm::hasSingleElement(C: succ->getPredecessors()))
154 return failure();
155
156 // Merge the successor into the current block and erase the branch.
157 SmallVector<Value> brOperands(op.getOperands());
158 rewriter.eraseOp(op: op);
159 rewriter.mergeBlocks(source: succ, dest: opParent, argValues: brOperands);
160 return success();
161}
162
163/// br ^bb1
164/// ^bb1
165/// br ^bbN(...)
166///
167/// -> br ^bbN(...)
168///
169static LogicalResult simplifyPassThroughBr(BranchOp op,
170 PatternRewriter &rewriter) {
171 Block *dest = op.getDest();
172 ValueRange destOperands = op.getOperands();
173 SmallVector<Value, 4> destOperandStorage;
174
175 // Try to collapse the successor if it points somewhere other than this
176 // block.
177 if (dest == op->getBlock() ||
178 failed(result: collapseBranch(successor&: dest, successorOperands&: destOperands, argStorage&: destOperandStorage)))
179 return failure();
180
181 // Create a new branch with the collapsed successor.
182 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
183 return success();
184}
185
186LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
187 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
188 succeeded(simplifyPassThroughBr(op, rewriter)));
189}
190
191void BranchOp::setDest(Block *block) { return setSuccessor(block); }
192
193void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
194
195SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
196 assert(index == 0 && "invalid successor index");
197 return SuccessorOperands(getDestOperandsMutable());
198}
199
200Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
201 return getDest();
202}
203
204//===----------------------------------------------------------------------===//
205// CondBranchOp
206//===----------------------------------------------------------------------===//
207
208namespace {
209/// cf.cond_br true, ^bb1, ^bb2
210/// -> br ^bb1
211/// cf.cond_br false, ^bb1, ^bb2
212/// -> br ^bb2
213///
214struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
215 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
216
217 LogicalResult matchAndRewrite(CondBranchOp condbr,
218 PatternRewriter &rewriter) const override {
219 if (matchPattern(condbr.getCondition(), m_NonZero())) {
220 // True branch taken.
221 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
222 condbr.getTrueOperands());
223 return success();
224 }
225 if (matchPattern(condbr.getCondition(), m_Zero())) {
226 // False branch taken.
227 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
228 condbr.getFalseOperands());
229 return success();
230 }
231 return failure();
232 }
233};
234
235/// cf.cond_br %cond, ^bb1, ^bb2
236/// ^bb1
237/// br ^bbN(...)
238/// ^bb2
239/// br ^bbK(...)
240///
241/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
242///
243struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
244 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
245
246 LogicalResult matchAndRewrite(CondBranchOp condbr,
247 PatternRewriter &rewriter) const override {
248 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
249 ValueRange trueDestOperands = condbr.getTrueOperands();
250 ValueRange falseDestOperands = condbr.getFalseOperands();
251 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
252
253 // Try to collapse one of the current successors.
254 LogicalResult collapsedTrue =
255 collapseBranch(successor&: trueDest, successorOperands&: trueDestOperands, argStorage&: trueDestOperandStorage);
256 LogicalResult collapsedFalse =
257 collapseBranch(successor&: falseDest, successorOperands&: falseDestOperands, argStorage&: falseDestOperandStorage);
258 if (failed(result: collapsedTrue) && failed(result: collapsedFalse))
259 return failure();
260
261 // Create a new branch with the collapsed successors.
262 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
263 trueDest, trueDestOperands,
264 falseDest, falseDestOperands);
265 return success();
266 }
267};
268
269/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
270/// -> br ^bb1(A, ..., N)
271///
272/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
273/// -> %select = arith.select %cond, A, B
274/// br ^bb1(%select)
275///
276struct SimplifyCondBranchIdenticalSuccessors
277 : public OpRewritePattern<CondBranchOp> {
278 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
279
280 LogicalResult matchAndRewrite(CondBranchOp condbr,
281 PatternRewriter &rewriter) const override {
282 // Check that the true and false destinations are the same and have the same
283 // operands.
284 Block *trueDest = condbr.getTrueDest();
285 if (trueDest != condbr.getFalseDest())
286 return failure();
287
288 // If all of the operands match, no selects need to be generated.
289 OperandRange trueOperands = condbr.getTrueOperands();
290 OperandRange falseOperands = condbr.getFalseOperands();
291 if (trueOperands == falseOperands) {
292 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
293 return success();
294 }
295
296 // Otherwise, if the current block is the only predecessor insert selects
297 // for any mismatched branch operands.
298 if (trueDest->getUniquePredecessor() != condbr->getBlock())
299 return failure();
300
301 // Generate a select for any operands that differ between the two.
302 SmallVector<Value, 8> mergedOperands;
303 mergedOperands.reserve(N: trueOperands.size());
304 Value condition = condbr.getCondition();
305 for (auto it : llvm::zip(trueOperands, falseOperands)) {
306 if (std::get<0>(it) == std::get<1>(it))
307 mergedOperands.push_back(std::get<0>(it));
308 else
309 mergedOperands.push_back(rewriter.create<arith::SelectOp>(
310 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
311 }
312
313 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
314 return success();
315 }
316};
317
318/// ...
319/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
320/// ...
321/// ^bb1: // has single predecessor
322/// ...
323/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
324///
325/// ->
326///
327/// ...
328/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
329/// ...
330/// ^bb1: // has single predecessor
331/// ...
332/// br ^bb3(...)
333///
334struct SimplifyCondBranchFromCondBranchOnSameCondition
335 : public OpRewritePattern<CondBranchOp> {
336 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
337
338 LogicalResult matchAndRewrite(CondBranchOp condbr,
339 PatternRewriter &rewriter) const override {
340 // Check that we have a single distinct predecessor.
341 Block *currentBlock = condbr->getBlock();
342 Block *predecessor = currentBlock->getSinglePredecessor();
343 if (!predecessor)
344 return failure();
345
346 // Check that the predecessor terminates with a conditional branch to this
347 // block and that it branches on the same condition.
348 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
349 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
350 return failure();
351
352 // Fold this branch to an unconditional branch.
353 if (currentBlock == predBranch.getTrueDest())
354 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
355 condbr.getTrueDestOperands());
356 else
357 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
358 condbr.getFalseDestOperands());
359 return success();
360 }
361};
362
363/// cf.cond_br %arg0, ^trueB, ^falseB
364///
365/// ^trueB:
366/// "test.consumer1"(%arg0) : (i1) -> ()
367/// ...
368///
369/// ^falseB:
370/// "test.consumer2"(%arg0) : (i1) -> ()
371/// ...
372///
373/// ->
374///
375/// cf.cond_br %arg0, ^trueB, ^falseB
376/// ^trueB:
377/// "test.consumer1"(%true) : (i1) -> ()
378/// ...
379///
380/// ^falseB:
381/// "test.consumer2"(%false) : (i1) -> ()
382/// ...
383struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
384 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
385
386 LogicalResult matchAndRewrite(CondBranchOp condbr,
387 PatternRewriter &rewriter) const override {
388 // Check that we have a single distinct predecessor.
389 bool replaced = false;
390 Type ty = rewriter.getI1Type();
391
392 // These variables serve to prevent creating duplicate constants
393 // and hold constant true or false values.
394 Value constantTrue = nullptr;
395 Value constantFalse = nullptr;
396
397 // TODO These checks can be expanded to encompas any use with only
398 // either the true of false edge as a predecessor. For now, we fall
399 // back to checking the single predecessor is given by the true/fasle
400 // destination, thereby ensuring that only that edge can reach the
401 // op.
402 if (condbr.getTrueDest()->getSinglePredecessor()) {
403 for (OpOperand &use :
404 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
405 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
406 replaced = true;
407
408 if (!constantTrue)
409 constantTrue = rewriter.create<arith::ConstantOp>(
410 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
411
412 rewriter.modifyOpInPlace(use.getOwner(),
413 [&] { use.set(constantTrue); });
414 }
415 }
416 }
417 if (condbr.getFalseDest()->getSinglePredecessor()) {
418 for (OpOperand &use :
419 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
420 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
421 replaced = true;
422
423 if (!constantFalse)
424 constantFalse = rewriter.create<arith::ConstantOp>(
425 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
426
427 rewriter.modifyOpInPlace(use.getOwner(),
428 [&] { use.set(constantFalse); });
429 }
430 }
431 }
432 return success(isSuccess: replaced);
433 }
434};
435} // namespace
436
437void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
438 MLIRContext *context) {
439 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
440 SimplifyCondBranchIdenticalSuccessors,
441 SimplifyCondBranchFromCondBranchOnSameCondition,
442 CondBranchTruthPropagation>(context);
443}
444
445SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
446 assert(index < getNumSuccessors() && "invalid successor index");
447 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
448 : getFalseDestOperandsMutable());
449}
450
451Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
452 if (IntegerAttr condAttr =
453 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
454 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
455 return nullptr;
456}
457
458//===----------------------------------------------------------------------===//
459// SwitchOp
460//===----------------------------------------------------------------------===//
461
462void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
463 Block *defaultDestination, ValueRange defaultOperands,
464 DenseIntElementsAttr caseValues,
465 BlockRange caseDestinations,
466 ArrayRef<ValueRange> caseOperands) {
467 build(builder, result, value, defaultOperands, caseOperands, caseValues,
468 defaultDestination, caseDestinations);
469}
470
471void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
472 Block *defaultDestination, ValueRange defaultOperands,
473 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
474 ArrayRef<ValueRange> caseOperands) {
475 DenseIntElementsAttr caseValuesAttr;
476 if (!caseValues.empty()) {
477 ShapedType caseValueType = VectorType::get(
478 static_cast<int64_t>(caseValues.size()), value.getType());
479 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
480 }
481 build(builder, result, value, defaultDestination, defaultOperands,
482 caseValuesAttr, caseDestinations, caseOperands);
483}
484
485void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
486 Block *defaultDestination, ValueRange defaultOperands,
487 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
488 ArrayRef<ValueRange> caseOperands) {
489 DenseIntElementsAttr caseValuesAttr;
490 if (!caseValues.empty()) {
491 ShapedType caseValueType = VectorType::get(
492 static_cast<int64_t>(caseValues.size()), value.getType());
493 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
494 }
495 build(builder, result, value, defaultDestination, defaultOperands,
496 caseValuesAttr, caseDestinations, caseOperands);
497}
498
499/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
500/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
501static ParseResult parseSwitchOpCases(
502 OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
503 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
504 SmallVectorImpl<Type> &defaultOperandTypes,
505 DenseIntElementsAttr &caseValues,
506 SmallVectorImpl<Block *> &caseDestinations,
507 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
508 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
509 if (parser.parseKeyword(keyword: "default") || parser.parseColon() ||
510 parser.parseSuccessor(dest&: defaultDestination))
511 return failure();
512 if (succeeded(result: parser.parseOptionalLParen())) {
513 if (parser.parseOperandList(result&: defaultOperands, delimiter: OpAsmParser::Delimiter::None,
514 /*allowResultNumber=*/false) ||
515 parser.parseColonTypeList(result&: defaultOperandTypes) || parser.parseRParen())
516 return failure();
517 }
518
519 SmallVector<APInt> values;
520 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
521 while (succeeded(result: parser.parseOptionalComma())) {
522 int64_t value = 0;
523 if (failed(result: parser.parseInteger(result&: value)))
524 return failure();
525 values.push_back(Elt: APInt(bitWidth, value));
526
527 Block *destination;
528 SmallVector<OpAsmParser::UnresolvedOperand> operands;
529 SmallVector<Type> operandTypes;
530 if (failed(result: parser.parseColon()) ||
531 failed(result: parser.parseSuccessor(dest&: destination)))
532 return failure();
533 if (succeeded(result: parser.parseOptionalLParen())) {
534 if (failed(result: parser.parseOperandList(result&: operands,
535 delimiter: OpAsmParser::Delimiter::None)) ||
536 failed(result: parser.parseColonTypeList(result&: operandTypes)) ||
537 failed(result: parser.parseRParen()))
538 return failure();
539 }
540 caseDestinations.push_back(Elt: destination);
541 caseOperands.emplace_back(Args&: operands);
542 caseOperandTypes.emplace_back(Args&: operandTypes);
543 }
544
545 if (!values.empty()) {
546 ShapedType caseValueType =
547 VectorType::get(static_cast<int64_t>(values.size()), flagType);
548 caseValues = DenseIntElementsAttr::get(caseValueType, values);
549 }
550 return success();
551}
552
553static void printSwitchOpCases(
554 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
555 OperandRange defaultOperands, TypeRange defaultOperandTypes,
556 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
557 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
558 p << " default: ";
559 p.printSuccessorAndUseList(successor: defaultDestination, succOperands: defaultOperands);
560
561 if (!caseValues)
562 return;
563
564 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
565 p << ',';
566 p.printNewline();
567 p << " ";
568 p << it.value().getLimitedValue();
569 p << ": ";
570 p.printSuccessorAndUseList(caseDestinations[it.index()],
571 caseOperands[it.index()]);
572 }
573 p.printNewline();
574}
575
576LogicalResult SwitchOp::verify() {
577 auto caseValues = getCaseValues();
578 auto caseDestinations = getCaseDestinations();
579
580 if (!caseValues && caseDestinations.empty())
581 return success();
582
583 Type flagType = getFlag().getType();
584 Type caseValueType = caseValues->getType().getElementType();
585 if (caseValueType != flagType)
586 return emitOpError() << "'flag' type (" << flagType
587 << ") should match case value type (" << caseValueType
588 << ")";
589
590 if (caseValues &&
591 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
592 return emitOpError() << "number of case values (" << caseValues->size()
593 << ") should match number of "
594 "case destinations ("
595 << caseDestinations.size() << ")";
596 return success();
597}
598
599SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
600 assert(index < getNumSuccessors() && "invalid successor index");
601 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
602 : getCaseOperandsMutable(index - 1));
603}
604
605Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
606 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
607
608 if (!caseValues)
609 return getDefaultDestination();
610
611 SuccessorRange caseDests = getCaseDestinations();
612 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
613 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
614 if (it.value() == value.getValue())
615 return caseDests[it.index()];
616 return getDefaultDestination();
617 }
618 return nullptr;
619}
620
621/// switch %flag : i32, [
622/// default: ^bb1
623/// ]
624/// -> br ^bb1
625static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
626 PatternRewriter &rewriter) {
627 if (!op.getCaseDestinations().empty())
628 return failure();
629
630 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
631 op.getDefaultOperands());
632 return success();
633}
634
635/// switch %flag : i32, [
636/// default: ^bb1,
637/// 42: ^bb1,
638/// 43: ^bb2
639/// ]
640/// ->
641/// switch %flag : i32, [
642/// default: ^bb1,
643/// 43: ^bb2
644/// ]
645static LogicalResult
646dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
647 SmallVector<Block *> newCaseDestinations;
648 SmallVector<ValueRange> newCaseOperands;
649 SmallVector<APInt> newCaseValues;
650 bool requiresChange = false;
651 auto caseValues = op.getCaseValues();
652 auto caseDests = op.getCaseDestinations();
653
654 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
655 if (caseDests[it.index()] == op.getDefaultDestination() &&
656 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
657 requiresChange = true;
658 continue;
659 }
660 newCaseDestinations.push_back(caseDests[it.index()]);
661 newCaseOperands.push_back(op.getCaseOperands(it.index()));
662 newCaseValues.push_back(it.value());
663 }
664
665 if (!requiresChange)
666 return failure();
667
668 rewriter.replaceOpWithNewOp<SwitchOp>(
669 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
670 newCaseValues, newCaseDestinations, newCaseOperands);
671 return success();
672}
673
674/// Helper for folding a switch with a constant value.
675/// switch %c_42 : i32, [
676/// default: ^bb1 ,
677/// 42: ^bb2,
678/// 43: ^bb3
679/// ]
680/// -> br ^bb2
681static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
682 const APInt &caseValue) {
683 auto caseValues = op.getCaseValues();
684 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
685 if (it.value() == caseValue) {
686 rewriter.replaceOpWithNewOp<BranchOp>(
687 op, op.getCaseDestinations()[it.index()],
688 op.getCaseOperands(it.index()));
689 return;
690 }
691 }
692 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
693 op.getDefaultOperands());
694}
695
696/// switch %c_42 : i32, [
697/// default: ^bb1,
698/// 42: ^bb2,
699/// 43: ^bb3
700/// ]
701/// -> br ^bb2
702static LogicalResult simplifyConstSwitchValue(SwitchOp op,
703 PatternRewriter &rewriter) {
704 APInt caseValue;
705 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
706 return failure();
707
708 foldSwitch(op, rewriter, caseValue);
709 return success();
710}
711
712/// switch %c_42 : i32, [
713/// default: ^bb1,
714/// 42: ^bb2,
715/// ]
716/// ^bb2:
717/// br ^bb3
718/// ->
719/// switch %c_42 : i32, [
720/// default: ^bb1,
721/// 42: ^bb3,
722/// ]
723static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
724 PatternRewriter &rewriter) {
725 SmallVector<Block *> newCaseDests;
726 SmallVector<ValueRange> newCaseOperands;
727 SmallVector<SmallVector<Value>> argStorage;
728 auto caseValues = op.getCaseValues();
729 argStorage.reserve(N: caseValues->size() + 1);
730 auto caseDests = op.getCaseDestinations();
731 bool requiresChange = false;
732 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
733 Block *caseDest = caseDests[i];
734 ValueRange caseOperands = op.getCaseOperands(i);
735 argStorage.emplace_back();
736 if (succeeded(result: collapseBranch(successor&: caseDest, successorOperands&: caseOperands, argStorage&: argStorage.back())))
737 requiresChange = true;
738
739 newCaseDests.push_back(Elt: caseDest);
740 newCaseOperands.push_back(Elt: caseOperands);
741 }
742
743 Block *defaultDest = op.getDefaultDestination();
744 ValueRange defaultOperands = op.getDefaultOperands();
745 argStorage.emplace_back();
746
747 if (succeeded(
748 result: collapseBranch(successor&: defaultDest, successorOperands&: defaultOperands, argStorage&: argStorage.back())))
749 requiresChange = true;
750
751 if (!requiresChange)
752 return failure();
753
754 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
755 defaultOperands, *caseValues,
756 newCaseDests, newCaseOperands);
757 return success();
758}
759
760/// switch %flag : i32, [
761/// default: ^bb1,
762/// 42: ^bb2,
763/// ]
764/// ^bb2:
765/// switch %flag : i32, [
766/// default: ^bb3,
767/// 42: ^bb4
768/// ]
769/// ->
770/// switch %flag : i32, [
771/// default: ^bb1,
772/// 42: ^bb2,
773/// ]
774/// ^bb2:
775/// br ^bb4
776///
777/// and
778///
779/// switch %flag : i32, [
780/// default: ^bb1,
781/// 42: ^bb2,
782/// ]
783/// ^bb2:
784/// switch %flag : i32, [
785/// default: ^bb3,
786/// 43: ^bb4
787/// ]
788/// ->
789/// switch %flag : i32, [
790/// default: ^bb1,
791/// 42: ^bb2,
792/// ]
793/// ^bb2:
794/// br ^bb3
795static LogicalResult
796simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
797 PatternRewriter &rewriter) {
798 // Check that we have a single distinct predecessor.
799 Block *currentBlock = op->getBlock();
800 Block *predecessor = currentBlock->getSinglePredecessor();
801 if (!predecessor)
802 return failure();
803
804 // Check that the predecessor terminates with a switch branch to this block
805 // and that it branches on the same condition and that this branch isn't the
806 // default destination.
807 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
808 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
809 predSwitch.getDefaultDestination() == currentBlock)
810 return failure();
811
812 // Fold this switch to an unconditional branch.
813 SuccessorRange predDests = predSwitch.getCaseDestinations();
814 auto it = llvm::find(Range&: predDests, Val: currentBlock);
815 if (it != predDests.end()) {
816 std::optional<DenseIntElementsAttr> predCaseValues =
817 predSwitch.getCaseValues();
818 foldSwitch(op, rewriter,
819 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
820 } else {
821 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
822 op.getDefaultOperands());
823 }
824 return success();
825}
826
827/// switch %flag : i32, [
828/// default: ^bb1,
829/// 42: ^bb2
830/// ]
831/// ^bb1:
832/// switch %flag : i32, [
833/// default: ^bb3,
834/// 42: ^bb4,
835/// 43: ^bb5
836/// ]
837/// ->
838/// switch %flag : i32, [
839/// default: ^bb1,
840/// 42: ^bb2,
841/// ]
842/// ^bb1:
843/// switch %flag : i32, [
844/// default: ^bb3,
845/// 43: ^bb5
846/// ]
847static LogicalResult
848simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
849 PatternRewriter &rewriter) {
850 // Check that we have a single distinct predecessor.
851 Block *currentBlock = op->getBlock();
852 Block *predecessor = currentBlock->getSinglePredecessor();
853 if (!predecessor)
854 return failure();
855
856 // Check that the predecessor terminates with a switch branch to this block
857 // and that it branches on the same condition and that this branch is the
858 // default destination.
859 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
860 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
861 predSwitch.getDefaultDestination() != currentBlock)
862 return failure();
863
864 // Delete case values that are not possible here.
865 DenseSet<APInt> caseValuesToRemove;
866 auto predDests = predSwitch.getCaseDestinations();
867 auto predCaseValues = predSwitch.getCaseValues();
868 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
869 if (currentBlock != predDests[i])
870 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
871
872 SmallVector<Block *> newCaseDestinations;
873 SmallVector<ValueRange> newCaseOperands;
874 SmallVector<APInt> newCaseValues;
875 bool requiresChange = false;
876
877 auto caseValues = op.getCaseValues();
878 auto caseDests = op.getCaseDestinations();
879 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
880 if (caseValuesToRemove.contains(it.value())) {
881 requiresChange = true;
882 continue;
883 }
884 newCaseDestinations.push_back(caseDests[it.index()]);
885 newCaseOperands.push_back(op.getCaseOperands(it.index()));
886 newCaseValues.push_back(it.value());
887 }
888
889 if (!requiresChange)
890 return failure();
891
892 rewriter.replaceOpWithNewOp<SwitchOp>(
893 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
894 newCaseValues, newCaseDestinations, newCaseOperands);
895 return success();
896}
897
898void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
899 MLIRContext *context) {
900 results.add(&simplifySwitchWithOnlyDefault)
901 .add(&dropSwitchCasesThatMatchDefault)
902 .add(&simplifyConstSwitchValue)
903 .add(&simplifyPassThroughSwitch)
904 .add(&simplifySwitchFromSwitchOnSameCondition)
905 .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
906}
907
908//===----------------------------------------------------------------------===//
909// TableGen'd op method definitions
910//===----------------------------------------------------------------------===//
911
912#define GET_OP_CLASSES
913#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
914

source code of mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp