1 | //===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
10 | |
11 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
14 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
15 | #include "mlir/IR/AffineExpr.h" |
16 | #include "mlir/IR/AffineMap.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/BuiltinOps.h" |
19 | #include "mlir/IR/BuiltinTypes.h" |
20 | #include "mlir/IR/IRMapping.h" |
21 | #include "mlir/IR/Matchers.h" |
22 | #include "mlir/IR/OpImplementation.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/IR/TypeUtilities.h" |
25 | #include "mlir/IR/Value.h" |
26 | #include "mlir/Transforms/InliningUtils.h" |
27 | #include "llvm/ADT/APFloat.h" |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/Support/FormatVariadic.h" |
30 | #include "llvm/Support/raw_ostream.h" |
31 | #include <numeric> |
32 | |
33 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" |
34 | |
35 | using namespace mlir; |
36 | using namespace mlir::cf; |
37 | |
38 | //===----------------------------------------------------------------------===// |
39 | // ControlFlowDialect Interfaces |
40 | //===----------------------------------------------------------------------===// |
41 | namespace { |
42 | /// This class defines the interface for handling inlining with control flow |
43 | /// operations. |
44 | struct ControlFlowInlinerInterface : public DialectInlinerInterface { |
45 | using DialectInlinerInterface::DialectInlinerInterface; |
46 | ~ControlFlowInlinerInterface() override = default; |
47 | |
48 | /// All control flow operations can be inlined. |
49 | bool isLegalToInline(Operation *call, Operation *callable, |
50 | bool wouldBeCloned) const final { |
51 | return true; |
52 | } |
53 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
54 | return true; |
55 | } |
56 | |
57 | /// ControlFlow terminator operations don't really need any special handing. |
58 | void handleTerminator(Operation *op, Block *newDest) const final {} |
59 | }; |
60 | } // namespace |
61 | |
62 | //===----------------------------------------------------------------------===// |
63 | // ControlFlowDialect |
64 | //===----------------------------------------------------------------------===// |
65 | |
66 | void ControlFlowDialect::initialize() { |
67 | addOperations< |
68 | #define GET_OP_LIST |
69 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" |
70 | >(); |
71 | addInterfaces<ControlFlowInlinerInterface>(); |
72 | declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>(); |
73 | declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp, |
74 | CondBranchOp>(); |
75 | declarePromisedInterface<bufferization::BufferDeallocationOpInterface, |
76 | CondBranchOp>(); |
77 | } |
78 | |
79 | //===----------------------------------------------------------------------===// |
80 | // AssertOp |
81 | //===----------------------------------------------------------------------===// |
82 | |
83 | LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { |
84 | // Erase assertion if argument is constant true. |
85 | if (matchPattern(op.getArg(), m_One())) { |
86 | rewriter.eraseOp(op); |
87 | return success(); |
88 | } |
89 | return failure(); |
90 | } |
91 | |
92 | // This side effect models "program termination". |
93 | void AssertOp::getEffects( |
94 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
95 | &effects) { |
96 | effects.emplace_back(MemoryEffects::Write::get()); |
97 | } |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // BranchOp |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | /// Given a successor, try to collapse it to a new destination if it only |
104 | /// contains a passthrough unconditional branch. If the successor is |
105 | /// collapsable, `successor` and `successorOperands` are updated to reference |
106 | /// the new destination and values. `argStorage` is used as storage if operands |
107 | /// to the collapsed successor need to be remapped. It must outlive uses of |
108 | /// successorOperands. |
109 | static LogicalResult collapseBranch(Block *&successor, |
110 | ValueRange &successorOperands, |
111 | SmallVectorImpl<Value> &argStorage) { |
112 | // Check that the successor only contains a unconditional branch. |
113 | if (std::next(successor->begin()) != successor->end()) |
114 | return failure(); |
115 | // Check that the terminator is an unconditional branch. |
116 | BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator()); |
117 | if (!successorBranch) |
118 | return failure(); |
119 | // Check that the arguments are only used within the terminator. |
120 | for (BlockArgument arg : successor->getArguments()) { |
121 | for (Operation *user : arg.getUsers()) |
122 | if (user != successorBranch) |
123 | return failure(); |
124 | } |
125 | // Don't try to collapse branches to infinite loops. |
126 | Block *successorDest = successorBranch.getDest(); |
127 | if (successorDest == successor) |
128 | return failure(); |
129 | |
130 | // Update the operands to the successor. If the branch parent has no |
131 | // arguments, we can use the branch operands directly. |
132 | OperandRange operands = successorBranch.getOperands(); |
133 | if (successor->args_empty()) { |
134 | successor = successorDest; |
135 | successorOperands = operands; |
136 | return success(); |
137 | } |
138 | |
139 | // Otherwise, we need to remap any argument operands. |
140 | for (Value operand : operands) { |
141 | BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand); |
142 | if (argOperand && argOperand.getOwner() == successor) |
143 | argStorage.push_back(successorOperands[argOperand.getArgNumber()]); |
144 | else |
145 | argStorage.push_back(operand); |
146 | } |
147 | successor = successorDest; |
148 | successorOperands = argStorage; |
149 | return success(); |
150 | } |
151 | |
152 | /// Simplify a branch to a block that has a single predecessor. This effectively |
153 | /// merges the two blocks. |
154 | static LogicalResult |
155 | simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { |
156 | // Check that the successor block has a single predecessor. |
157 | Block *succ = op.getDest(); |
158 | Block *opParent = op->getBlock(); |
159 | if (succ == opParent || !llvm::hasSingleElement(C: succ->getPredecessors())) |
160 | return failure(); |
161 | |
162 | // Merge the successor into the current block and erase the branch. |
163 | SmallVector<Value> brOperands(op.getOperands()); |
164 | rewriter.eraseOp(op: op); |
165 | rewriter.mergeBlocks(source: succ, dest: opParent, argValues: brOperands); |
166 | return success(); |
167 | } |
168 | |
169 | /// br ^bb1 |
170 | /// ^bb1 |
171 | /// br ^bbN(...) |
172 | /// |
173 | /// -> br ^bbN(...) |
174 | /// |
175 | static LogicalResult simplifyPassThroughBr(BranchOp op, |
176 | PatternRewriter &rewriter) { |
177 | Block *dest = op.getDest(); |
178 | ValueRange destOperands = op.getOperands(); |
179 | SmallVector<Value, 4> destOperandStorage; |
180 | |
181 | // Try to collapse the successor if it points somewhere other than this |
182 | // block. |
183 | if (dest == op->getBlock() || |
184 | failed(Result: collapseBranch(successor&: dest, successorOperands&: destOperands, argStorage&: destOperandStorage))) |
185 | return failure(); |
186 | |
187 | // Create a new branch with the collapsed successor. |
188 | rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands); |
189 | return success(); |
190 | } |
191 | |
192 | LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { |
193 | return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || |
194 | succeeded(simplifyPassThroughBr(op, rewriter))); |
195 | } |
196 | |
197 | void BranchOp::setDest(Block *block) { return setSuccessor(block); } |
198 | |
199 | void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } |
200 | |
201 | SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) { |
202 | assert(index == 0 && "invalid successor index"); |
203 | return SuccessorOperands(getDestOperandsMutable()); |
204 | } |
205 | |
206 | Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { |
207 | return getDest(); |
208 | } |
209 | |
210 | //===----------------------------------------------------------------------===// |
211 | // CondBranchOp |
212 | //===----------------------------------------------------------------------===// |
213 | |
214 | namespace { |
215 | /// cf.cond_br true, ^bb1, ^bb2 |
216 | /// -> br ^bb1 |
217 | /// cf.cond_br false, ^bb1, ^bb2 |
218 | /// -> br ^bb2 |
219 | /// |
220 | struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> { |
221 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
222 | |
223 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
224 | PatternRewriter &rewriter) const override { |
225 | if (matchPattern(condbr.getCondition(), m_NonZero())) { |
226 | // True branch taken. |
227 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), |
228 | condbr.getTrueOperands()); |
229 | return success(); |
230 | } |
231 | if (matchPattern(condbr.getCondition(), m_Zero())) { |
232 | // False branch taken. |
233 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), |
234 | condbr.getFalseOperands()); |
235 | return success(); |
236 | } |
237 | return failure(); |
238 | } |
239 | }; |
240 | |
241 | /// cf.cond_br %cond, ^bb1, ^bb2 |
242 | /// ^bb1 |
243 | /// br ^bbN(...) |
244 | /// ^bb2 |
245 | /// br ^bbK(...) |
246 | /// |
247 | /// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) |
248 | /// |
249 | struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> { |
250 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
251 | |
252 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
253 | PatternRewriter &rewriter) const override { |
254 | Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); |
255 | ValueRange trueDestOperands = condbr.getTrueOperands(); |
256 | ValueRange falseDestOperands = condbr.getFalseOperands(); |
257 | SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage; |
258 | |
259 | // Try to collapse one of the current successors. |
260 | LogicalResult collapsedTrue = |
261 | collapseBranch(successor&: trueDest, successorOperands&: trueDestOperands, argStorage&: trueDestOperandStorage); |
262 | LogicalResult collapsedFalse = |
263 | collapseBranch(successor&: falseDest, successorOperands&: falseDestOperands, argStorage&: falseDestOperandStorage); |
264 | if (failed(Result: collapsedTrue) && failed(Result: collapsedFalse)) |
265 | return failure(); |
266 | |
267 | // Create a new branch with the collapsed successors. |
268 | rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(), |
269 | trueDest, trueDestOperands, |
270 | falseDest, falseDestOperands); |
271 | return success(); |
272 | } |
273 | }; |
274 | |
275 | /// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) |
276 | /// -> br ^bb1(A, ..., N) |
277 | /// |
278 | /// cf.cond_br %cond, ^bb1(A), ^bb1(B) |
279 | /// -> %select = arith.select %cond, A, B |
280 | /// br ^bb1(%select) |
281 | /// |
282 | struct SimplifyCondBranchIdenticalSuccessors |
283 | : public OpRewritePattern<CondBranchOp> { |
284 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
285 | |
286 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
287 | PatternRewriter &rewriter) const override { |
288 | // Check that the true and false destinations are the same and have the same |
289 | // operands. |
290 | Block *trueDest = condbr.getTrueDest(); |
291 | if (trueDest != condbr.getFalseDest()) |
292 | return failure(); |
293 | |
294 | // If all of the operands match, no selects need to be generated. |
295 | OperandRange trueOperands = condbr.getTrueOperands(); |
296 | OperandRange falseOperands = condbr.getFalseOperands(); |
297 | if (trueOperands == falseOperands) { |
298 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands); |
299 | return success(); |
300 | } |
301 | |
302 | // Otherwise, if the current block is the only predecessor insert selects |
303 | // for any mismatched branch operands. |
304 | if (trueDest->getUniquePredecessor() != condbr->getBlock()) |
305 | return failure(); |
306 | |
307 | // Generate a select for any operands that differ between the two. |
308 | SmallVector<Value, 8> mergedOperands; |
309 | mergedOperands.reserve(N: trueOperands.size()); |
310 | Value condition = condbr.getCondition(); |
311 | for (auto it : llvm::zip(trueOperands, falseOperands)) { |
312 | if (std::get<0>(it) == std::get<1>(it)) |
313 | mergedOperands.push_back(std::get<0>(it)); |
314 | else |
315 | mergedOperands.push_back(rewriter.create<arith::SelectOp>( |
316 | condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); |
317 | } |
318 | |
319 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands); |
320 | return success(); |
321 | } |
322 | }; |
323 | |
324 | /// ... |
325 | /// cf.cond_br %cond, ^bb1(...), ^bb2(...) |
326 | /// ... |
327 | /// ^bb1: // has single predecessor |
328 | /// ... |
329 | /// cf.cond_br %cond, ^bb3(...), ^bb4(...) |
330 | /// |
331 | /// -> |
332 | /// |
333 | /// ... |
334 | /// cf.cond_br %cond, ^bb1(...), ^bb2(...) |
335 | /// ... |
336 | /// ^bb1: // has single predecessor |
337 | /// ... |
338 | /// br ^bb3(...) |
339 | /// |
340 | struct SimplifyCondBranchFromCondBranchOnSameCondition |
341 | : public OpRewritePattern<CondBranchOp> { |
342 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
343 | |
344 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
345 | PatternRewriter &rewriter) const override { |
346 | // Check that we have a single distinct predecessor. |
347 | Block *currentBlock = condbr->getBlock(); |
348 | Block *predecessor = currentBlock->getSinglePredecessor(); |
349 | if (!predecessor) |
350 | return failure(); |
351 | |
352 | // Check that the predecessor terminates with a conditional branch to this |
353 | // block and that it branches on the same condition. |
354 | auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator()); |
355 | if (!predBranch || condbr.getCondition() != predBranch.getCondition()) |
356 | return failure(); |
357 | |
358 | // Fold this branch to an unconditional branch. |
359 | if (currentBlock == predBranch.getTrueDest()) |
360 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), |
361 | condbr.getTrueDestOperands()); |
362 | else |
363 | rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), |
364 | condbr.getFalseDestOperands()); |
365 | return success(); |
366 | } |
367 | }; |
368 | |
369 | /// cf.cond_br %arg0, ^trueB, ^falseB |
370 | /// |
371 | /// ^trueB: |
372 | /// "test.consumer1"(%arg0) : (i1) -> () |
373 | /// ... |
374 | /// |
375 | /// ^falseB: |
376 | /// "test.consumer2"(%arg0) : (i1) -> () |
377 | /// ... |
378 | /// |
379 | /// -> |
380 | /// |
381 | /// cf.cond_br %arg0, ^trueB, ^falseB |
382 | /// ^trueB: |
383 | /// "test.consumer1"(%true) : (i1) -> () |
384 | /// ... |
385 | /// |
386 | /// ^falseB: |
387 | /// "test.consumer2"(%false) : (i1) -> () |
388 | /// ... |
389 | struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { |
390 | using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
391 | |
392 | LogicalResult matchAndRewrite(CondBranchOp condbr, |
393 | PatternRewriter &rewriter) const override { |
394 | // Check that we have a single distinct predecessor. |
395 | bool replaced = false; |
396 | Type ty = rewriter.getI1Type(); |
397 | |
398 | // These variables serve to prevent creating duplicate constants |
399 | // and hold constant true or false values. |
400 | Value constantTrue = nullptr; |
401 | Value constantFalse = nullptr; |
402 | |
403 | // TODO These checks can be expanded to encompas any use with only |
404 | // either the true of false edge as a predecessor. For now, we fall |
405 | // back to checking the single predecessor is given by the true/fasle |
406 | // destination, thereby ensuring that only that edge can reach the |
407 | // op. |
408 | if (condbr.getTrueDest()->getSinglePredecessor()) { |
409 | for (OpOperand &use : |
410 | llvm::make_early_inc_range(condbr.getCondition().getUses())) { |
411 | if (use.getOwner()->getBlock() == condbr.getTrueDest()) { |
412 | replaced = true; |
413 | |
414 | if (!constantTrue) |
415 | constantTrue = rewriter.create<arith::ConstantOp>( |
416 | condbr.getLoc(), ty, rewriter.getBoolAttr(true)); |
417 | |
418 | rewriter.modifyOpInPlace(use.getOwner(), |
419 | [&] { use.set(constantTrue); }); |
420 | } |
421 | } |
422 | } |
423 | if (condbr.getFalseDest()->getSinglePredecessor()) { |
424 | for (OpOperand &use : |
425 | llvm::make_early_inc_range(condbr.getCondition().getUses())) { |
426 | if (use.getOwner()->getBlock() == condbr.getFalseDest()) { |
427 | replaced = true; |
428 | |
429 | if (!constantFalse) |
430 | constantFalse = rewriter.create<arith::ConstantOp>( |
431 | condbr.getLoc(), ty, rewriter.getBoolAttr(false)); |
432 | |
433 | rewriter.modifyOpInPlace(use.getOwner(), |
434 | [&] { use.set(constantFalse); }); |
435 | } |
436 | } |
437 | } |
438 | return success(IsSuccess: replaced); |
439 | } |
440 | }; |
441 | } // namespace |
442 | |
443 | void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
444 | MLIRContext *context) { |
445 | results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, |
446 | SimplifyCondBranchIdenticalSuccessors, |
447 | SimplifyCondBranchFromCondBranchOnSameCondition, |
448 | CondBranchTruthPropagation>(context); |
449 | } |
450 | |
451 | SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { |
452 | assert(index < getNumSuccessors() && "invalid successor index"); |
453 | return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable() |
454 | : getFalseDestOperandsMutable()); |
455 | } |
456 | |
457 | Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { |
458 | if (IntegerAttr condAttr = |
459 | llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) |
460 | return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest(); |
461 | return nullptr; |
462 | } |
463 | |
464 | //===----------------------------------------------------------------------===// |
465 | // SwitchOp |
466 | //===----------------------------------------------------------------------===// |
467 | |
468 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
469 | Block *defaultDestination, ValueRange defaultOperands, |
470 | DenseIntElementsAttr caseValues, |
471 | BlockRange caseDestinations, |
472 | ArrayRef<ValueRange> caseOperands) { |
473 | build(builder, result, value, defaultOperands, caseOperands, caseValues, |
474 | defaultDestination, caseDestinations); |
475 | } |
476 | |
477 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
478 | Block *defaultDestination, ValueRange defaultOperands, |
479 | ArrayRef<APInt> caseValues, BlockRange caseDestinations, |
480 | ArrayRef<ValueRange> caseOperands) { |
481 | DenseIntElementsAttr caseValuesAttr; |
482 | if (!caseValues.empty()) { |
483 | ShapedType caseValueType = VectorType::get( |
484 | static_cast<int64_t>(caseValues.size()), value.getType()); |
485 | caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
486 | } |
487 | build(builder, result, value, defaultDestination, defaultOperands, |
488 | caseValuesAttr, caseDestinations, caseOperands); |
489 | } |
490 | |
491 | void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, |
492 | Block *defaultDestination, ValueRange defaultOperands, |
493 | ArrayRef<int32_t> caseValues, BlockRange caseDestinations, |
494 | ArrayRef<ValueRange> caseOperands) { |
495 | DenseIntElementsAttr caseValuesAttr; |
496 | if (!caseValues.empty()) { |
497 | ShapedType caseValueType = VectorType::get( |
498 | static_cast<int64_t>(caseValues.size()), value.getType()); |
499 | caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); |
500 | } |
501 | build(builder, result, value, defaultDestination, defaultOperands, |
502 | caseValuesAttr, caseDestinations, caseOperands); |
503 | } |
504 | |
505 | /// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? |
506 | /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* |
507 | static ParseResult parseSwitchOpCases( |
508 | OpAsmParser &parser, Type &flagType, Block *&defaultDestination, |
509 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands, |
510 | SmallVectorImpl<Type> &defaultOperandTypes, |
511 | DenseIntElementsAttr &caseValues, |
512 | SmallVectorImpl<Block *> &caseDestinations, |
513 | SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands, |
514 | SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) { |
515 | if (parser.parseKeyword(keyword: "default") || parser.parseColon() || |
516 | parser.parseSuccessor(dest&: defaultDestination)) |
517 | return failure(); |
518 | if (succeeded(Result: parser.parseOptionalLParen())) { |
519 | if (parser.parseOperandList(result&: defaultOperands, delimiter: OpAsmParser::Delimiter::None, |
520 | /*allowResultNumber=*/false) || |
521 | parser.parseColonTypeList(result&: defaultOperandTypes) || parser.parseRParen()) |
522 | return failure(); |
523 | } |
524 | |
525 | SmallVector<APInt> values; |
526 | unsigned bitWidth = flagType.getIntOrFloatBitWidth(); |
527 | while (succeeded(Result: parser.parseOptionalComma())) { |
528 | int64_t value = 0; |
529 | if (failed(Result: parser.parseInteger(result&: value))) |
530 | return failure(); |
531 | values.push_back(Elt: APInt(bitWidth, value, /*isSigned=*/true)); |
532 | |
533 | Block *destination; |
534 | SmallVector<OpAsmParser::UnresolvedOperand> operands; |
535 | SmallVector<Type> operandTypes; |
536 | if (failed(Result: parser.parseColon()) || |
537 | failed(Result: parser.parseSuccessor(dest&: destination))) |
538 | return failure(); |
539 | if (succeeded(Result: parser.parseOptionalLParen())) { |
540 | if (failed(Result: parser.parseOperandList(result&: operands, |
541 | delimiter: OpAsmParser::Delimiter::None)) || |
542 | failed(Result: parser.parseColonTypeList(result&: operandTypes)) || |
543 | failed(Result: parser.parseRParen())) |
544 | return failure(); |
545 | } |
546 | caseDestinations.push_back(Elt: destination); |
547 | caseOperands.emplace_back(Args&: operands); |
548 | caseOperandTypes.emplace_back(Args&: operandTypes); |
549 | } |
550 | |
551 | if (!values.empty()) { |
552 | ShapedType caseValueType = |
553 | VectorType::get(static_cast<int64_t>(values.size()), flagType); |
554 | caseValues = DenseIntElementsAttr::get(caseValueType, values); |
555 | } |
556 | return success(); |
557 | } |
558 | |
559 | static void printSwitchOpCases( |
560 | OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, |
561 | OperandRange defaultOperands, TypeRange defaultOperandTypes, |
562 | DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, |
563 | OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { |
564 | p << " default: "; |
565 | p.printSuccessorAndUseList(successor: defaultDestination, succOperands: defaultOperands); |
566 | |
567 | if (!caseValues) |
568 | return; |
569 | |
570 | for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) { |
571 | p << ','; |
572 | p.printNewline(); |
573 | p << " "; |
574 | p << it.value().getLimitedValue(); |
575 | p << ": "; |
576 | p.printSuccessorAndUseList(caseDestinations[it.index()], |
577 | caseOperands[it.index()]); |
578 | } |
579 | p.printNewline(); |
580 | } |
581 | |
582 | LogicalResult SwitchOp::verify() { |
583 | auto caseValues = getCaseValues(); |
584 | auto caseDestinations = getCaseDestinations(); |
585 | |
586 | if (!caseValues && caseDestinations.empty()) |
587 | return success(); |
588 | |
589 | Type flagType = getFlag().getType(); |
590 | Type caseValueType = caseValues->getType().getElementType(); |
591 | if (caseValueType != flagType) |
592 | return emitOpError() << "'flag' type ("<< flagType |
593 | << ") should match case value type ("<< caseValueType |
594 | << ")"; |
595 | |
596 | if (caseValues && |
597 | caseValues->size() != static_cast<int64_t>(caseDestinations.size())) |
598 | return emitOpError() << "number of case values ("<< caseValues->size() |
599 | << ") should match number of " |
600 | "case destinations (" |
601 | << caseDestinations.size() << ")"; |
602 | return success(); |
603 | } |
604 | |
605 | SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { |
606 | assert(index < getNumSuccessors() && "invalid successor index"); |
607 | return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() |
608 | : getCaseOperandsMutable(index - 1)); |
609 | } |
610 | |
611 | Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { |
612 | std::optional<DenseIntElementsAttr> caseValues = getCaseValues(); |
613 | |
614 | if (!caseValues) |
615 | return getDefaultDestination(); |
616 | |
617 | SuccessorRange caseDests = getCaseDestinations(); |
618 | if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) { |
619 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) |
620 | if (it.value() == value.getValue()) |
621 | return caseDests[it.index()]; |
622 | return getDefaultDestination(); |
623 | } |
624 | return nullptr; |
625 | } |
626 | |
627 | /// switch %flag : i32, [ |
628 | /// default: ^bb1 |
629 | /// ] |
630 | /// -> br ^bb1 |
631 | static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, |
632 | PatternRewriter &rewriter) { |
633 | if (!op.getCaseDestinations().empty()) |
634 | return failure(); |
635 | |
636 | rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), |
637 | op.getDefaultOperands()); |
638 | return success(); |
639 | } |
640 | |
641 | /// switch %flag : i32, [ |
642 | /// default: ^bb1, |
643 | /// 42: ^bb1, |
644 | /// 43: ^bb2 |
645 | /// ] |
646 | /// -> |
647 | /// switch %flag : i32, [ |
648 | /// default: ^bb1, |
649 | /// 43: ^bb2 |
650 | /// ] |
651 | static LogicalResult |
652 | dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { |
653 | SmallVector<Block *> newCaseDestinations; |
654 | SmallVector<ValueRange> newCaseOperands; |
655 | SmallVector<APInt> newCaseValues; |
656 | bool requiresChange = false; |
657 | auto caseValues = op.getCaseValues(); |
658 | auto caseDests = op.getCaseDestinations(); |
659 | |
660 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { |
661 | if (caseDests[it.index()] == op.getDefaultDestination() && |
662 | op.getCaseOperands(it.index()) == op.getDefaultOperands()) { |
663 | requiresChange = true; |
664 | continue; |
665 | } |
666 | newCaseDestinations.push_back(caseDests[it.index()]); |
667 | newCaseOperands.push_back(op.getCaseOperands(it.index())); |
668 | newCaseValues.push_back(it.value()); |
669 | } |
670 | |
671 | if (!requiresChange) |
672 | return failure(); |
673 | |
674 | rewriter.replaceOpWithNewOp<SwitchOp>( |
675 | op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), |
676 | newCaseValues, newCaseDestinations, newCaseOperands); |
677 | return success(); |
678 | } |
679 | |
680 | /// Helper for folding a switch with a constant value. |
681 | /// switch %c_42 : i32, [ |
682 | /// default: ^bb1 , |
683 | /// 42: ^bb2, |
684 | /// 43: ^bb3 |
685 | /// ] |
686 | /// -> br ^bb2 |
687 | static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, |
688 | const APInt &caseValue) { |
689 | auto caseValues = op.getCaseValues(); |
690 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { |
691 | if (it.value() == caseValue) { |
692 | rewriter.replaceOpWithNewOp<BranchOp>( |
693 | op, op.getCaseDestinations()[it.index()], |
694 | op.getCaseOperands(it.index())); |
695 | return; |
696 | } |
697 | } |
698 | rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), |
699 | op.getDefaultOperands()); |
700 | } |
701 | |
702 | /// switch %c_42 : i32, [ |
703 | /// default: ^bb1, |
704 | /// 42: ^bb2, |
705 | /// 43: ^bb3 |
706 | /// ] |
707 | /// -> br ^bb2 |
708 | static LogicalResult simplifyConstSwitchValue(SwitchOp op, |
709 | PatternRewriter &rewriter) { |
710 | APInt caseValue; |
711 | if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) |
712 | return failure(); |
713 | |
714 | foldSwitch(op, rewriter, caseValue); |
715 | return success(); |
716 | } |
717 | |
718 | /// switch %c_42 : i32, [ |
719 | /// default: ^bb1, |
720 | /// 42: ^bb2, |
721 | /// ] |
722 | /// ^bb2: |
723 | /// br ^bb3 |
724 | /// -> |
725 | /// switch %c_42 : i32, [ |
726 | /// default: ^bb1, |
727 | /// 42: ^bb3, |
728 | /// ] |
729 | static LogicalResult simplifyPassThroughSwitch(SwitchOp op, |
730 | PatternRewriter &rewriter) { |
731 | SmallVector<Block *> newCaseDests; |
732 | SmallVector<ValueRange> newCaseOperands; |
733 | SmallVector<SmallVector<Value>> argStorage; |
734 | auto caseValues = op.getCaseValues(); |
735 | argStorage.reserve(N: caseValues->size() + 1); |
736 | auto caseDests = op.getCaseDestinations(); |
737 | bool requiresChange = false; |
738 | for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { |
739 | Block *caseDest = caseDests[i]; |
740 | ValueRange caseOperands = op.getCaseOperands(i); |
741 | argStorage.emplace_back(); |
742 | if (succeeded(Result: collapseBranch(successor&: caseDest, successorOperands&: caseOperands, argStorage&: argStorage.back()))) |
743 | requiresChange = true; |
744 | |
745 | newCaseDests.push_back(Elt: caseDest); |
746 | newCaseOperands.push_back(Elt: caseOperands); |
747 | } |
748 | |
749 | Block *defaultDest = op.getDefaultDestination(); |
750 | ValueRange defaultOperands = op.getDefaultOperands(); |
751 | argStorage.emplace_back(); |
752 | |
753 | if (succeeded( |
754 | Result: collapseBranch(successor&: defaultDest, successorOperands&: defaultOperands, argStorage&: argStorage.back()))) |
755 | requiresChange = true; |
756 | |
757 | if (!requiresChange) |
758 | return failure(); |
759 | |
760 | rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest, |
761 | defaultOperands, *caseValues, |
762 | newCaseDests, newCaseOperands); |
763 | return success(); |
764 | } |
765 | |
766 | /// switch %flag : i32, [ |
767 | /// default: ^bb1, |
768 | /// 42: ^bb2, |
769 | /// ] |
770 | /// ^bb2: |
771 | /// switch %flag : i32, [ |
772 | /// default: ^bb3, |
773 | /// 42: ^bb4 |
774 | /// ] |
775 | /// -> |
776 | /// switch %flag : i32, [ |
777 | /// default: ^bb1, |
778 | /// 42: ^bb2, |
779 | /// ] |
780 | /// ^bb2: |
781 | /// br ^bb4 |
782 | /// |
783 | /// and |
784 | /// |
785 | /// switch %flag : i32, [ |
786 | /// default: ^bb1, |
787 | /// 42: ^bb2, |
788 | /// ] |
789 | /// ^bb2: |
790 | /// switch %flag : i32, [ |
791 | /// default: ^bb3, |
792 | /// 43: ^bb4 |
793 | /// ] |
794 | /// -> |
795 | /// switch %flag : i32, [ |
796 | /// default: ^bb1, |
797 | /// 42: ^bb2, |
798 | /// ] |
799 | /// ^bb2: |
800 | /// br ^bb3 |
801 | static LogicalResult |
802 | simplifySwitchFromSwitchOnSameCondition(SwitchOp op, |
803 | PatternRewriter &rewriter) { |
804 | // Check that we have a single distinct predecessor. |
805 | Block *currentBlock = op->getBlock(); |
806 | Block *predecessor = currentBlock->getSinglePredecessor(); |
807 | if (!predecessor) |
808 | return failure(); |
809 | |
810 | // Check that the predecessor terminates with a switch branch to this block |
811 | // and that it branches on the same condition and that this branch isn't the |
812 | // default destination. |
813 | auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); |
814 | if (!predSwitch || op.getFlag() != predSwitch.getFlag() || |
815 | predSwitch.getDefaultDestination() == currentBlock) |
816 | return failure(); |
817 | |
818 | // Fold this switch to an unconditional branch. |
819 | SuccessorRange predDests = predSwitch.getCaseDestinations(); |
820 | auto it = llvm::find(Range&: predDests, Val: currentBlock); |
821 | if (it != predDests.end()) { |
822 | std::optional<DenseIntElementsAttr> predCaseValues = |
823 | predSwitch.getCaseValues(); |
824 | foldSwitch(op, rewriter, |
825 | predCaseValues->getValues<APInt>()[it - predDests.begin()]); |
826 | } else { |
827 | rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(), |
828 | op.getDefaultOperands()); |
829 | } |
830 | return success(); |
831 | } |
832 | |
833 | /// switch %flag : i32, [ |
834 | /// default: ^bb1, |
835 | /// 42: ^bb2 |
836 | /// ] |
837 | /// ^bb1: |
838 | /// switch %flag : i32, [ |
839 | /// default: ^bb3, |
840 | /// 42: ^bb4, |
841 | /// 43: ^bb5 |
842 | /// ] |
843 | /// -> |
844 | /// switch %flag : i32, [ |
845 | /// default: ^bb1, |
846 | /// 42: ^bb2, |
847 | /// ] |
848 | /// ^bb1: |
849 | /// switch %flag : i32, [ |
850 | /// default: ^bb3, |
851 | /// 43: ^bb5 |
852 | /// ] |
853 | static LogicalResult |
854 | simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, |
855 | PatternRewriter &rewriter) { |
856 | // Check that we have a single distinct predecessor. |
857 | Block *currentBlock = op->getBlock(); |
858 | Block *predecessor = currentBlock->getSinglePredecessor(); |
859 | if (!predecessor) |
860 | return failure(); |
861 | |
862 | // Check that the predecessor terminates with a switch branch to this block |
863 | // and that it branches on the same condition and that this branch is the |
864 | // default destination. |
865 | auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator()); |
866 | if (!predSwitch || op.getFlag() != predSwitch.getFlag() || |
867 | predSwitch.getDefaultDestination() != currentBlock) |
868 | return failure(); |
869 | |
870 | // Delete case values that are not possible here. |
871 | DenseSet<APInt> caseValuesToRemove; |
872 | auto predDests = predSwitch.getCaseDestinations(); |
873 | auto predCaseValues = predSwitch.getCaseValues(); |
874 | for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) |
875 | if (currentBlock != predDests[i]) |
876 | caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]); |
877 | |
878 | SmallVector<Block *> newCaseDestinations; |
879 | SmallVector<ValueRange> newCaseOperands; |
880 | SmallVector<APInt> newCaseValues; |
881 | bool requiresChange = false; |
882 | |
883 | auto caseValues = op.getCaseValues(); |
884 | auto caseDests = op.getCaseDestinations(); |
885 | for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) { |
886 | if (caseValuesToRemove.contains(it.value())) { |
887 | requiresChange = true; |
888 | continue; |
889 | } |
890 | newCaseDestinations.push_back(caseDests[it.index()]); |
891 | newCaseOperands.push_back(op.getCaseOperands(it.index())); |
892 | newCaseValues.push_back(it.value()); |
893 | } |
894 | |
895 | if (!requiresChange) |
896 | return failure(); |
897 | |
898 | rewriter.replaceOpWithNewOp<SwitchOp>( |
899 | op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), |
900 | newCaseValues, newCaseDestinations, newCaseOperands); |
901 | return success(); |
902 | } |
903 | |
904 | void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
905 | MLIRContext *context) { |
906 | results.add(&simplifySwitchWithOnlyDefault) |
907 | .add(&dropSwitchCasesThatMatchDefault) |
908 | .add(&simplifyConstSwitchValue) |
909 | .add(&simplifyPassThroughSwitch) |
910 | .add(&simplifySwitchFromSwitchOnSameCondition) |
911 | .add(&simplifySwitchFromDefaultSwitchOnSameCondition); |
912 | } |
913 | |
914 | //===----------------------------------------------------------------------===// |
915 | // TableGen'd op method definitions |
916 | //===----------------------------------------------------------------------===// |
917 | |
918 | #define GET_OP_CLASSES |
919 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" |
920 |
Definitions
- ControlFlowInlinerInterface
- ~ControlFlowInlinerInterface
- isLegalToInline
- isLegalToInline
- handleTerminator
- collapseBranch
- simplifyBrToBlockWithSinglePred
- simplifyPassThroughBr
- SimplifyConstCondBranchPred
- matchAndRewrite
- SimplifyPassThroughCondBranch
- matchAndRewrite
- SimplifyCondBranchIdenticalSuccessors
- matchAndRewrite
- SimplifyCondBranchFromCondBranchOnSameCondition
- matchAndRewrite
- CondBranchTruthPropagation
- matchAndRewrite
- parseSwitchOpCases
- printSwitchOpCases
- simplifySwitchWithOnlyDefault
- dropSwitchCasesThatMatchDefault
- foldSwitch
- simplifyConstSwitchValue
- simplifyPassThroughSwitch
- simplifySwitchFromSwitchOnSameCondition
Learn to use CMake with our Intro Training
Find out more