| 1 | //===- AffineCanonicalizationUtils.cpp - Affine Canonicalization in SCF ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Utility functions to canonicalize affine ops within SCF op regions. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include <utility> |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
| 16 | #include "mlir/Dialect/Affine/Analysis/Utils.h" |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
| 19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 20 | #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" |
| 21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 22 | #include "mlir/IR/AffineMap.h" |
| 23 | #include "mlir/IR/Matchers.h" |
| 24 | #include "mlir/IR/PatternMatch.h" |
| 25 | #include "llvm/Support/Debug.h" |
| 26 | |
| 27 | #define DEBUG_TYPE "mlir-scf-affine-utils" |
| 28 | |
| 29 | using namespace mlir; |
| 30 | using namespace affine; |
| 31 | using namespace presburger; |
| 32 | |
| 33 | LogicalResult scf::matchForLikeLoop(Value iv, OpFoldResult &lb, |
| 34 | OpFoldResult &ub, OpFoldResult &step) { |
| 35 | if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { |
| 36 | lb = forOp.getLowerBound(); |
| 37 | ub = forOp.getUpperBound(); |
| 38 | step = forOp.getStep(); |
| 39 | return success(); |
| 40 | } |
| 41 | if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { |
| 42 | for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { |
| 43 | if (parOp.getInductionVars()[idx] == iv) { |
| 44 | lb = parOp.getLowerBound()[idx]; |
| 45 | ub = parOp.getUpperBound()[idx]; |
| 46 | step = parOp.getStep()[idx]; |
| 47 | return success(); |
| 48 | } |
| 49 | } |
| 50 | return failure(); |
| 51 | } |
| 52 | if (scf::ForallOp forallOp = scf::getForallOpThreadIndexOwner(iv)) { |
| 53 | for (int64_t idx = 0; idx < forallOp.getRank(); ++idx) { |
| 54 | if (forallOp.getInductionVar(idx) == iv) { |
| 55 | lb = forallOp.getMixedLowerBound()[idx]; |
| 56 | ub = forallOp.getMixedUpperBound()[idx]; |
| 57 | step = forallOp.getMixedStep()[idx]; |
| 58 | return success(); |
| 59 | } |
| 60 | } |
| 61 | return failure(); |
| 62 | } |
| 63 | return failure(); |
| 64 | } |
| 65 | |
| 66 | static FailureOr<AffineApplyOp> |
| 67 | canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, |
| 68 | FlatAffineValueConstraints constraints) { |
| 69 | RewriterBase::InsertionGuard guard(rewriter); |
| 70 | rewriter.setInsertionPoint(op); |
| 71 | FailureOr<AffineValueMap> simplified = |
| 72 | affine::simplifyConstrainedMinMaxOp(op, constraints: std::move(constraints)); |
| 73 | if (failed(Result: simplified)) |
| 74 | return failure(); |
| 75 | return rewriter.replaceOpWithNewOp<AffineApplyOp>( |
| 76 | op, simplified->getAffineMap(), simplified->getOperands()); |
| 77 | } |
| 78 | |
| 79 | LogicalResult scf::addLoopRangeConstraints(FlatAffineValueConstraints &cstr, |
| 80 | Value iv, OpFoldResult lb, |
| 81 | OpFoldResult ub, OpFoldResult step) { |
| 82 | Builder b(iv.getContext()); |
| 83 | |
| 84 | // IntegerPolyhedron does not support semi-affine expressions. |
| 85 | // Therefore, only constant step values are supported. |
| 86 | auto stepInt = getConstantIntValue(ofr: step); |
| 87 | if (!stepInt) |
| 88 | return failure(); |
| 89 | |
| 90 | unsigned dimIv = cstr.appendDimVar(vals: iv); |
| 91 | auto lbv = llvm::dyn_cast_if_present<Value>(Val&: lb); |
| 92 | unsigned symLb = |
| 93 | lbv ? cstr.appendSymbolVar(vals: lbv) : cstr.appendSymbolVar(/*num=*/1); |
| 94 | auto ubv = llvm::dyn_cast_if_present<Value>(Val&: ub); |
| 95 | unsigned symUb = |
| 96 | ubv ? cstr.appendSymbolVar(vals: ubv) : cstr.appendSymbolVar(/*num=*/1); |
| 97 | |
| 98 | // If loop lower/upper bounds are constant: Add EQ constraint. |
| 99 | std::optional<int64_t> lbInt = getConstantIntValue(ofr: lb); |
| 100 | std::optional<int64_t> ubInt = getConstantIntValue(ofr: ub); |
| 101 | if (lbInt) |
| 102 | cstr.addBound(type: BoundType::EQ, pos: symLb, value: *lbInt); |
| 103 | if (ubInt) |
| 104 | cstr.addBound(type: BoundType::EQ, pos: symUb, value: *ubInt); |
| 105 | |
| 106 | // Lower bound: iv >= lb (equiv.: iv - lb >= 0) |
| 107 | SmallVector<int64_t> ineqLb(cstr.getNumCols(), 0); |
| 108 | ineqLb[dimIv] = 1; |
| 109 | ineqLb[symLb] = -1; |
| 110 | cstr.addInequality(inEq: ineqLb); |
| 111 | |
| 112 | // Upper bound |
| 113 | AffineExpr ivUb; |
| 114 | if (lbInt && ubInt && (*lbInt + *stepInt >= *ubInt)) { |
| 115 | // The loop has at most one iteration. |
| 116 | // iv < lb + 1 |
| 117 | // TODO: Try to derive this constraint by simplifying the expression in |
| 118 | // the else-branch. |
| 119 | ivUb = b.getAffineSymbolExpr(position: symLb - cstr.getNumDimVars()) + 1; |
| 120 | } else { |
| 121 | // The loop may have more than one iteration. |
| 122 | // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 |
| 123 | AffineExpr exprLb = |
| 124 | lbInt ? b.getAffineConstantExpr(constant: *lbInt) |
| 125 | : b.getAffineSymbolExpr(position: symLb - cstr.getNumDimVars()); |
| 126 | AffineExpr exprUb = |
| 127 | ubInt ? b.getAffineConstantExpr(constant: *ubInt) |
| 128 | : b.getAffineSymbolExpr(position: symUb - cstr.getNumDimVars()); |
| 129 | ivUb = exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(v: *stepInt))); |
| 130 | } |
| 131 | auto map = AffineMap::get( |
| 132 | /*dimCount=*/cstr.getNumDimVars(), |
| 133 | /*symbolCount=*/cstr.getNumSymbolVars(), /*result=*/ivUb); |
| 134 | |
| 135 | return cstr.addBound(type: BoundType::UB, pos: dimIv, boundMap: map); |
| 136 | } |
| 137 | |
| 138 | /// Canonicalize min/max operations in the context of for loops with a known |
| 139 | /// range. Call `canonicalizeMinMaxOp` and add the following constraints to |
| 140 | /// the constraint system (along with the missing dimensions): |
| 141 | /// |
| 142 | /// * iv >= lb |
| 143 | /// * iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 |
| 144 | /// |
| 145 | /// Note: Due to limitations of IntegerPolyhedron, only constant step sizes |
| 146 | /// are currently supported. |
| 147 | LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, |
| 148 | Operation *op, |
| 149 | LoopMatcherFn loopMatcher) { |
| 150 | FlatAffineValueConstraints constraints; |
| 151 | DenseSet<Value> allIvs; |
| 152 | |
| 153 | // Find all iteration variables among `minOp`'s operands add constrain them. |
| 154 | for (Value operand : op->getOperands()) { |
| 155 | // Skip duplicate ivs. |
| 156 | if (allIvs.contains(V: operand)) |
| 157 | continue; |
| 158 | |
| 159 | // If `operand` is an iteration variable: Find corresponding loop |
| 160 | // bounds and step. |
| 161 | Value iv = operand; |
| 162 | OpFoldResult lb, ub, step; |
| 163 | if (failed(Result: loopMatcher(operand, lb, ub, step))) |
| 164 | continue; |
| 165 | allIvs.insert(V: iv); |
| 166 | |
| 167 | if (failed(Result: addLoopRangeConstraints(cstr&: constraints, iv, lb, ub, step))) |
| 168 | return failure(); |
| 169 | } |
| 170 | |
| 171 | return canonicalizeMinMaxOp(rewriter, op, constraints); |
| 172 | } |
| 173 | |
| 174 | /// Try to simplify the given affine.min/max operation `op` after loop peeling. |
| 175 | /// This function can simplify min/max operations such as (ub is the previous |
| 176 | /// upper bound of the unpeeled loop): |
| 177 | /// ``` |
| 178 | /// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)> |
| 179 | /// %r = affine.min #affine.min #map(%iv)[%step, %ub] |
| 180 | /// ``` |
| 181 | /// and rewrites them into (in the case the peeled loop): |
| 182 | /// ``` |
| 183 | /// %r = %step |
| 184 | /// ``` |
| 185 | /// min/max operations inside the partial iteration are rewritten in a similar |
| 186 | /// way. |
| 187 | /// |
| 188 | /// This function builds up a set of constraints, capable of proving that: |
| 189 | /// * Inside the peeled loop: min(step, ub - iv) == step |
| 190 | /// * Inside the partial iteration: min(step, ub - iv) == ub - iv |
| 191 | /// |
| 192 | /// Returns `success` if the given operation was replaced by a new operation; |
| 193 | /// `failure` otherwise. |
| 194 | /// |
| 195 | /// Note: `ub` is the previous upper bound of the loop (before peeling). |
| 196 | /// `insideLoop` must be true for min/max ops inside the loop and false for |
| 197 | /// affine.min ops inside the partial iteration. For an explanation of the other |
| 198 | /// parameters, see comment of `canonicalizeMinMaxOpInLoop`. |
| 199 | LogicalResult scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, |
| 200 | Value iv, Value ub, Value step, |
| 201 | bool insideLoop) { |
| 202 | FlatAffineValueConstraints constraints; |
| 203 | constraints.appendDimVar(vals: {iv}); |
| 204 | constraints.appendSymbolVar(vals: {ub, step}); |
| 205 | if (auto constUb = getConstantIntValue(ofr: ub)) |
| 206 | constraints.addBound(type: BoundType::EQ, pos: 1, value: *constUb); |
| 207 | if (auto constStep = getConstantIntValue(ofr: step)) |
| 208 | constraints.addBound(type: BoundType::EQ, pos: 2, value: *constStep); |
| 209 | |
| 210 | // Add loop peeling invariant. This is the main piece of knowledge that |
| 211 | // enables AffineMinOp simplification. |
| 212 | if (insideLoop) { |
| 213 | // ub - iv >= step (equiv.: -iv + ub - step + 0 >= 0) |
| 214 | // Intuitively: Inside the peeled loop, every iteration is a "full" |
| 215 | // iteration, i.e., step divides the iteration space `ub - lb` evenly. |
| 216 | constraints.addInequality(inEq: {-1, 1, -1, 0}); |
| 217 | } else { |
| 218 | // ub - iv < step (equiv.: iv + -ub + step - 1 >= 0) |
| 219 | // Intuitively: `iv` is the split bound here, i.e., the iteration variable |
| 220 | // value of the very last iteration (in the unpeeled loop). At that point, |
| 221 | // there are less than `step` elements remaining. (Otherwise, the peeled |
| 222 | // loop would run for at least one more iteration.) |
| 223 | constraints.addInequality(inEq: {1, -1, 1, -1}); |
| 224 | } |
| 225 | |
| 226 | return canonicalizeMinMaxOp(rewriter, op, constraints); |
| 227 | } |
| 228 | |