| 1 | //===- ParallelLoopTiling.cpp - Tiles scf.parallel ------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements loop tiling on parallel loops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 18 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 19 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 20 | |
| 21 | namespace mlir { |
| 22 | #define GEN_PASS_DEF_SCFPARALLELLOOPTILING |
| 23 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
| 24 | } // namespace mlir |
| 25 | |
| 26 | using namespace mlir; |
| 27 | using namespace mlir::scf; |
| 28 | |
| 29 | /// Tile a parallel loop of the form |
| 30 | /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) |
| 31 | /// step (%arg4, %arg5) |
| 32 | /// |
| 33 | /// into |
| 34 | /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) |
| 35 | /// step (%arg4*tileSize[0], |
| 36 | /// %arg5*tileSize[1]) |
| 37 | /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) |
| 38 | /// min(%arg5*tileSize[1], %arg3-%i1)) |
| 39 | /// step (%arg4, %arg5) |
| 40 | /// |
| 41 | /// or, when no-min-max-bounds is true, into |
| 42 | /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) |
| 43 | /// step (%arg4*tileSize[0], |
| 44 | /// %arg5*tileSize[1]) |
| 45 | /// scf.parallel (%j0, %j1) = (0, 0) to (%arg4*tileSize[0], |
| 46 | /// %arg5*tileSize[1]) |
| 47 | /// step (%arg4, %arg5) |
| 48 | /// %inbound = (%j0 * %arg4 + %i0 < %arg2) && |
| 49 | /// (%j1 * %arg5 + %i1 < %arg3) |
| 50 | /// scf.if (%inbound) |
| 51 | /// .... |
| 52 | /// |
| 53 | /// where the uses of %i0 and %i1 in the loop body are replaced by |
| 54 | /// %i0 + j0 and %i1 + %j1. |
| 55 | /// |
| 56 | /// The old loop is replaced with the new one. |
| 57 | std::pair<ParallelOp, ParallelOp> |
| 58 | mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, |
| 59 | bool noMinMaxBounds) { |
| 60 | OpBuilder b(op); |
| 61 | auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); |
| 62 | SmallVector<Value, 2> tileSizeConstants; |
| 63 | tileSizeConstants.reserve(op.getUpperBound().size()); |
| 64 | for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) { |
| 65 | if (i < tileSizes.size()) |
| 66 | tileSizeConstants.push_back( |
| 67 | b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i])); |
| 68 | else |
| 69 | // Just pick 1 for the remaining dimensions. |
| 70 | tileSizeConstants.push_back( |
| 71 | b.create<arith::ConstantIndexOp>(op.getLoc(), 1)); |
| 72 | } |
| 73 | |
| 74 | // Create the outer loop with adjusted steps. |
| 75 | SmallVector<Value, 2> newSteps; |
| 76 | newSteps.reserve(op.getStep().size()); |
| 77 | for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) { |
| 78 | newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step), |
| 79 | std::get<1>(step))); |
| 80 | } |
| 81 | auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(), |
| 82 | op.getUpperBound(), newSteps); |
| 83 | b.setInsertionPointToStart(outerLoop.getBody()); |
| 84 | |
| 85 | // Compute min(size, dim - offset) to avoid out-of-bounds accesses. |
| 86 | auto minMap = AffineMap::get( |
| 87 | /*dimCount=*/3, /*symbolCount=*/0, |
| 88 | {getAffineDimExpr(/*position=*/0, context: b.getContext()), |
| 89 | getAffineDimExpr(/*position=*/1, context: b.getContext()) - |
| 90 | getAffineDimExpr(/*position=*/2, context: b.getContext())}, |
| 91 | b.getContext()); |
| 92 | |
| 93 | // Create the inner loop with adjusted bounds. |
| 94 | SmallVector<Value, 2> newBounds; |
| 95 | newBounds.reserve(op.getUpperBound().size()); |
| 96 | bool needInboundCheck = false; |
| 97 | for (auto [lowerBound, upperBound, newStep, iv, step, tileSizeConstant] : |
| 98 | llvm::zip(outerLoop.getLowerBound(), outerLoop.getUpperBound(), |
| 99 | outerLoop.getStep(), outerLoop.getInductionVars(), |
| 100 | op.getStep(), tileSizeConstants)) { |
| 101 | // Collect the statically known loop bounds |
| 102 | auto lowerBoundConstant = |
| 103 | dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp()); |
| 104 | auto upperBoundConstant = |
| 105 | dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp()); |
| 106 | auto stepConstant = |
| 107 | dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp()); |
| 108 | auto tileSize = |
| 109 | cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value(); |
| 110 | // If the loop bounds and the loop step are constant and if the number of |
| 111 | // loop iterations is an integer multiple of the tile size, we use a static |
| 112 | // bound for the inner loop. |
| 113 | if (lowerBoundConstant && upperBoundConstant && stepConstant) { |
| 114 | auto numIterations = llvm::divideCeil(upperBoundConstant.value() - |
| 115 | lowerBoundConstant.value(), |
| 116 | stepConstant.value()); |
| 117 | if (numIterations % tileSize == 0) { |
| 118 | newBounds.push_back(newStep); |
| 119 | continue; |
| 120 | } |
| 121 | } |
| 122 | |
| 123 | // For InboundCheck mode, just use the variable outer step |
| 124 | if (noMinMaxBounds) { |
| 125 | newBounds.push_back(newStep); |
| 126 | needInboundCheck = true; |
| 127 | continue; |
| 128 | } |
| 129 | |
| 130 | // Otherwise, we dynamically compute the bound for |
| 131 | // each iteration of the outer loop. |
| 132 | newBounds.push_back( |
| 133 | b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap, |
| 134 | ValueRange{newStep, upperBound, iv})); |
| 135 | } |
| 136 | auto innerLoop = b.create<ParallelOp>( |
| 137 | op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds, |
| 138 | op.getStep()); |
| 139 | |
| 140 | if (noMinMaxBounds && needInboundCheck) { |
| 141 | b.setInsertionPointToStart(innerLoop.getBody()); |
| 142 | // Insert in-bound check |
| 143 | Value inbound = |
| 144 | b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1)); |
| 145 | for (auto [outerUpperBound, outerIV, innerIV, innerStep] : |
| 146 | llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(), |
| 147 | innerLoop.getInductionVars(), innerLoop.getStep())) { |
| 148 | // %in_bound = %in_bound && |
| 149 | // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) |
| 150 | Value index = b.create<arith::AddIOp>( |
| 151 | op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep), |
| 152 | outerIV); |
| 153 | Value dimInbound = b.create<arith::CmpIOp>( |
| 154 | op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); |
| 155 | inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound); |
| 156 | } |
| 157 | auto ifInbound = b.create<IfOp>(op.getLoc(), |
| 158 | /*resultTypes*/ ArrayRef<Type>{}, inbound, |
| 159 | /*hasElseRegion*/ false); |
| 160 | ifInbound.getThenRegion().takeBody(op.getRegion()); |
| 161 | Block &thenBlock = ifInbound.getThenRegion().front(); |
| 162 | // Replace the scf.reduce terminator with an scf.yield terminator. |
| 163 | Operation *reduceOp = thenBlock.getTerminator(); |
| 164 | b.setInsertionPointToEnd(&thenBlock); |
| 165 | b.create<scf::YieldOp>(reduceOp->getLoc()); |
| 166 | reduceOp->erase(); |
| 167 | b.setInsertionPointToStart(innerLoop.getBody()); |
| 168 | for (const auto &ivs : llvm::enumerate(llvm::zip( |
| 169 | innerLoop.getInductionVars(), outerLoop.getInductionVars()))) { |
| 170 | auto newIndex = b.create<arith::AddIOp>( |
| 171 | op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); |
| 172 | thenBlock.getArgument(ivs.index()) |
| 173 | .replaceAllUsesExcept(newIndex, newIndex); |
| 174 | } |
| 175 | thenBlock.eraseArguments(start: 0, num: thenBlock.getNumArguments()); |
| 176 | } else { |
| 177 | innerLoop.getRegion().takeBody(op.getRegion()); |
| 178 | b.setInsertionPointToStart(innerLoop.getBody()); |
| 179 | for (auto ivs : llvm::zip(innerLoop.getInductionVars(), |
| 180 | outerLoop.getInductionVars())) { |
| 181 | Value innerIndex = std::get<0>(ivs); |
| 182 | auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs), |
| 183 | std::get<1>(ivs)); |
| 184 | innerIndex.replaceAllUsesExcept(newIndex, newIndex); |
| 185 | } |
| 186 | } |
| 187 | |
| 188 | op.erase(); |
| 189 | return std::make_pair(outerLoop, innerLoop); |
| 190 | } |
| 191 | |
| 192 | namespace { |
| 193 | struct ParallelLoopTiling |
| 194 | : public impl::SCFParallelLoopTilingBase<ParallelLoopTiling> { |
| 195 | ParallelLoopTiling() = default; |
| 196 | explicit ParallelLoopTiling(ArrayRef<int64_t> tileSizes, |
| 197 | bool noMinMaxBounds = false) { |
| 198 | this->tileSizes = tileSizes; |
| 199 | this->noMinMaxBounds = noMinMaxBounds; |
| 200 | } |
| 201 | |
| 202 | void runOnOperation() override { |
| 203 | for (auto tileSize : tileSizes) |
| 204 | if (tileSize == 0) { |
| 205 | mlir::emitError(mlir::UnknownLoc::get(&Pass::getContext()), |
| 206 | "tile size cannot be 0" ); |
| 207 | return signalPassFailure(); |
| 208 | } |
| 209 | auto *parentOp = getOperation(); |
| 210 | SmallVector<ParallelOp, 2> innermostPloops; |
| 211 | getInnermostParallelLoops(parentOp, innermostPloops); |
| 212 | for (ParallelOp ploop : innermostPloops) { |
| 213 | // FIXME: Add reduction support. |
| 214 | if (ploop.getNumReductions() == 0) |
| 215 | tileParallelLoop(ploop, tileSizes, noMinMaxBounds); |
| 216 | } |
| 217 | } |
| 218 | }; |
| 219 | } // namespace |
| 220 | |
| 221 | std::unique_ptr<Pass> |
| 222 | mlir::createParallelLoopTilingPass(ArrayRef<int64_t> tileSizes, |
| 223 | bool noMinMaxBounds) { |
| 224 | return std::make_unique<ParallelLoopTiling>(args&: tileSizes, args&: noMinMaxBounds); |
| 225 | } |
| 226 | |