| 1 | //===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 10 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 11 | #include "mlir/IR/IRMapping.h" |
| 12 | #include "mlir/IR/PatternMatch.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | |
| 16 | /// Create zero-trip-check around a `while` op and return the new loop op in the |
| 17 | /// check. The while loop is rotated to avoid evaluating the condition twice. |
| 18 | /// |
| 19 | /// Given an example below: |
| 20 | /// |
| 21 | /// scf.while (%arg0 = %init) : (i32) -> i64 { |
| 22 | /// %val = .., %arg0 : i64 |
| 23 | /// %cond = arith.cmpi .., %arg0 : i32 |
| 24 | /// scf.condition(%cond) %val : i64 |
| 25 | /// } do { |
| 26 | /// ^bb0(%arg1: i64): |
| 27 | /// %next = .., %arg1 : i32 |
| 28 | /// scf.yield %next : i32 |
| 29 | /// } |
| 30 | /// |
| 31 | /// First clone before block to the front of the loop: |
| 32 | /// |
| 33 | /// %pre_val = .., %init : i64 |
| 34 | /// %pre_cond = arith.cmpi .., %init : i32 |
| 35 | /// scf.while (%arg0 = %init) : (i32) -> i64 { |
| 36 | /// %val = .., %arg0 : i64 |
| 37 | /// %cond = arith.cmpi .., %arg0 : i32 |
| 38 | /// scf.condition(%cond) %val : i64 |
| 39 | /// } do { |
| 40 | /// ^bb0(%arg1: i64): |
| 41 | /// %next = .., %arg1 : i32 |
| 42 | /// scf.yield %next : i32 |
| 43 | /// } |
| 44 | /// |
| 45 | /// Create `if` op with the condition, rotate and move the loop into the else |
| 46 | /// branch: |
| 47 | /// |
| 48 | /// %pre_val = .., %init : i64 |
| 49 | /// %pre_cond = arith.cmpi .., %init : i32 |
| 50 | /// scf.if %pre_cond -> i64 { |
| 51 | /// %res = scf.while (%arg1 = %va0) : (i64) -> i64 { |
| 52 | /// // Original after block |
| 53 | /// %next = .., %arg1 : i32 |
| 54 | /// // Original before block |
| 55 | /// %val = .., %next : i64 |
| 56 | /// %cond = arith.cmpi .., %next : i32 |
| 57 | /// scf.condition(%cond) %val : i64 |
| 58 | /// } do { |
| 59 | /// ^bb0(%arg2: i64): |
| 60 | /// %scf.yield %arg2 : i32 |
| 61 | /// } |
| 62 | /// scf.yield %res : i64 |
| 63 | /// } else { |
| 64 | /// scf.yield %pre_val : i64 |
| 65 | /// } |
| 66 | FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck( |
| 67 | scf::WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck) { |
| 68 | // If the loop is in do-while form (after block only passes through values), |
| 69 | // there is no need to create a zero-trip-check as before block is always run. |
| 70 | if (!forceCreateCheck && isa<scf::YieldOp>(whileOp.getAfterBody()->front())) { |
| 71 | return whileOp; |
| 72 | } |
| 73 | |
| 74 | OpBuilder::InsertionGuard insertion_guard(rewriter); |
| 75 | |
| 76 | IRMapping mapper; |
| 77 | Block *beforeBlock = whileOp.getBeforeBody(); |
| 78 | // Clone before block before the loop for zero-trip-check. |
| 79 | for (auto [arg, init] : |
| 80 | llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) { |
| 81 | mapper.map(arg, init); |
| 82 | } |
| 83 | rewriter.setInsertionPoint(whileOp); |
| 84 | for (auto &op : *beforeBlock) { |
| 85 | if (isa<scf::ConditionOp>(op)) { |
| 86 | break; |
| 87 | } |
| 88 | // Safe to clone everything as in a single block all defs have been cloned |
| 89 | // and added to mapper in order. |
| 90 | rewriter.insert(op.clone(mapper)); |
| 91 | } |
| 92 | |
| 93 | scf::ConditionOp condOp = whileOp.getConditionOp(); |
| 94 | Value clonedCondition = mapper.lookupOrDefault(condOp.getCondition()); |
| 95 | SmallVector<Value> clonedCondArgs = llvm::map_to_vector( |
| 96 | condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(from: arg); }); |
| 97 | |
| 98 | // Create rotated while loop. |
| 99 | auto newLoopOp = rewriter.create<scf::WhileOp>( |
| 100 | whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs, |
| 101 | [&](OpBuilder &builder, Location loc, ValueRange args) { |
| 102 | // Rotate and move the loop body into before block. |
| 103 | auto newBlock = builder.getBlock(); |
| 104 | rewriter.mergeBlocks(source: whileOp.getAfterBody(), dest: newBlock, argValues: args); |
| 105 | auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator()); |
| 106 | rewriter.mergeBlocks(source: whileOp.getBeforeBody(), dest: newBlock, |
| 107 | argValues: yieldOp.getResults()); |
| 108 | rewriter.eraseOp(op: yieldOp); |
| 109 | }, |
| 110 | [&](OpBuilder &builder, Location loc, ValueRange args) { |
| 111 | // Pass through values. |
| 112 | builder.create<scf::YieldOp>(loc, args); |
| 113 | }); |
| 114 | |
| 115 | // Create zero-trip-check and move the while loop in. |
| 116 | auto ifOp = rewriter.create<scf::IfOp>( |
| 117 | whileOp.getLoc(), clonedCondition, |
| 118 | [&](OpBuilder &builder, Location loc) { |
| 119 | // Then runs the while loop. |
| 120 | rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(), |
| 121 | builder.getInsertionPoint()); |
| 122 | builder.create<scf::YieldOp>(loc, newLoopOp.getResults()); |
| 123 | }, |
| 124 | [&](OpBuilder &builder, Location loc) { |
| 125 | // Else returns the results from precondition. |
| 126 | builder.create<scf::YieldOp>(loc, clonedCondArgs); |
| 127 | }); |
| 128 | |
| 129 | rewriter.replaceOp(whileOp, ifOp); |
| 130 | |
| 131 | return newLoopOp; |
| 132 | } |
| 133 | |