| 1 | //===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Transforms SCF.WhileOp's into SCF.ForOp's. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 17 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 18 | #include "mlir/IR/Dominance.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | |
| 21 | using namespace mlir; |
| 22 | |
| 23 | namespace { |
| 24 | struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { |
| 25 | using OpRewritePattern::OpRewritePattern; |
| 26 | |
| 27 | LogicalResult matchAndRewrite(scf::WhileOp loop, |
| 28 | PatternRewriter &rewriter) const override { |
| 29 | return upliftWhileToForLoop(rewriter, loop); |
| 30 | } |
| 31 | }; |
| 32 | } // namespace |
| 33 | |
| 34 | FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, |
| 35 | scf::WhileOp loop) { |
| 36 | Block *beforeBody = loop.getBeforeBody(); |
| 37 | if (!llvm::hasSingleElement(C: beforeBody->without_terminator())) |
| 38 | return rewriter.notifyMatchFailure(loop, "Loop body must have single op" ); |
| 39 | |
| 40 | auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front()); |
| 41 | if (!cmp) |
| 42 | return rewriter.notifyMatchFailure(loop, |
| 43 | "Loop body must have single cmp op" ); |
| 44 | |
| 45 | scf::ConditionOp beforeTerm = loop.getConditionOp(); |
| 46 | if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult()) |
| 47 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 48 | diag << "Expected single condition use: " << *cmp; |
| 49 | }); |
| 50 | |
| 51 | // If all 'before' arguments are forwarded but the order is different from |
| 52 | // 'after' arguments, here is the mapping from the 'after' argument index to |
| 53 | // the 'before' argument index. |
| 54 | std::optional<SmallVector<unsigned>> argReorder; |
| 55 | // All `before` block args must be directly forwarded to ConditionOp. |
| 56 | // They will be converted to `scf.for` `iter_vars` except induction var. |
| 57 | if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) { |
| 58 | auto getArgReordering = |
| 59 | [](Block *beforeBody, |
| 60 | scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> { |
| 61 | // Skip further checking if their sizes mismatch. |
| 62 | if (beforeBody->getNumArguments() != cond.getArgs().size()) |
| 63 | return std::nullopt; |
| 64 | // Bitset on which 'before' argument is forwarded. |
| 65 | llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false); |
| 66 | // The forwarding order of 'before' arguments. |
| 67 | SmallVector<unsigned> order; |
| 68 | for (Value a : cond.getArgs()) { |
| 69 | BlockArgument arg = dyn_cast<BlockArgument>(a); |
| 70 | // Skip if 'arg' is not a 'before' argument. |
| 71 | if (!arg || arg.getOwner() != beforeBody) |
| 72 | return std::nullopt; |
| 73 | unsigned idx = arg.getArgNumber(); |
| 74 | // Skip if 'arg' is already forwarded in another place. |
| 75 | if (forwarded[idx]) |
| 76 | return std::nullopt; |
| 77 | // Record the presence of 'arg' and its order. |
| 78 | forwarded[idx] = true; |
| 79 | order.push_back(idx); |
| 80 | } |
| 81 | // Skip if not all 'before' arguments are forwarded. |
| 82 | if (!forwarded.all()) |
| 83 | return std::nullopt; |
| 84 | return order; |
| 85 | }; |
| 86 | // Check if 'before' arguments are all forwarded but just reordered. |
| 87 | argReorder = getArgReordering(beforeBody, beforeTerm); |
| 88 | if (!argReorder) |
| 89 | return rewriter.notifyMatchFailure(loop, "Invalid args order" ); |
| 90 | } |
| 91 | |
| 92 | using Pred = arith::CmpIPredicate; |
| 93 | Pred predicate = cmp.getPredicate(); |
| 94 | if (predicate != Pred::slt && predicate != Pred::sgt) |
| 95 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 96 | diag << "Expected 'slt' or 'sgt' predicate: " << *cmp; |
| 97 | }); |
| 98 | |
| 99 | BlockArgument inductionVar; |
| 100 | Value ub; |
| 101 | DominanceInfo dom; |
| 102 | |
| 103 | // Check if cmp has a suitable form. One of the arguments must be a `before` |
| 104 | // block arg, other must be defined outside `scf.while` and will be treated |
| 105 | // as upper bound. |
| 106 | for (bool reverse : {false, true}) { |
| 107 | auto expectedPred = reverse ? Pred::sgt : Pred::slt; |
| 108 | if (cmp.getPredicate() != expectedPred) |
| 109 | continue; |
| 110 | |
| 111 | auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs(); |
| 112 | auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs(); |
| 113 | |
| 114 | auto blockArg = dyn_cast<BlockArgument>(arg1); |
| 115 | if (!blockArg || blockArg.getOwner() != beforeBody) |
| 116 | continue; |
| 117 | |
| 118 | if (!dom.properlyDominates(arg2, loop)) |
| 119 | continue; |
| 120 | |
| 121 | inductionVar = blockArg; |
| 122 | ub = arg2; |
| 123 | break; |
| 124 | } |
| 125 | |
| 126 | if (!inductionVar) |
| 127 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 128 | diag << "Unrecognized cmp form: " << *cmp; |
| 129 | }); |
| 130 | |
| 131 | // inductionVar must have 2 uses: one is in `cmp` and other is `condition` |
| 132 | // arg. |
| 133 | if (!llvm::hasNItems(C: inductionVar.getUses(), N: 2)) |
| 134 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
| 135 | diag << "Unrecognized induction var: " << inductionVar; |
| 136 | }); |
| 137 | |
| 138 | Block *afterBody = loop.getAfterBody(); |
| 139 | scf::YieldOp afterTerm = loop.getYieldOp(); |
| 140 | unsigned argNumber = inductionVar.getArgNumber(); |
| 141 | Value afterTermIndArg = afterTerm.getResults()[argNumber]; |
| 142 | |
| 143 | auto findAfterArgNo = [](ArrayRef<unsigned> indices, unsigned beforeArgNo) { |
| 144 | return std::distance(first: indices.begin(), |
| 145 | last: llvm::find_if(Range&: indices, P: [beforeArgNo](unsigned n) { |
| 146 | return n == beforeArgNo; |
| 147 | })); |
| 148 | }; |
| 149 | Value inductionVarAfter = afterBody->getArgument( |
| 150 | i: argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber); |
| 151 | |
| 152 | // Find suitable `addi` op inside `after` block, one of the args must be an |
| 153 | // Induction var passed from `before` block and second arg must be defined |
| 154 | // outside of the loop and will be considered step value. |
| 155 | // TODO: Add `subi` support? |
| 156 | auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>(); |
| 157 | if (!addOp) |
| 158 | return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op" ); |
| 159 | |
| 160 | Value step; |
| 161 | if (addOp.getLhs() == inductionVarAfter) { |
| 162 | step = addOp.getRhs(); |
| 163 | } else if (addOp.getRhs() == inductionVarAfter) { |
| 164 | step = addOp.getLhs(); |
| 165 | } |
| 166 | |
| 167 | if (!step || !dom.properlyDominates(step, loop)) |
| 168 | return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form" ); |
| 169 | |
| 170 | Value lb = loop.getInits()[argNumber]; |
| 171 | |
| 172 | assert(lb.getType().isIntOrIndex()); |
| 173 | assert(lb.getType() == ub.getType()); |
| 174 | assert(lb.getType() == step.getType()); |
| 175 | |
| 176 | SmallVector<Value> newArgs; |
| 177 | |
| 178 | // Populate inits for new `scf.for`, skip induction var. |
| 179 | newArgs.reserve(N: loop.getInits().size()); |
| 180 | for (auto &&[i, init] : llvm::enumerate(loop.getInits())) { |
| 181 | if (i == argNumber) |
| 182 | continue; |
| 183 | |
| 184 | newArgs.emplace_back(init); |
| 185 | } |
| 186 | |
| 187 | Location loc = loop.getLoc(); |
| 188 | |
| 189 | // With `builder == nullptr`, ForOp::build will try to insert terminator at |
| 190 | // the end of newly created block and we don't want it. Provide empty |
| 191 | // dummy builder instead. |
| 192 | auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; |
| 193 | auto newLoop = |
| 194 | rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder); |
| 195 | |
| 196 | Block *newBody = newLoop.getBody(); |
| 197 | |
| 198 | // Populate block args for `scf.for` body, move induction var to the front. |
| 199 | newArgs.clear(); |
| 200 | ValueRange newBodyArgs = newBody->getArguments(); |
| 201 | for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) { |
| 202 | if (i < argNumber) { |
| 203 | newArgs.emplace_back(newBodyArgs[i + 1]); |
| 204 | } else if (i == argNumber) { |
| 205 | newArgs.emplace_back(newBodyArgs.front()); |
| 206 | } else { |
| 207 | newArgs.emplace_back(newBodyArgs[i]); |
| 208 | } |
| 209 | } |
| 210 | if (argReorder) { |
| 211 | // Reorder arguments following the 'after' argument order from the original |
| 212 | // 'while' loop. |
| 213 | SmallVector<Value> args; |
| 214 | for (unsigned order : *argReorder) |
| 215 | args.push_back(Elt: newArgs[order]); |
| 216 | newArgs = args; |
| 217 | } |
| 218 | |
| 219 | rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(), |
| 220 | newArgs); |
| 221 | |
| 222 | auto term = cast<scf::YieldOp>(newBody->getTerminator()); |
| 223 | |
| 224 | // Populate new yield args, skipping the induction var. |
| 225 | newArgs.clear(); |
| 226 | for (auto &&[i, arg] : llvm::enumerate(term.getResults())) { |
| 227 | if (i == argNumber) |
| 228 | continue; |
| 229 | |
| 230 | newArgs.emplace_back(arg); |
| 231 | } |
| 232 | |
| 233 | OpBuilder::InsertionGuard g(rewriter); |
| 234 | rewriter.setInsertionPoint(term); |
| 235 | rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs); |
| 236 | |
| 237 | // Compute induction var value after loop execution. |
| 238 | rewriter.setInsertionPointAfter(newLoop); |
| 239 | Value one; |
| 240 | if (isa<IndexType>(Val: step.getType())) { |
| 241 | one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 242 | } else { |
| 243 | one = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, args: step.getType()); |
| 244 | } |
| 245 | |
| 246 | Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one); |
| 247 | Value len = rewriter.create<arith::SubIOp>(loc, ub, lb); |
| 248 | len = rewriter.create<arith::AddIOp>(loc, len, stepDec); |
| 249 | len = rewriter.create<arith::DivSIOp>(loc, len, step); |
| 250 | len = rewriter.create<arith::SubIOp>(loc, len, one); |
| 251 | Value res = rewriter.create<arith::MulIOp>(loc, len, step); |
| 252 | res = rewriter.create<arith::AddIOp>(loc, lb, res); |
| 253 | |
| 254 | // Reconstruct `scf.while` results, inserting final induction var value |
| 255 | // into proper place. |
| 256 | newArgs.clear(); |
| 257 | llvm::append_range(newArgs, newLoop.getResults()); |
| 258 | newArgs.insert(I: newArgs.begin() + argNumber, Elt: res); |
| 259 | if (argReorder) { |
| 260 | // Reorder arguments following the 'after' argument order from the original |
| 261 | // 'while' loop. |
| 262 | SmallVector<Value> results; |
| 263 | for (unsigned order : *argReorder) |
| 264 | results.push_back(Elt: newArgs[order]); |
| 265 | newArgs = results; |
| 266 | } |
| 267 | rewriter.replaceOp(loop, newArgs); |
| 268 | return newLoop; |
| 269 | } |
| 270 | |
| 271 | void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { |
| 272 | patterns.add<UpliftWhileOp>(arg: patterns.getContext()); |
| 273 | } |
| 274 | |