1//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
14#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/BuiltinOps.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/IR/TypeUtilities.h"
25#include "mlir/IR/Value.h"
26#include "mlir/Transforms/InliningUtils.h"
27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/FormatVariadic.h"
30#include "llvm/Support/raw_ostream.h"
31#include <numeric>
32
33#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
34
35using namespace mlir;
36using namespace mlir::cf;
37
38//===----------------------------------------------------------------------===//
39// ControlFlowDialect Interfaces
40//===----------------------------------------------------------------------===//
41namespace {
42/// This class defines the interface for handling inlining with control flow
43/// operations.
44struct ControlFlowInlinerInterface : public DialectInlinerInterface {
45 using DialectInlinerInterface::DialectInlinerInterface;
46 ~ControlFlowInlinerInterface() override = default;
47
48 /// All control flow operations can be inlined.
49 bool isLegalToInline(Operation *call, Operation *callable,
50 bool wouldBeCloned) const final {
51 return true;
52 }
53 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
54 return true;
55 }
56
57 /// ControlFlow terminator operations don't really need any special handing.
58 void handleTerminator(Operation *op, Block *newDest) const final {}
59};
60} // namespace
61
62//===----------------------------------------------------------------------===//
63// ControlFlowDialect
64//===----------------------------------------------------------------------===//
65
66void ControlFlowDialect::initialize() {
67 addOperations<
68#define GET_OP_LIST
69#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
70 >();
71 addInterfaces<ControlFlowInlinerInterface>();
72 declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
73 declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
74 CondBranchOp>();
75 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
76 CondBranchOp>();
77}
78
79//===----------------------------------------------------------------------===//
80// AssertOp
81//===----------------------------------------------------------------------===//
82
83LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
84 // Erase assertion if argument is constant true.
85 if (matchPattern(op.getArg(), m_One())) {
86 rewriter.eraseOp(op);
87 return success();
88 }
89 return failure();
90}
91
92// This side effect models "program termination".
93void AssertOp::getEffects(
94 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
95 &effects) {
96 effects.emplace_back(MemoryEffects::Write::get());
97}
98
99//===----------------------------------------------------------------------===//
100// BranchOp
101//===----------------------------------------------------------------------===//
102
103/// Given a successor, try to collapse it to a new destination if it only
104/// contains a passthrough unconditional branch. If the successor is
105/// collapsable, `successor` and `successorOperands` are updated to reference
106/// the new destination and values. `argStorage` is used as storage if operands
107/// to the collapsed successor need to be remapped. It must outlive uses of
108/// successorOperands.
109static LogicalResult collapseBranch(Block *&successor,
110 ValueRange &successorOperands,
111 SmallVectorImpl<Value> &argStorage) {
112 // Check that the successor only contains a unconditional branch.
113 if (std::next(successor->begin()) != successor->end())
114 return failure();
115 // Check that the terminator is an unconditional branch.
116 BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
117 if (!successorBranch)
118 return failure();
119 // Check that the arguments are only used within the terminator.
120 for (BlockArgument arg : successor->getArguments()) {
121 for (Operation *user : arg.getUsers())
122 if (user != successorBranch)
123 return failure();
124 }
125 // Don't try to collapse branches to infinite loops.
126 Block *successorDest = successorBranch.getDest();
127 if (successorDest == successor)
128 return failure();
129
130 // Update the operands to the successor. If the branch parent has no
131 // arguments, we can use the branch operands directly.
132 OperandRange operands = successorBranch.getOperands();
133 if (successor->args_empty()) {
134 successor = successorDest;
135 successorOperands = operands;
136 return success();
137 }
138
139 // Otherwise, we need to remap any argument operands.
140 for (Value operand : operands) {
141 BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
142 if (argOperand && argOperand.getOwner() == successor)
143 argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
144 else
145 argStorage.push_back(operand);
146 }
147 successor = successorDest;
148 successorOperands = argStorage;
149 return success();
150}
151
152/// Simplify a branch to a block that has a single predecessor. This effectively
153/// merges the two blocks.
154static LogicalResult
155simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
156 // Check that the successor block has a single predecessor.
157 Block *succ = op.getDest();
158 Block *opParent = op->getBlock();
159 if (succ == opParent || !llvm::hasSingleElement(C: succ->getPredecessors()))
160 return failure();
161
162 // Merge the successor into the current block and erase the branch.
163 SmallVector<Value> brOperands(op.getOperands());
164 rewriter.eraseOp(op: op);
165 rewriter.mergeBlocks(source: succ, dest: opParent, argValues: brOperands);
166 return success();
167}
168
169/// br ^bb1
170/// ^bb1
171/// br ^bbN(...)
172///
173/// -> br ^bbN(...)
174///
175static LogicalResult simplifyPassThroughBr(BranchOp op,
176 PatternRewriter &rewriter) {
177 Block *dest = op.getDest();
178 ValueRange destOperands = op.getOperands();
179 SmallVector<Value, 4> destOperandStorage;
180
181 // Try to collapse the successor if it points somewhere other than this
182 // block.
183 if (dest == op->getBlock() ||
184 failed(Result: collapseBranch(successor&: dest, successorOperands&: destOperands, argStorage&: destOperandStorage)))
185 return failure();
186
187 // Create a new branch with the collapsed successor.
188 rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
189 return success();
190}
191
192LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
193 return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
194 succeeded(simplifyPassThroughBr(op, rewriter)));
195}
196
197void BranchOp::setDest(Block *block) { return setSuccessor(block); }
198
199void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
200
201SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
202 assert(index == 0 && "invalid successor index");
203 return SuccessorOperands(getDestOperandsMutable());
204}
205
206Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
207 return getDest();
208}
209
210//===----------------------------------------------------------------------===//
211// CondBranchOp
212//===----------------------------------------------------------------------===//
213
214namespace {
215/// cf.cond_br true, ^bb1, ^bb2
216/// -> br ^bb1
217/// cf.cond_br false, ^bb1, ^bb2
218/// -> br ^bb2
219///
220struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
221 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
222
223 LogicalResult matchAndRewrite(CondBranchOp condbr,
224 PatternRewriter &rewriter) const override {
225 if (matchPattern(condbr.getCondition(), m_NonZero())) {
226 // True branch taken.
227 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
228 condbr.getTrueOperands());
229 return success();
230 }
231 if (matchPattern(condbr.getCondition(), m_Zero())) {
232 // False branch taken.
233 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
234 condbr.getFalseOperands());
235 return success();
236 }
237 return failure();
238 }
239};
240
241/// cf.cond_br %cond, ^bb1, ^bb2
242/// ^bb1
243/// br ^bbN(...)
244/// ^bb2
245/// br ^bbK(...)
246///
247/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
248///
249struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
250 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
251
252 LogicalResult matchAndRewrite(CondBranchOp condbr,
253 PatternRewriter &rewriter) const override {
254 Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
255 ValueRange trueDestOperands = condbr.getTrueOperands();
256 ValueRange falseDestOperands = condbr.getFalseOperands();
257 SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
258
259 // Try to collapse one of the current successors.
260 LogicalResult collapsedTrue =
261 collapseBranch(successor&: trueDest, successorOperands&: trueDestOperands, argStorage&: trueDestOperandStorage);
262 LogicalResult collapsedFalse =
263 collapseBranch(successor&: falseDest, successorOperands&: falseDestOperands, argStorage&: falseDestOperandStorage);
264 if (failed(Result: collapsedTrue) && failed(Result: collapsedFalse))
265 return failure();
266
267 // Create a new branch with the collapsed successors.
268 rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
269 trueDest, trueDestOperands,
270 falseDest, falseDestOperands);
271 return success();
272 }
273};
274
275/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
276/// -> br ^bb1(A, ..., N)
277///
278/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
279/// -> %select = arith.select %cond, A, B
280/// br ^bb1(%select)
281///
282struct SimplifyCondBranchIdenticalSuccessors
283 : public OpRewritePattern<CondBranchOp> {
284 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
285
286 LogicalResult matchAndRewrite(CondBranchOp condbr,
287 PatternRewriter &rewriter) const override {
288 // Check that the true and false destinations are the same and have the same
289 // operands.
290 Block *trueDest = condbr.getTrueDest();
291 if (trueDest != condbr.getFalseDest())
292 return failure();
293
294 // If all of the operands match, no selects need to be generated.
295 OperandRange trueOperands = condbr.getTrueOperands();
296 OperandRange falseOperands = condbr.getFalseOperands();
297 if (trueOperands == falseOperands) {
298 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
299 return success();
300 }
301
302 // Otherwise, if the current block is the only predecessor insert selects
303 // for any mismatched branch operands.
304 if (trueDest->getUniquePredecessor() != condbr->getBlock())
305 return failure();
306
307 // Generate a select for any operands that differ between the two.
308 SmallVector<Value, 8> mergedOperands;
309 mergedOperands.reserve(N: trueOperands.size());
310 Value condition = condbr.getCondition();
311 for (auto it : llvm::zip(trueOperands, falseOperands)) {
312 if (std::get<0>(it) == std::get<1>(it))
313 mergedOperands.push_back(std::get<0>(it));
314 else
315 mergedOperands.push_back(rewriter.create<arith::SelectOp>(
316 condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it)));
317 }
318
319 rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
320 return success();
321 }
322};
323
324/// ...
325/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
326/// ...
327/// ^bb1: // has single predecessor
328/// ...
329/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
330///
331/// ->
332///
333/// ...
334/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
335/// ...
336/// ^bb1: // has single predecessor
337/// ...
338/// br ^bb3(...)
339///
340struct SimplifyCondBranchFromCondBranchOnSameCondition
341 : public OpRewritePattern<CondBranchOp> {
342 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
343
344 LogicalResult matchAndRewrite(CondBranchOp condbr,
345 PatternRewriter &rewriter) const override {
346 // Check that we have a single distinct predecessor.
347 Block *currentBlock = condbr->getBlock();
348 Block *predecessor = currentBlock->getSinglePredecessor();
349 if (!predecessor)
350 return failure();
351
352 // Check that the predecessor terminates with a conditional branch to this
353 // block and that it branches on the same condition.
354 auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
355 if (!predBranch || condbr.getCondition() != predBranch.getCondition())
356 return failure();
357
358 // Fold this branch to an unconditional branch.
359 if (currentBlock == predBranch.getTrueDest())
360 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
361 condbr.getTrueDestOperands());
362 else
363 rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
364 condbr.getFalseDestOperands());
365 return success();
366 }
367};
368
369/// cf.cond_br %arg0, ^trueB, ^falseB
370///
371/// ^trueB:
372/// "test.consumer1"(%arg0) : (i1) -> ()
373/// ...
374///
375/// ^falseB:
376/// "test.consumer2"(%arg0) : (i1) -> ()
377/// ...
378///
379/// ->
380///
381/// cf.cond_br %arg0, ^trueB, ^falseB
382/// ^trueB:
383/// "test.consumer1"(%true) : (i1) -> ()
384/// ...
385///
386/// ^falseB:
387/// "test.consumer2"(%false) : (i1) -> ()
388/// ...
389struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
390 using OpRewritePattern<CondBranchOp>::OpRewritePattern;
391
392 LogicalResult matchAndRewrite(CondBranchOp condbr,
393 PatternRewriter &rewriter) const override {
394 // Check that we have a single distinct predecessor.
395 bool replaced = false;
396 Type ty = rewriter.getI1Type();
397
398 // These variables serve to prevent creating duplicate constants
399 // and hold constant true or false values.
400 Value constantTrue = nullptr;
401 Value constantFalse = nullptr;
402
403 // TODO These checks can be expanded to encompas any use with only
404 // either the true of false edge as a predecessor. For now, we fall
405 // back to checking the single predecessor is given by the true/fasle
406 // destination, thereby ensuring that only that edge can reach the
407 // op.
408 if (condbr.getTrueDest()->getSinglePredecessor()) {
409 for (OpOperand &use :
410 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
411 if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
412 replaced = true;
413
414 if (!constantTrue)
415 constantTrue = rewriter.create<arith::ConstantOp>(
416 condbr.getLoc(), ty, rewriter.getBoolAttr(true));
417
418 rewriter.modifyOpInPlace(use.getOwner(),
419 [&] { use.set(constantTrue); });
420 }
421 }
422 }
423 if (condbr.getFalseDest()->getSinglePredecessor()) {
424 for (OpOperand &use :
425 llvm::make_early_inc_range(condbr.getCondition().getUses())) {
426 if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
427 replaced = true;
428
429 if (!constantFalse)
430 constantFalse = rewriter.create<arith::ConstantOp>(
431 condbr.getLoc(), ty, rewriter.getBoolAttr(false));
432
433 rewriter.modifyOpInPlace(use.getOwner(),
434 [&] { use.set(constantFalse); });
435 }
436 }
437 }
438 return success(IsSuccess: replaced);
439 }
440};
441} // namespace
442
443void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
444 MLIRContext *context) {
445 results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
446 SimplifyCondBranchIdenticalSuccessors,
447 SimplifyCondBranchFromCondBranchOnSameCondition,
448 CondBranchTruthPropagation>(context);
449}
450
451SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
452 assert(index < getNumSuccessors() && "invalid successor index");
453 return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
454 : getFalseDestOperandsMutable());
455}
456
457Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
458 if (IntegerAttr condAttr =
459 llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
460 return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
461 return nullptr;
462}
463
464//===----------------------------------------------------------------------===//
465// SwitchOp
466//===----------------------------------------------------------------------===//
467
468void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
469 Block *defaultDestination, ValueRange defaultOperands,
470 DenseIntElementsAttr caseValues,
471 BlockRange caseDestinations,
472 ArrayRef<ValueRange> caseOperands) {
473 build(builder, result, value, defaultOperands, caseOperands, caseValues,
474 defaultDestination, caseDestinations);
475}
476
477void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
478 Block *defaultDestination, ValueRange defaultOperands,
479 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
480 ArrayRef<ValueRange> caseOperands) {
481 DenseIntElementsAttr caseValuesAttr;
482 if (!caseValues.empty()) {
483 ShapedType caseValueType = VectorType::get(
484 static_cast<int64_t>(caseValues.size()), value.getType());
485 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
486 }
487 build(builder, result, value, defaultDestination, defaultOperands,
488 caseValuesAttr, caseDestinations, caseOperands);
489}
490
491void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
492 Block *defaultDestination, ValueRange defaultOperands,
493 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
494 ArrayRef<ValueRange> caseOperands) {
495 DenseIntElementsAttr caseValuesAttr;
496 if (!caseValues.empty()) {
497 ShapedType caseValueType = VectorType::get(
498 static_cast<int64_t>(caseValues.size()), value.getType());
499 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
500 }
501 build(builder, result, value, defaultDestination, defaultOperands,
502 caseValuesAttr, caseDestinations, caseOperands);
503}
504
505/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
506/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
507static ParseResult parseSwitchOpCases(
508 OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
509 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
510 SmallVectorImpl<Type> &defaultOperandTypes,
511 DenseIntElementsAttr &caseValues,
512 SmallVectorImpl<Block *> &caseDestinations,
513 SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
514 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
515 if (parser.parseKeyword(keyword: "default") || parser.parseColon() ||
516 parser.parseSuccessor(dest&: defaultDestination))
517 return failure();
518 if (succeeded(Result: parser.parseOptionalLParen())) {
519 if (parser.parseOperandList(result&: defaultOperands, delimiter: OpAsmParser::Delimiter::None,
520 /*allowResultNumber=*/false) ||
521 parser.parseColonTypeList(result&: defaultOperandTypes) || parser.parseRParen())
522 return failure();
523 }
524
525 SmallVector<APInt> values;
526 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
527 while (succeeded(Result: parser.parseOptionalComma())) {
528 int64_t value = 0;
529 if (failed(Result: parser.parseInteger(result&: value)))
530 return failure();
531 values.push_back(Elt: APInt(bitWidth, value, /*isSigned=*/true));
532
533 Block *destination;
534 SmallVector<OpAsmParser::UnresolvedOperand> operands;
535 SmallVector<Type> operandTypes;
536 if (failed(Result: parser.parseColon()) ||
537 failed(Result: parser.parseSuccessor(dest&: destination)))
538 return failure();
539 if (succeeded(Result: parser.parseOptionalLParen())) {
540 if (failed(Result: parser.parseOperandList(result&: operands,
541 delimiter: OpAsmParser::Delimiter::None)) ||
542 failed(Result: parser.parseColonTypeList(result&: operandTypes)) ||
543 failed(Result: parser.parseRParen()))
544 return failure();
545 }
546 caseDestinations.push_back(Elt: destination);
547 caseOperands.emplace_back(Args&: operands);
548 caseOperandTypes.emplace_back(Args&: operandTypes);
549 }
550
551 if (!values.empty()) {
552 ShapedType caseValueType =
553 VectorType::get(static_cast<int64_t>(values.size()), flagType);
554 caseValues = DenseIntElementsAttr::get(caseValueType, values);
555 }
556 return success();
557}
558
559static void printSwitchOpCases(
560 OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
561 OperandRange defaultOperands, TypeRange defaultOperandTypes,
562 DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
563 OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
564 p << " default: ";
565 p.printSuccessorAndUseList(successor: defaultDestination, succOperands: defaultOperands);
566
567 if (!caseValues)
568 return;
569
570 for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
571 p << ',';
572 p.printNewline();
573 p << " ";
574 p << it.value().getLimitedValue();
575 p << ": ";
576 p.printSuccessorAndUseList(caseDestinations[it.index()],
577 caseOperands[it.index()]);
578 }
579 p.printNewline();
580}
581
582LogicalResult SwitchOp::verify() {
583 auto caseValues = getCaseValues();
584 auto caseDestinations = getCaseDestinations();
585
586 if (!caseValues && caseDestinations.empty())
587 return success();
588
589 Type flagType = getFlag().getType();
590 Type caseValueType = caseValues->getType().getElementType();
591 if (caseValueType != flagType)
592 return emitOpError() << "'flag' type (" << flagType
593 << ") should match case value type (" << caseValueType
594 << ")";
595
596 if (caseValues &&
597 caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
598 return emitOpError() << "number of case values (" << caseValues->size()
599 << ") should match number of "
600 "case destinations ("
601 << caseDestinations.size() << ")";
602 return success();
603}
604
605SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
606 assert(index < getNumSuccessors() && "invalid successor index");
607 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
608 : getCaseOperandsMutable(index - 1));
609}
610
611Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
612 std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
613
614 if (!caseValues)
615 return getDefaultDestination();
616
617 SuccessorRange caseDests = getCaseDestinations();
618 if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
619 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
620 if (it.value() == value.getValue())
621 return caseDests[it.index()];
622 return getDefaultDestination();
623 }
624 return nullptr;
625}
626
627/// switch %flag : i32, [
628/// default: ^bb1
629/// ]
630/// -> br ^bb1
631static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
632 PatternRewriter &rewriter) {
633 if (!op.getCaseDestinations().empty())
634 return failure();
635
636 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
637 op.getDefaultOperands());
638 return success();
639}
640
641/// switch %flag : i32, [
642/// default: ^bb1,
643/// 42: ^bb1,
644/// 43: ^bb2
645/// ]
646/// ->
647/// switch %flag : i32, [
648/// default: ^bb1,
649/// 43: ^bb2
650/// ]
651static LogicalResult
652dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
653 SmallVector<Block *> newCaseDestinations;
654 SmallVector<ValueRange> newCaseOperands;
655 SmallVector<APInt> newCaseValues;
656 bool requiresChange = false;
657 auto caseValues = op.getCaseValues();
658 auto caseDests = op.getCaseDestinations();
659
660 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
661 if (caseDests[it.index()] == op.getDefaultDestination() &&
662 op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
663 requiresChange = true;
664 continue;
665 }
666 newCaseDestinations.push_back(caseDests[it.index()]);
667 newCaseOperands.push_back(op.getCaseOperands(it.index()));
668 newCaseValues.push_back(it.value());
669 }
670
671 if (!requiresChange)
672 return failure();
673
674 rewriter.replaceOpWithNewOp<SwitchOp>(
675 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
676 newCaseValues, newCaseDestinations, newCaseOperands);
677 return success();
678}
679
680/// Helper for folding a switch with a constant value.
681/// switch %c_42 : i32, [
682/// default: ^bb1 ,
683/// 42: ^bb2,
684/// 43: ^bb3
685/// ]
686/// -> br ^bb2
687static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
688 const APInt &caseValue) {
689 auto caseValues = op.getCaseValues();
690 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
691 if (it.value() == caseValue) {
692 rewriter.replaceOpWithNewOp<BranchOp>(
693 op, op.getCaseDestinations()[it.index()],
694 op.getCaseOperands(it.index()));
695 return;
696 }
697 }
698 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
699 op.getDefaultOperands());
700}
701
702/// switch %c_42 : i32, [
703/// default: ^bb1,
704/// 42: ^bb2,
705/// 43: ^bb3
706/// ]
707/// -> br ^bb2
708static LogicalResult simplifyConstSwitchValue(SwitchOp op,
709 PatternRewriter &rewriter) {
710 APInt caseValue;
711 if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
712 return failure();
713
714 foldSwitch(op, rewriter, caseValue);
715 return success();
716}
717
718/// switch %c_42 : i32, [
719/// default: ^bb1,
720/// 42: ^bb2,
721/// ]
722/// ^bb2:
723/// br ^bb3
724/// ->
725/// switch %c_42 : i32, [
726/// default: ^bb1,
727/// 42: ^bb3,
728/// ]
729static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
730 PatternRewriter &rewriter) {
731 SmallVector<Block *> newCaseDests;
732 SmallVector<ValueRange> newCaseOperands;
733 SmallVector<SmallVector<Value>> argStorage;
734 auto caseValues = op.getCaseValues();
735 argStorage.reserve(N: caseValues->size() + 1);
736 auto caseDests = op.getCaseDestinations();
737 bool requiresChange = false;
738 for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
739 Block *caseDest = caseDests[i];
740 ValueRange caseOperands = op.getCaseOperands(i);
741 argStorage.emplace_back();
742 if (succeeded(Result: collapseBranch(successor&: caseDest, successorOperands&: caseOperands, argStorage&: argStorage.back())))
743 requiresChange = true;
744
745 newCaseDests.push_back(Elt: caseDest);
746 newCaseOperands.push_back(Elt: caseOperands);
747 }
748
749 Block *defaultDest = op.getDefaultDestination();
750 ValueRange defaultOperands = op.getDefaultOperands();
751 argStorage.emplace_back();
752
753 if (succeeded(
754 Result: collapseBranch(successor&: defaultDest, successorOperands&: defaultOperands, argStorage&: argStorage.back())))
755 requiresChange = true;
756
757 if (!requiresChange)
758 return failure();
759
760 rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
761 defaultOperands, *caseValues,
762 newCaseDests, newCaseOperands);
763 return success();
764}
765
766/// switch %flag : i32, [
767/// default: ^bb1,
768/// 42: ^bb2,
769/// ]
770/// ^bb2:
771/// switch %flag : i32, [
772/// default: ^bb3,
773/// 42: ^bb4
774/// ]
775/// ->
776/// switch %flag : i32, [
777/// default: ^bb1,
778/// 42: ^bb2,
779/// ]
780/// ^bb2:
781/// br ^bb4
782///
783/// and
784///
785/// switch %flag : i32, [
786/// default: ^bb1,
787/// 42: ^bb2,
788/// ]
789/// ^bb2:
790/// switch %flag : i32, [
791/// default: ^bb3,
792/// 43: ^bb4
793/// ]
794/// ->
795/// switch %flag : i32, [
796/// default: ^bb1,
797/// 42: ^bb2,
798/// ]
799/// ^bb2:
800/// br ^bb3
801static LogicalResult
802simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
803 PatternRewriter &rewriter) {
804 // Check that we have a single distinct predecessor.
805 Block *currentBlock = op->getBlock();
806 Block *predecessor = currentBlock->getSinglePredecessor();
807 if (!predecessor)
808 return failure();
809
810 // Check that the predecessor terminates with a switch branch to this block
811 // and that it branches on the same condition and that this branch isn't the
812 // default destination.
813 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
814 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
815 predSwitch.getDefaultDestination() == currentBlock)
816 return failure();
817
818 // Fold this switch to an unconditional branch.
819 SuccessorRange predDests = predSwitch.getCaseDestinations();
820 auto it = llvm::find(Range&: predDests, Val: currentBlock);
821 if (it != predDests.end()) {
822 std::optional<DenseIntElementsAttr> predCaseValues =
823 predSwitch.getCaseValues();
824 foldSwitch(op, rewriter,
825 predCaseValues->getValues<APInt>()[it - predDests.begin()]);
826 } else {
827 rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
828 op.getDefaultOperands());
829 }
830 return success();
831}
832
833/// switch %flag : i32, [
834/// default: ^bb1,
835/// 42: ^bb2
836/// ]
837/// ^bb1:
838/// switch %flag : i32, [
839/// default: ^bb3,
840/// 42: ^bb4,
841/// 43: ^bb5
842/// ]
843/// ->
844/// switch %flag : i32, [
845/// default: ^bb1,
846/// 42: ^bb2,
847/// ]
848/// ^bb1:
849/// switch %flag : i32, [
850/// default: ^bb3,
851/// 43: ^bb5
852/// ]
853static LogicalResult
854simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
855 PatternRewriter &rewriter) {
856 // Check that we have a single distinct predecessor.
857 Block *currentBlock = op->getBlock();
858 Block *predecessor = currentBlock->getSinglePredecessor();
859 if (!predecessor)
860 return failure();
861
862 // Check that the predecessor terminates with a switch branch to this block
863 // and that it branches on the same condition and that this branch is the
864 // default destination.
865 auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
866 if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
867 predSwitch.getDefaultDestination() != currentBlock)
868 return failure();
869
870 // Delete case values that are not possible here.
871 DenseSet<APInt> caseValuesToRemove;
872 auto predDests = predSwitch.getCaseDestinations();
873 auto predCaseValues = predSwitch.getCaseValues();
874 for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
875 if (currentBlock != predDests[i])
876 caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
877
878 SmallVector<Block *> newCaseDestinations;
879 SmallVector<ValueRange> newCaseOperands;
880 SmallVector<APInt> newCaseValues;
881 bool requiresChange = false;
882
883 auto caseValues = op.getCaseValues();
884 auto caseDests = op.getCaseDestinations();
885 for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
886 if (caseValuesToRemove.contains(it.value())) {
887 requiresChange = true;
888 continue;
889 }
890 newCaseDestinations.push_back(caseDests[it.index()]);
891 newCaseOperands.push_back(op.getCaseOperands(it.index()));
892 newCaseValues.push_back(it.value());
893 }
894
895 if (!requiresChange)
896 return failure();
897
898 rewriter.replaceOpWithNewOp<SwitchOp>(
899 op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
900 newCaseValues, newCaseDestinations, newCaseOperands);
901 return success();
902}
903
904void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
905 MLIRContext *context) {
906 results.add(&simplifySwitchWithOnlyDefault)
907 .add(&dropSwitchCasesThatMatchDefault)
908 .add(&simplifyConstSwitchValue)
909 .add(&simplifyPassThroughSwitch)
910 .add(&simplifySwitchFromSwitchOnSameCondition)
911 .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
912}
913
914//===----------------------------------------------------------------------===//
915// TableGen'd op method definitions
916//===----------------------------------------------------------------------===//
917
918#define GET_OP_CLASSES
919#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
920

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp