| 1 | //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the tiling using TilingInterface. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
| 14 | |
| 15 | #include "mlir/Analysis/SliceAnalysis.h" |
| 16 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 19 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 21 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 22 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 24 | #include "mlir/IR/Dominance.h" |
| 25 | #include "mlir/IR/Matchers.h" |
| 26 | #include "mlir/IR/PatternMatch.h" |
| 27 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
| 28 | #include "mlir/Interfaces/TilingInterface.h" |
| 29 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 31 | #include "llvm/ADT/ScopeExit.h" |
| 32 | #include "llvm/ADT/TypeSwitch.h" |
| 33 | #include "llvm/Support/Debug.h" |
| 34 | #include <optional> |
| 35 | |
| 36 | #define DEBUG_TYPE "tile-using-interface" |
| 37 | |
| 38 | using namespace mlir; |
| 39 | |
| 40 | scf::SCFTilingOptions & |
| 41 | scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { |
| 42 | assert(!tileSizeComputationFunction && "tile sizes already set" ); |
| 43 | auto tileSizes = llvm::to_vector(Range&: ts); |
| 44 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { |
| 45 | return tileSizes; |
| 46 | }; |
| 47 | return *this; |
| 48 | } |
| 49 | |
| 50 | scf::SCFTilingOptions & |
| 51 | scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { |
| 52 | assert(!numThreadsComputationFunction && "num tiles already set" ); |
| 53 | auto numThreads = llvm::to_vector(Range&: nt); |
| 54 | numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { |
| 55 | return numThreads; |
| 56 | }; |
| 57 | return *this; |
| 58 | } |
| 59 | |
| 60 | /// Helper method to adjust the interchange vector to match the iteration |
| 61 | /// domain. |
| 62 | static SmallVector<int64_t> |
| 63 | fillInterchangeVector(ArrayRef<int64_t> interchangeVector, |
| 64 | size_t iterationDomainSize) { |
| 65 | SmallVector<int64_t> filledVector = llvm::to_vector(Range&: interchangeVector); |
| 66 | if (filledVector.size() < iterationDomainSize) { |
| 67 | auto range = llvm::seq<int64_t>(Begin: filledVector.size(), End: iterationDomainSize); |
| 68 | filledVector.append(in_start: range.begin(), in_end: range.end()); |
| 69 | } |
| 70 | if (filledVector.size() > iterationDomainSize) |
| 71 | filledVector.resize(N: iterationDomainSize); |
| 72 | return filledVector; |
| 73 | } |
| 74 | |
| 75 | //===----------------------------------------------------------------------===// |
| 76 | // tileUsingSCF implementation. |
| 77 | //===----------------------------------------------------------------------===// |
| 78 | |
| 79 | /// Verify the tile size options are set in a consistent manner. |
| 80 | static LogicalResult |
| 81 | verifyTileSizeOptions(RewriterBase &rewriter, Location loc, |
| 82 | const scf::SCFTilingOptions &options) { |
| 83 | // Specifying number of threads is only supported on `scf.forall` op. |
| 84 | if (options.numThreadsComputationFunction && |
| 85 | options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { |
| 86 | return rewriter.notifyMatchFailure( |
| 87 | arg&: loc, msg: "number of threads can only by specified when loop type is " |
| 88 | "set to use `scf.forall`" ); |
| 89 | } |
| 90 | |
| 91 | // If specified, check that the interchange vector is a permutation. |
| 92 | if (!options.interchangeVector.empty()) { |
| 93 | if (!isPermutationVector(interchange: options.interchangeVector)) { |
| 94 | return rewriter.notifyMatchFailure( |
| 95 | arg&: loc, msg: "invalid interchange vector, not a permutation of the entire " |
| 96 | "iteration space" ); |
| 97 | } |
| 98 | } |
| 99 | return success(); |
| 100 | } |
| 101 | |
| 102 | /// Method to instantiate the tile sizes and/or number of threads specified |
| 103 | /// by the user. |
| 104 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> |
| 105 | getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, |
| 106 | ArrayRef<Range> iterationDomain, |
| 107 | const scf::SCFTilingOptions &options) { |
| 108 | OpFoldResult zero = rewriter.getIndexAttr(0); |
| 109 | SmallVector<OpFoldResult> tileSizes, numThreads; |
| 110 | size_t numLoops = iterationDomain.size(); |
| 111 | |
| 112 | // Check whether the number of tiles to use is specified. |
| 113 | if (options.numThreadsComputationFunction) { |
| 114 | numThreads = options.numThreadsComputationFunction(rewriter, op); |
| 115 | numThreads.resize(N: numLoops, NV: zero); |
| 116 | |
| 117 | // If the number of tiles is also specified, use that. |
| 118 | if (options.tileSizeComputationFunction) { |
| 119 | tileSizes = options.tileSizeComputationFunction(rewriter, op); |
| 120 | tileSizes.resize(N: numLoops, NV: zero); |
| 121 | return {tileSizes, numThreads}; |
| 122 | } |
| 123 | |
| 124 | // Compute the tile sizes from the iteration domain and number |
| 125 | // of tiles as follows |
| 126 | // - niters = ceilDiv(ub - lb, step) |
| 127 | // - tileSize = ceilDiv(niters, numThreads) |
| 128 | AffineExpr s0, s1, s2; |
| 129 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2); |
| 130 | // TODO: The step here is assumed to be 1. |
| 131 | AffineExpr numItersExpr = (s1 - s0); |
| 132 | AffineExpr tileSizeExpr = numItersExpr.ceilDiv(other: s2); |
| 133 | tileSizes.resize(N: numLoops, NV: zero); |
| 134 | for (auto [index, range, nt] : |
| 135 | llvm::enumerate(First&: iterationDomain, Rest&: numThreads)) { |
| 136 | if (isZeroInteger(v: nt)) |
| 137 | continue; |
| 138 | |
| 139 | tileSizes[index] = affine::makeComposedFoldedAffineApply( |
| 140 | rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); |
| 141 | } |
| 142 | tileSizes.resize(N: numLoops, NV: zero); |
| 143 | return {tileSizes, numThreads}; |
| 144 | } |
| 145 | |
| 146 | // Enforce the convention that "tiling by zero" |
| 147 | // skips tiling a particular dimension. This convention is significantly |
| 148 | // simpler to handle instead of adjusting affine maps to account for missing |
| 149 | // dimensions. |
| 150 | assert(options.tileSizeComputationFunction && |
| 151 | "expected tile sizes to be specified" ); |
| 152 | tileSizes = options.tileSizeComputationFunction(rewriter, op); |
| 153 | tileSizes.resize(N: numLoops, NV: zero); |
| 154 | |
| 155 | return {tileSizes, numThreads}; |
| 156 | } |
| 157 | |
| 158 | /// Checks if any of the tiled loops are not parallel. |
| 159 | static void checkSafeToTileToForall(TilingInterface op, |
| 160 | ArrayRef<OpFoldResult> tileSizes, |
| 161 | ArrayRef<OpFoldResult> numThreads) { |
| 162 | auto iterators = op.getLoopIteratorTypes(); |
| 163 | assert(iterators.size() == tileSizes.size() && |
| 164 | "expected as many tile size values as number of loops" ); |
| 165 | assert((numThreads.empty() || (numThreads.size() == iterators.size())) && |
| 166 | "when specified, expected number of threads to use for each loop" ); |
| 167 | |
| 168 | for (auto [index, iterator, tileSize] : |
| 169 | llvm::enumerate(iterators, tileSizes)) { |
| 170 | // If num threads is specified, check that it is greater than one only for |
| 171 | // parallel dimensions. |
| 172 | if (!numThreads.empty()) { |
| 173 | if (std::optional<int64_t> constNumThreads = |
| 174 | getConstantIntValue(numThreads[index])) { |
| 175 | if (constNumThreads.value() > 1 && |
| 176 | iterator != utils::IteratorType::parallel) { |
| 177 | op.emitWarning() << "tiling is not thread safe at axis #" << index; |
| 178 | } |
| 179 | } |
| 180 | continue; |
| 181 | } |
| 182 | |
| 183 | if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) { |
| 184 | if (constTileSize.value() > 0 && |
| 185 | iterator != utils::IteratorType::parallel) { |
| 186 | op.emitWarning() << "tiling is not thread safe at axis #" << index; |
| 187 | } |
| 188 | } |
| 189 | } |
| 190 | } |
| 191 | |
| 192 | /// Check if `stride` evenly divides the trip count `size - offset`. |
| 193 | static bool tileDividesIterationDomain(Range loopRange) { |
| 194 | std::optional<int64_t> offsetAsInt = getConstantIntValue(ofr: loopRange.offset); |
| 195 | if (!offsetAsInt) |
| 196 | return false; |
| 197 | std::optional<int64_t> sizeAsInt = getConstantIntValue(ofr: loopRange.size); |
| 198 | if (!sizeAsInt) |
| 199 | return false; |
| 200 | std::optional<int64_t> strideAsInt = getConstantIntValue(ofr: loopRange.stride); |
| 201 | if (!strideAsInt) |
| 202 | return false; |
| 203 | return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); |
| 204 | } |
| 205 | |
| 206 | /// Returns the bounded tile size given the current `offset`, `loopRange` and |
| 207 | /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. |
| 208 | static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, |
| 209 | Range loopRange, OpFoldResult offset, |
| 210 | OpFoldResult tileSize) { |
| 211 | std::optional<int64_t> ts = getConstantIntValue(ofr: tileSize); |
| 212 | if (ts && ts.value() == 1) |
| 213 | return tileSize; |
| 214 | |
| 215 | if (tileDividesIterationDomain( |
| 216 | loopRange: Range{.offset: loopRange.offset, .size: loopRange.size, .stride: tileSize})) |
| 217 | return tileSize; |
| 218 | |
| 219 | // The tile size to use (to avoid out of bounds access) is minimum of |
| 220 | // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled |
| 221 | // loop. |
| 222 | AffineExpr s0, s1, d0; |
| 223 | bindDims(ctx: b.getContext(), exprs&: d0); |
| 224 | bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1); |
| 225 | AffineMap minMap = AffineMap::get(dimCount: 1, symbolCount: 2, results: {s0 - d0, s1}, context: b.getContext()); |
| 226 | Value size = getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange.size); |
| 227 | return affine::makeComposedFoldedAffineMin( |
| 228 | b, loc, map: minMap, operands: SmallVector<OpFoldResult>{offset, size, tileSize}); |
| 229 | } |
| 230 | |
| 231 | /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less |
| 232 | /// than `iterationSize`. |
| 233 | static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, |
| 234 | OpFoldResult numThreads, |
| 235 | OpFoldResult iterationSize) { |
| 236 | std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize); |
| 237 | std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads); |
| 238 | std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize); |
| 239 | if (!tileSizeConst || !numThreadsConst || !iterSizeConst) |
| 240 | return false; |
| 241 | return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; |
| 242 | } |
| 243 | |
| 244 | /// Compute the `OpFoldResult`s that represents the multi-dimensional |
| 245 | /// `offset`s and `size`s of the tile of the iteration space that the |
| 246 | /// innermost loop body of the generated tiled loops corresponds to. |
| 247 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> |
| 248 | getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, |
| 249 | ArrayRef<Range> iterationDomain, |
| 250 | ArrayRef<OpFoldResult> tileSizes, |
| 251 | ArrayRef<OpFoldResult> numThreads) { |
| 252 | SmallVector<OpFoldResult> offsets, sizes; |
| 253 | int materializedLoopNum = 0; |
| 254 | |
| 255 | if (!numThreads.empty()) { |
| 256 | AffineExpr d0, d1, s0, s1; |
| 257 | AffineExpr offsetExpr, residualTileSizeExpr; |
| 258 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1); |
| 259 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
| 260 | offsetExpr = d0 + d1 * s0; |
| 261 | residualTileSizeExpr = s1 - (d0 + d1 * s0); |
| 262 | |
| 263 | for (auto [nt, tileSize, loopRange] : |
| 264 | llvm::zip_equal(t&: numThreads, u&: tileSizes, args&: iterationDomain)) { |
| 265 | |
| 266 | // Non-tiled cases, set the offset and size to the |
| 267 | // `loopRange.offset/size`. |
| 268 | if (isZeroInteger(v: nt)) { |
| 269 | offsets.push_back(Elt: loopRange.offset); |
| 270 | sizes.push_back(Elt: loopRange.size); |
| 271 | continue; |
| 272 | } |
| 273 | |
| 274 | Value iv = ivs[materializedLoopNum++]; |
| 275 | OpFoldResult offset = affine::makeComposedFoldedAffineApply( |
| 276 | b&: rewriter, loc, expr: offsetExpr, |
| 277 | operands: ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); |
| 278 | OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( |
| 279 | b&: rewriter, loc, expr: residualTileSizeExpr, |
| 280 | operands: {loopRange.offset, nt, tileSize, loopRange.size}); |
| 281 | |
| 282 | OpFoldResult size = tileSize; |
| 283 | if (!isZeroInteger(v: residualTileSize)) { |
| 284 | OpFoldResult sizeMinusOffsetPerThread = |
| 285 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 - d0, |
| 286 | operands: {offset, loopRange.size}); |
| 287 | size = affine::makeComposedFoldedAffineMin( |
| 288 | b&: rewriter, loc, |
| 289 | map: AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()), |
| 290 | operands: {sizeMinusOffsetPerThread, tileSize}); |
| 291 | } |
| 292 | |
| 293 | // Consider the case where the original loop was `[0, 100)`. |
| 294 | // If number of threads are `7`, the tile size would be computed as |
| 295 | // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) |
| 296 | // - `offset = 0 + 6 * 15 = 105` |
| 297 | // - `tileSize = min(15, 100 - 105) = -5` |
| 298 | // To avoid negative tile sizes, we need to do a further |
| 299 | // `nonNegativeTileSize = affine.max(0, tileSize)`. |
| 300 | // This `max` can be avoided if |
| 301 | // `offset + tileSize * (numThreads - 1) < (ub - lb)` |
| 302 | if (!canOmitTileOffsetInBoundsCheck(tileSize, numThreads: nt, iterationSize: loopRange.size)) { |
| 303 | AffineMap maxMap = |
| 304 | AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()); |
| 305 | size = affine::makeComposedFoldedAffineMax( |
| 306 | rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); |
| 307 | } |
| 308 | |
| 309 | offsets.push_back(Elt: offset); |
| 310 | sizes.push_back(Elt: size); |
| 311 | } |
| 312 | return {offsets, sizes}; |
| 313 | } else { |
| 314 | for (auto [tileSize, loopRange] : |
| 315 | llvm::zip_equal(t&: tileSizes, u&: iterationDomain)) { |
| 316 | |
| 317 | // Non-tiled cases, set the offset and size to the |
| 318 | // `loopRange.offset/size`. |
| 319 | if (isZeroInteger(v: tileSize)) { |
| 320 | offsets.push_back(Elt: loopRange.offset); |
| 321 | sizes.push_back(Elt: loopRange.size); |
| 322 | continue; |
| 323 | } |
| 324 | |
| 325 | Value iv = ivs[materializedLoopNum++]; |
| 326 | OpFoldResult offset = getAsOpFoldResult(val: iv); |
| 327 | offsets.push_back(Elt: offset); |
| 328 | OpFoldResult size = |
| 329 | getBoundedTileSize(b&: rewriter, loc, loopRange, offset, tileSize); |
| 330 | sizes.push_back(Elt: size); |
| 331 | } |
| 332 | return {offsets, sizes}; |
| 333 | } |
| 334 | } |
| 335 | |
| 336 | /// Function to return the bounds of the loops to be generated. |
| 337 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
| 338 | SmallVector<OpFoldResult>> |
| 339 | getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
| 340 | ArrayRef<OpFoldResult> tileSizes) { |
| 341 | SmallVector<OpFoldResult> lbs, ubs, steps; |
| 342 | for (auto [loopRange, tileSize] : llvm::zip_equal(t&: loopRanges, u&: tileSizes)) { |
| 343 | // No loop if the tile size is 0. |
| 344 | if (isZeroInteger(v: tileSize)) |
| 345 | continue; |
| 346 | lbs.push_back(Elt: loopRange.offset); |
| 347 | ubs.push_back(Elt: loopRange.size); |
| 348 | steps.push_back(Elt: tileSize); |
| 349 | } |
| 350 | return {lbs, ubs, steps}; |
| 351 | } |
| 352 | |
| 353 | /// A function that allows returning additional yielded values during |
| 354 | /// `yieldTiledValuesAndReplace`. |
| 355 | /// - `ivs` induction variable for the loop. |
| 356 | /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. |
| 357 | /// - `tiledValues` the tiled values to return. Must be of same size as |
| 358 | /// `newbbArgs`, each element of this array is inserted into the corresponding |
| 359 | /// element in `newbbArgs`. |
| 360 | /// - `resultOffsets` is of the same size as `tiledValues` and represents |
| 361 | /// the offsets to use when inserting corresponding element from `tiledValues` |
| 362 | /// into the element from `newBbArgs`. |
| 363 | /// - `resultSizes` is of the same size as `tiledValues` and represents |
| 364 | /// the size of the corresponding element from `tiledValues` inserted into |
| 365 | /// the element from `newBbArgs`. |
| 366 | /// In case the method needs to return `failure()` the method is expected |
| 367 | /// to clean up any inserted operations. |
| 368 | using YieldTiledValuesFn = std::function<LogicalResult( |
| 369 | RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, |
| 370 | SmallVector<Value> &tiledValues, |
| 371 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
| 372 | SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; |
| 373 | |
| 374 | /// Clones the operation and updates the destination if the operation |
| 375 | /// implements the `DestinationStyleOpInterface`. |
| 376 | static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, |
| 377 | Operation *op, |
| 378 | ValueRange newDestArgs) { |
| 379 | Operation *clonedOp = rewriter.clone(op&: *op); |
| 380 | if (newDestArgs.empty()) |
| 381 | return clonedOp; |
| 382 | if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp)) |
| 383 | destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); |
| 384 | return clonedOp; |
| 385 | } |
| 386 | |
| 387 | /// Generate the tile-loop nest using `scf.for` operation. |
| 388 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
| 389 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
| 390 | /// - `destinationTensors` are the init values to use for the outer most loop. |
| 391 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
| 392 | /// most |
| 393 | /// loop. |
| 394 | /// - `loops` is an in-out parameter into which the generated loops are |
| 395 | /// populated. |
| 396 | static LogicalResult generateLoopNestUsingForOp( |
| 397 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
| 398 | ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, |
| 399 | YieldTiledValuesFn yieldTiledValuesFn, |
| 400 | SmallVector<LoopLikeOpInterface> &loops) { |
| 401 | assert(!loopRanges.empty() && "unexpected empty loop ranges" ); |
| 402 | assert(loopRanges.size() == tileSizes.size() && |
| 403 | "expected as many tile sizes as loop ranges" ); |
| 404 | OpBuilder::InsertionGuard guard(rewriter); |
| 405 | |
| 406 | SmallVector<OpFoldResult> lbs, ubs, steps; |
| 407 | std::tie(args&: lbs, args&: ubs, args&: steps) = |
| 408 | getLoopBounds(rewriter, loc, loopRanges, tileSizes); |
| 409 | SmallVector<Value> lbVals = |
| 410 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: lbs); |
| 411 | SmallVector<Value> ubVals = |
| 412 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: ubs); |
| 413 | SmallVector<Value> stepVals = |
| 414 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: steps); |
| 415 | |
| 416 | SmallVector<Value> ivs; |
| 417 | for (auto [lb, ub, step] : llvm::zip_equal(t&: lbVals, u&: ubVals, args&: stepVals)) { |
| 418 | auto loop = |
| 419 | rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, |
| 420 | [](OpBuilder &bodyBuilder, Location bodyLoc, |
| 421 | Value iv, ValueRange /*iterArgs*/) {}); |
| 422 | loops.push_back(loop); |
| 423 | ivs.push_back(Elt: loop.getInductionVar()); |
| 424 | rewriter.setInsertionPointToEnd(loop.getBody()); |
| 425 | destinationTensors = loop.getRegionIterArgs(); |
| 426 | } |
| 427 | |
| 428 | SmallVector<Value> tiledResults; |
| 429 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 430 | if (failed(Result: yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, |
| 431 | tiledResults, resultOffsets, resultSizes))) { |
| 432 | return rewriter.notifyMatchFailure( |
| 433 | arg&: loc, msg: "failed to generate inner tile loop body" ); |
| 434 | } |
| 435 | if (loops.empty()) |
| 436 | return success(); |
| 437 | |
| 438 | assert(tiledResults.size() == destinationTensors.size() && |
| 439 | "Number of results of body should be equal to number of iter args" ); |
| 440 | |
| 441 | // 6. Yield all the results of the tiled operation. |
| 442 | SmallVector<Value> yieldedValues; |
| 443 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
| 444 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
| 445 | args&: resultSizes)) { |
| 446 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 447 | rewriter.getIndexAttr(1)); |
| 448 | auto insertSlice = rewriter.create<tensor::InsertSliceOp>( |
| 449 | loc, tiledValue, destinationTensor, resultOffset, resultSize, |
| 450 | resultStride); |
| 451 | yieldedValues.push_back(Elt: insertSlice); |
| 452 | } |
| 453 | rewriter.create<scf::YieldOp>(loc, yieldedValues); |
| 454 | |
| 455 | // Add the scf.yield operations for all the outer loops. |
| 456 | for (auto [outerLoop, innerLoop] : |
| 457 | llvm::zip_equal(MutableArrayRef(loops).drop_back(), |
| 458 | MutableArrayRef(loops).drop_front())) { |
| 459 | rewriter.setInsertionPointToEnd( |
| 460 | cast<scf::ForOp>(outerLoop.getOperation()).getBody()); |
| 461 | rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); |
| 462 | } |
| 463 | return success(); |
| 464 | } |
| 465 | |
| 466 | /// Generate the tile-loop nest using `scf.forall` operation. |
| 467 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
| 468 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
| 469 | /// - `destinationTensors` are the init values to use for the outer most loop. |
| 470 | /// - `mappingVector` is the mapping attributes to use for loop construction. |
| 471 | /// Can be empty. |
| 472 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
| 473 | /// most |
| 474 | /// loop. |
| 475 | /// - `loops` is an in-out parameter into which the generated loops are |
| 476 | /// populated. |
| 477 | static LogicalResult generateLoopNestUsingForallOp( |
| 478 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
| 479 | ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, |
| 480 | ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, |
| 481 | YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { |
| 482 | assert(!loopRanges.empty() && "unexpected empty loop ranges" ); |
| 483 | assert(loopRanges.size() == tileSizes.size() && |
| 484 | "expected as many tile sizes as loop ranges" ); |
| 485 | OpBuilder::InsertionGuard guard(rewriter); |
| 486 | |
| 487 | std::optional<ArrayAttr> mappingAttr; |
| 488 | if (!mappingVector.empty()) |
| 489 | mappingAttr = rewriter.getArrayAttr(mappingVector); |
| 490 | |
| 491 | scf::ForallOp forallOp; |
| 492 | bool useNumThreads = !numThreads.empty(); |
| 493 | |
| 494 | if (useNumThreads) { |
| 495 | // Prune the zero numthreads. |
| 496 | SmallVector<OpFoldResult> nonZeroNumThreads; |
| 497 | for (auto nt : numThreads) { |
| 498 | if (isZeroInteger(v: nt)) |
| 499 | continue; |
| 500 | nonZeroNumThreads.push_back(Elt: nt); |
| 501 | } |
| 502 | forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, |
| 503 | destinationTensors, mappingAttr); |
| 504 | } else { |
| 505 | SmallVector<OpFoldResult> lbs, ubs, steps; |
| 506 | std::tie(args&: lbs, args&: ubs, args&: steps) = |
| 507 | getLoopBounds(rewriter, loc, loopRanges, tileSizes); |
| 508 | forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, |
| 509 | destinationTensors, mappingAttr); |
| 510 | } |
| 511 | loops.push_back(forallOp); |
| 512 | |
| 513 | rewriter.setInsertionPoint(forallOp.getTerminator()); |
| 514 | destinationTensors = forallOp.getRegionOutArgs(); |
| 515 | |
| 516 | SmallVector<Value> tiledResults; |
| 517 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 518 | if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), |
| 519 | destinationTensors, tiledResults, resultOffsets, |
| 520 | resultSizes))) |
| 521 | return rewriter.notifyMatchFailure(arg&: loc, msg: "failed to generate loop body" ); |
| 522 | |
| 523 | rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); |
| 524 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
| 525 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
| 526 | args&: resultSizes)) { |
| 527 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 528 | rewriter.getIndexAttr(1)); |
| 529 | |
| 530 | rewriter.create<tensor::ParallelInsertSliceOp>( |
| 531 | loc, tiledValue, destinationTensor, resultOffset, resultSize, |
| 532 | resultStride); |
| 533 | } |
| 534 | return success(); |
| 535 | } |
| 536 | |
| 537 | /// Generate the tile-loop nest using the loop construct specifed in `options`. |
| 538 | /// - `options`: Tiling options specified. |
| 539 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
| 540 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
| 541 | /// - `destinationTensors` are the init values to use for the outer most loop. |
| 542 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
| 543 | /// most |
| 544 | /// loop. |
| 545 | /// - `loops` is an in-out parameter into which the generated loops are |
| 546 | /// populated. |
| 547 | static LogicalResult generateLoopNest( |
| 548 | RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, |
| 549 | ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, |
| 550 | ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, |
| 551 | YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { |
| 552 | // If the tile sizes are all zero, no loops are generated. Just call the |
| 553 | // callback function to handle untiled case. |
| 554 | if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) { |
| 555 | SmallVector<Value> tiledResults; |
| 556 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 557 | return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, |
| 558 | tiledResults, resultOffsets, resultSizes); |
| 559 | } |
| 560 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { |
| 561 | return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, |
| 562 | destinationTensors, tiledBodyFn, loops); |
| 563 | } |
| 564 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { |
| 565 | return generateLoopNestUsingForallOp( |
| 566 | rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, |
| 567 | destinationTensors, tiledBodyFn, loops); |
| 568 | } |
| 569 | return rewriter.notifyMatchFailure(arg&: loc, msg: "unhandled loop type" ); |
| 570 | } |
| 571 | |
| 572 | static FailureOr<SmallVector<Value>> |
| 573 | createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, |
| 574 | ArrayRef<OpFoldResult> tileSizes, |
| 575 | const scf::SCFTilingOptions &options) { |
| 576 | SmallVector<Value> initTensors; |
| 577 | Location loc = op->getLoc(); |
| 578 | switch (options.reductionStrategy) { |
| 579 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
| 580 | if (failed(tensor::getOrCreateDestinations(b&: rewriter, loc, op: op, result&: initTensors))) |
| 581 | return failure(); |
| 582 | return initTensors; |
| 583 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
| 584 | PartialReductionOuterReduction: { |
| 585 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
| 586 | if (!redOp) { |
| 587 | return rewriter.notifyMatchFailure( |
| 588 | op, "PartialReductionOuterReduction tiling strategy is only supported" |
| 589 | "for operations implementing PartialReductionOpInterface" ); |
| 590 | } |
| 591 | // Get reduction dimensions. |
| 592 | // TODO: PartialReductionOpInterface should really query TilingInterface |
| 593 | // itself and find reduction dimensions. |
| 594 | SmallVector<int> reductionDims; |
| 595 | for (auto [idx, iteratorType] : |
| 596 | llvm::enumerate(op.getLoopIteratorTypes())) { |
| 597 | if (iteratorType == utils::IteratorType::reduction) |
| 598 | reductionDims.push_back(idx); |
| 599 | } |
| 600 | return redOp.generateInitialTensorForPartialReduction( |
| 601 | rewriter, loc, tileSizes, reductionDims); |
| 602 | } |
| 603 | default: |
| 604 | return rewriter.notifyMatchFailure(op, |
| 605 | "unhandled reduction tiling strategy" ); |
| 606 | } |
| 607 | } |
| 608 | |
| 609 | static FailureOr<TilingResult> |
| 610 | getTiledImplementation(RewriterBase &rewriter, TilingInterface op, |
| 611 | ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets, |
| 612 | ArrayRef<OpFoldResult> sizes, |
| 613 | const scf::SCFTilingOptions &options) { |
| 614 | switch (options.reductionStrategy) { |
| 615 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
| 616 | return op.getTiledImplementation(rewriter, offsets, sizes); |
| 617 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
| 618 | PartialReductionOuterReduction: { |
| 619 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
| 620 | if (!redOp) { |
| 621 | return rewriter.notifyMatchFailure( |
| 622 | op, "PartialReductionOuterReduction tiling strategy is only " |
| 623 | "supported for operations " |
| 624 | "implementing PartialReductionOpInterface" ); |
| 625 | } |
| 626 | // Get reduction dimensions. |
| 627 | // TODO: PartialReductionOpInterface should really query TilingInterface |
| 628 | // itself and find reduction dimensions. |
| 629 | SmallVector<int> reductionDims; |
| 630 | for (auto [idx, iteratorType] : |
| 631 | llvm::enumerate(op.getLoopIteratorTypes())) { |
| 632 | if (iteratorType == utils::IteratorType::reduction) |
| 633 | reductionDims.push_back(idx); |
| 634 | } |
| 635 | return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg, |
| 636 | offsets, sizes, reductionDims); |
| 637 | } |
| 638 | default: |
| 639 | return rewriter.notifyMatchFailure(op, |
| 640 | "unhandled reduction tiling strategy" ); |
| 641 | } |
| 642 | } |
| 643 | |
| 644 | static LogicalResult |
| 645 | getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, |
| 646 | TilingInterface op, ArrayRef<OpFoldResult> offsets, |
| 647 | ArrayRef<OpFoldResult> sizes, |
| 648 | SmallVector<OpFoldResult> &resultOffset, |
| 649 | SmallVector<OpFoldResult> &resultSize, |
| 650 | const scf::SCFTilingOptions &options) { |
| 651 | |
| 652 | switch (options.reductionStrategy) { |
| 653 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
| 654 | return op.getResultTilePosition(rewriter, index, offsets, sizes, |
| 655 | resultOffset, resultSize); |
| 656 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
| 657 | PartialReductionOuterReduction: { |
| 658 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
| 659 | if (!redOp) { |
| 660 | return rewriter.notifyMatchFailure( |
| 661 | op, "PartialReductionOuterReduction tiling strategy is only supported" |
| 662 | "for operations implementing PartialReductionOpInterface" ); |
| 663 | } |
| 664 | // Get reduction dimensions. |
| 665 | // TODO: PartialReductionOpInterface should really query TilingInterface |
| 666 | // itself and find reduction dimensions. |
| 667 | SmallVector<int> reductionDims; |
| 668 | for (auto [idx, iteratorType] : |
| 669 | llvm::enumerate(op.getLoopIteratorTypes())) { |
| 670 | if (iteratorType == utils::IteratorType::reduction) |
| 671 | reductionDims.push_back(idx); |
| 672 | } |
| 673 | return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, |
| 674 | resultOffset, resultSize, |
| 675 | reductionDims); |
| 676 | } |
| 677 | default: |
| 678 | return rewriter.notifyMatchFailure(op, |
| 679 | "unhandled reduction tiling strategy" ); |
| 680 | } |
| 681 | } |
| 682 | |
| 683 | static FailureOr<MergeResult> |
| 684 | mergeTilingResults(RewriterBase &rewriter, TilingInterface op, |
| 685 | ValueRange partialResults, |
| 686 | const scf::SCFTilingOptions &options) { |
| 687 | switch (options.reductionStrategy) { |
| 688 | case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction: |
| 689 | // No need to merge results for reduction tiling strategy. |
| 690 | return MergeResult{.mergeOps: {}, .replacements: partialResults}; |
| 691 | case scf::SCFTilingOptions::ReductionTilingStrategy:: |
| 692 | PartialReductionOuterReduction: { |
| 693 | auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); |
| 694 | if (!redOp) { |
| 695 | return rewriter.notifyMatchFailure( |
| 696 | op, "PartialReductionOuterReduction tiling strategy is only " |
| 697 | "supported for operations " |
| 698 | "implementing PartialReductionOpInterface" ); |
| 699 | } |
| 700 | // Get reduction dimensions. |
| 701 | // TODO: PartialReductionOpInterface should really query TilingInterface |
| 702 | // itself and find reduction dimensions. |
| 703 | SmallVector<int> reductionDims; |
| 704 | for (auto [idx, iteratorType] : |
| 705 | llvm::enumerate(op.getLoopIteratorTypes())) { |
| 706 | if (iteratorType == utils::IteratorType::reduction) |
| 707 | reductionDims.push_back(idx); |
| 708 | } |
| 709 | return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, |
| 710 | reductionDims); |
| 711 | } |
| 712 | default: |
| 713 | return rewriter.notifyMatchFailure(op, |
| 714 | "unhandled reduction tiling strategy" ); |
| 715 | } |
| 716 | } |
| 717 | |
| 718 | /// Append the specified additional `newInitOperands` operands to the |
| 719 | /// loops existing `init` operands (or similar), and replace `loopOp` with |
| 720 | /// the new loop that has the additional init operands. The loop body of |
| 721 | /// this loop is moved over to the new loop. `yieldTiledValuesFn` |
| 722 | /// is called to get the new tiled values returned, and the offset |
| 723 | /// and sizes at which the tiled value is inserted into the |
| 724 | /// new region iter_args that correspond to the newly added init operands. |
| 725 | template <typename LoopType> |
| 726 | FailureOr<LoopLikeOpInterface> |
| 727 | yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, |
| 728 | ValueRange newInitOperands, |
| 729 | YieldTiledValuesFn yieldTiledValuesFn) { |
| 730 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type" ); |
| 731 | } |
| 732 | |
| 733 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. |
| 734 | template <> |
| 735 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( |
| 736 | scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
| 737 | YieldTiledValuesFn yieldTiledValuesFn) { |
| 738 | OpBuilder::InsertionGuard g(rewriter); |
| 739 | Location loc = loopOp.getLoc(); |
| 740 | rewriter.setInsertionPoint(loopOp); |
| 741 | |
| 742 | auto inits = llvm::to_vector(loopOp.getInitArgs()); |
| 743 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
| 744 | auto newLoop = rewriter.create<scf::ForOp>( |
| 745 | loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), |
| 746 | inits, [](OpBuilder &, Location, Value, ValueRange) {}); |
| 747 | |
| 748 | // Move the loop body to the new op. |
| 749 | Block *loopBody = loopOp.getBody(); |
| 750 | Block *newLoopBody = newLoop.getBody(); |
| 751 | rewriter.mergeBlocks( |
| 752 | loopBody, newLoopBody, |
| 753 | newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
| 754 | |
| 755 | auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator()); |
| 756 | rewriter.setInsertionPoint(yieldOp); |
| 757 | |
| 758 | SmallVector<Value> tiledValues; |
| 759 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 760 | ValueRange newRegionIterArgs = |
| 761 | newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
| 762 | if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), |
| 763 | newRegionIterArgs, tiledValues, resultOffsets, |
| 764 | resultSizes))) { |
| 765 | rewriter.eraseOp(newLoop); |
| 766 | return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values" ); |
| 767 | } |
| 768 | |
| 769 | SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands()); |
| 770 | for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : |
| 771 | llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, |
| 772 | resultSizes)) { |
| 773 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 774 | rewriter.getIndexAttr(1)); |
| 775 | Value insert = rewriter.create<tensor::InsertSliceOp>( |
| 776 | yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, |
| 777 | resultStride); |
| 778 | newYieldValues.push_back(insert); |
| 779 | } |
| 780 | |
| 781 | rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues); |
| 782 | rewriter.replaceOp(loopOp, |
| 783 | newLoop->getResults().take_front(loopOp.getNumResults())); |
| 784 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
| 785 | } |
| 786 | |
| 787 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` |
| 788 | template <> |
| 789 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( |
| 790 | scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
| 791 | YieldTiledValuesFn yieldTiledValuesFn) { |
| 792 | OpBuilder::InsertionGuard g(rewriter); |
| 793 | Location loc = loopOp.getLoc(); |
| 794 | rewriter.setInsertionPoint(loopOp); |
| 795 | auto inits = llvm::to_vector(loopOp.getOutputs()); |
| 796 | inits.append(newInitOperands.begin(), newInitOperands.end()); |
| 797 | auto newLoop = rewriter.create<scf::ForallOp>( |
| 798 | loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), |
| 799 | loopOp.getMixedStep(), inits, loopOp.getMapping(), |
| 800 | [](OpBuilder &, Location, ValueRange) {}); |
| 801 | |
| 802 | // Move the region of the current block to the newly created op. |
| 803 | Block *loopBody = loopOp.getBody(); |
| 804 | Block *newLoopBody = newLoop.getBody(); |
| 805 | rewriter.mergeBlocks( |
| 806 | loopBody, newLoopBody, |
| 807 | newLoopBody->getArguments().take_front(loopBody->getNumArguments())); |
| 808 | |
| 809 | auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator()); |
| 810 | rewriter.setInsertionPoint(terminator); |
| 811 | SmallVector<Value> tiledValues; |
| 812 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 813 | ValueRange regionIterArgs = |
| 814 | newLoop.getRegionIterArgs().take_back(newInitOperands.size()); |
| 815 | if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), |
| 816 | regionIterArgs, tiledValues, resultOffsets, |
| 817 | resultSizes))) { |
| 818 | rewriter.eraseOp(newLoop); |
| 819 | return rewriter.notifyMatchFailure(loopOp, |
| 820 | "failed to get yielded tiled values" ); |
| 821 | } |
| 822 | |
| 823 | // Update the terminator. |
| 824 | rewriter.setInsertionPointToEnd(terminator.getBody()); |
| 825 | |
| 826 | for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( |
| 827 | tiledValues, regionIterArgs, resultOffsets, resultSizes)) { |
| 828 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 829 | rewriter.getIndexAttr(1)); |
| 830 | rewriter.create<tensor::ParallelInsertSliceOp>( |
| 831 | terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, |
| 832 | resultStride); |
| 833 | } |
| 834 | |
| 835 | rewriter.replaceOp(loopOp, |
| 836 | newLoop->getResults().take_front(loopOp.getNumResults())); |
| 837 | return cast<LoopLikeOpInterface>(newLoop.getOperation()); |
| 838 | } |
| 839 | |
| 840 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for |
| 841 | /// `LoopLikeOpInterface`, that just dispatches to the implementation for each |
| 842 | /// supported loop type. |
| 843 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( |
| 844 | LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, |
| 845 | ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { |
| 846 | return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( |
| 847 | loopLikeOp.getOperation()) |
| 848 | .Case<scf::ForOp, scf::ForallOp>( |
| 849 | [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
| 850 | return yieldTiledValuesAndReplaceLoop( |
| 851 | loopOp, rewriter, newInitOperands, yieldTiledValuesFn); |
| 852 | }) |
| 853 | .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
| 854 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type" ); |
| 855 | }); |
| 856 | } |
| 857 | |
| 858 | /// Method to add new init values to a loop nest. Updates `loops` in-place |
| 859 | /// with new loops that use the `newInitValues`. The outer-loops are updated |
| 860 | /// to yield the new result values of the inner loop. For the innermost loop, |
| 861 | /// the call back `getNewYields` is invoked to get the additional values to |
| 862 | /// yield form the innermost loop. |
| 863 | static LogicalResult addInitOperandsToLoopNest( |
| 864 | RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, |
| 865 | ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { |
| 866 | if (loops.empty()) |
| 867 | return success(); |
| 868 | OpBuilder::InsertionGuard g(rewriter); |
| 869 | rewriter.setInsertionPoint(loops.front()); |
| 870 | |
| 871 | SmallVector<Value> ivs; |
| 872 | for (auto &loop : loops.drop_back()) { |
| 873 | rewriter.setInsertionPoint(loop); |
| 874 | |
| 875 | // if loops.size() > 1 we assume that scf.for is used for the loops. |
| 876 | auto forLoop = cast<scf::ForOp>(loop.getOperation()); |
| 877 | |
| 878 | // Create a new loop with the new init values for this loop. |
| 879 | SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); |
| 880 | newInits.append(newInitValues.begin(), newInitValues.end()); |
| 881 | auto newLoop = rewriter.create<scf::ForOp>( |
| 882 | forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), |
| 883 | forLoop.getStep(), newInits, |
| 884 | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); |
| 885 | |
| 886 | // Merge the body of the new loop with the body of the old loops. |
| 887 | SmallVector<Value> sourceBlockArgs; |
| 888 | sourceBlockArgs.push_back(newLoop.getInductionVar()); |
| 889 | auto newRegionIterArgs = newLoop.getRegionIterArgs(); |
| 890 | sourceBlockArgs.append( |
| 891 | newRegionIterArgs.begin(), |
| 892 | std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); |
| 893 | rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); |
| 894 | rewriter.replaceOp( |
| 895 | forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); |
| 896 | loop = newLoop; |
| 897 | ivs.push_back(newLoop.getInductionVar()); |
| 898 | newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); |
| 899 | } |
| 900 | |
| 901 | // Update the loop body of the innermost loop to get new yield values. |
| 902 | LoopLikeOpInterface innerMostLoop = loops.back(); |
| 903 | FailureOr<LoopLikeOpInterface> newInnerMostLoop = |
| 904 | yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, |
| 905 | getNewTiledYieldsFn); |
| 906 | |
| 907 | if (failed(newInnerMostLoop)) |
| 908 | return innerMostLoop.emitOpError("failed to return additional yields" ); |
| 909 | loops.back() = newInnerMostLoop.value(); |
| 910 | |
| 911 | // Make all other loops except the innermost loops yield the values returned |
| 912 | // by the inner loop. |
| 913 | for (auto [outerLoop, innerLoop] : |
| 914 | llvm::zip_equal(loops.drop_back(), loops.drop_front())) { |
| 915 | // Again assume that all the outer loops are scf.for operations. |
| 916 | auto outerForLoop = cast<scf::ForOp>(outerLoop); |
| 917 | auto outerLoopYield = |
| 918 | cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); |
| 919 | SmallVector<Value> newYields = |
| 920 | llvm::to_vector(outerLoopYield.getOperands()); |
| 921 | ValueRange additionalYields = |
| 922 | innerLoop->getResults().take_back(newInitValues.size()); |
| 923 | newYields.append(additionalYields.begin(), additionalYields.end()); |
| 924 | rewriter.setInsertionPoint(outerLoopYield); |
| 925 | rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields); |
| 926 | } |
| 927 | return success(); |
| 928 | } |
| 929 | |
| 930 | /// Implementation of tiling transformation of `op` that implements the |
| 931 | /// `TilingInterface` using `scf.for` to iterate over the tiles. |
| 932 | FailureOr<scf::SCFTilingResult> |
| 933 | mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, |
| 934 | const scf::SCFTilingOptions &options) { |
| 935 | if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) { |
| 936 | return failure(); |
| 937 | } |
| 938 | |
| 939 | OpBuilder::InsertionGuard guard(rewriter); |
| 940 | rewriter.setInsertionPointAfter(op); |
| 941 | |
| 942 | // 1. Get the range of the loops that are represented by the operation. |
| 943 | SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); |
| 944 | |
| 945 | // 2. Materialize the tile sizes and/or number of threads; |
| 946 | SmallVector<OpFoldResult> tileSizes, numThreads; |
| 947 | std::tie(args&: tileSizes, args&: numThreads) = |
| 948 | getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); |
| 949 | |
| 950 | // Check if it is safe to tile. This is hold over from previous iterations |
| 951 | // of tile to for-all. Consider dropping it. |
| 952 | if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { |
| 953 | checkSafeToTileToForall(op, tileSizes, numThreads); |
| 954 | } |
| 955 | |
| 956 | // 3. If there is an interchange specified, permute the iteration domain and |
| 957 | // the tile sizes. |
| 958 | SmallVector<int64_t> interchangeVector; |
| 959 | if (!options.interchangeVector.empty()) { |
| 960 | interchangeVector = fillInterchangeVector(interchangeVector: options.interchangeVector, |
| 961 | iterationDomainSize: iterationDomain.size()); |
| 962 | assert(isPermutationVector(interchangeVector) && |
| 963 | "expected interchange vector to be a permutation" ); |
| 964 | |
| 965 | applyPermutationToVector(inVec&: iterationDomain, permutation: interchangeVector); |
| 966 | applyPermutationToVector(inVec&: tileSizes, permutation: interchangeVector); |
| 967 | if (!numThreads.empty()) |
| 968 | applyPermutationToVector(inVec&: numThreads, permutation: interchangeVector); |
| 969 | } |
| 970 | |
| 971 | FailureOr<TilingResult> tilingResult; |
| 972 | // 4. Define the lambda function used later to generate the body of the |
| 973 | // innermost tiled loop. |
| 974 | YieldTiledValuesFn innerYieldTiledValuesFn = |
| 975 | [&](RewriterBase &rewriter, Location loc, ValueRange ivs, |
| 976 | ValueRange regionIterArgs, SmallVector<Value> &tiledResults, |
| 977 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
| 978 | SmallVector<SmallVector<OpFoldResult>> &resultSizes) |
| 979 | -> LogicalResult { |
| 980 | // 4a. Compute the `offsets` and `sizes` to use for tiling. |
| 981 | SmallVector<OpFoldResult> offsets, sizes; |
| 982 | std::tie(args&: offsets, args&: sizes) = getTileOffsetAndSizes( |
| 983 | rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); |
| 984 | |
| 985 | // 4b. If interchange was provided, apply inverse of the interchange |
| 986 | // to get back the offsets/sizes in the order to be specified. |
| 987 | if (!interchangeVector.empty()) { |
| 988 | auto inversePermutation = invertPermutationVector(permutation: interchangeVector); |
| 989 | applyPermutationToVector(inVec&: offsets, permutation: inversePermutation); |
| 990 | applyPermutationToVector(inVec&: sizes, permutation: inversePermutation); |
| 991 | } |
| 992 | |
| 993 | // 5. Generate the tiled implementation within the inner most loop. |
| 994 | |
| 995 | // 5a. Clone the operation within the loop body. |
| 996 | auto clonedOp = cast<TilingInterface>( |
| 997 | cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); |
| 998 | |
| 999 | // 5b. Early return cloned op if tiling is not happening. We can not |
| 1000 | // return the original op because it could lead to `rewriter.replaceOp(op, |
| 1001 | // op->getResults())` and users would get crash. |
| 1002 | if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) { |
| 1003 | tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); |
| 1004 | tilingResult = |
| 1005 | TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), |
| 1006 | /*generatedSlices=*/{}}; |
| 1007 | return success(); |
| 1008 | } |
| 1009 | |
| 1010 | // 5c. Tile the cloned operation. |
| 1011 | tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, |
| 1012 | offsets, sizes, options); |
| 1013 | if (failed(Result: tilingResult)) { |
| 1014 | rewriter.eraseOp(op: clonedOp); |
| 1015 | return op.emitOpError("faild to tile operation" ); |
| 1016 | } |
| 1017 | |
| 1018 | // 5d. Delete the cloned operation. |
| 1019 | rewriter.eraseOp(op: clonedOp); |
| 1020 | |
| 1021 | // 5e. Compute the offsets at which the result values are to be inserted |
| 1022 | // back into its destinations. |
| 1023 | for (auto [index, tiledValue] : |
| 1024 | llvm::enumerate(First&: tilingResult->tiledValues)) { |
| 1025 | tiledResults.push_back(Elt: tiledValue); |
| 1026 | SmallVector<OpFoldResult> resultOffset, resultSize; |
| 1027 | if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, |
| 1028 | sizes, resultOffset, resultSize, |
| 1029 | options))) { |
| 1030 | for (auto op : tilingResult->tiledOps) { |
| 1031 | rewriter.eraseOp(op); |
| 1032 | } |
| 1033 | return rewriter.notifyMatchFailure( |
| 1034 | op, "failed to get slice of result produced" ); |
| 1035 | } |
| 1036 | resultOffsets.emplace_back(Args: std::move(resultOffset)); |
| 1037 | resultSizes.emplace_back(Args: std::move(resultSize)); |
| 1038 | } |
| 1039 | |
| 1040 | return success(); |
| 1041 | }; |
| 1042 | |
| 1043 | // 6. Find the destination tensors to use for the operation. |
| 1044 | FailureOr<SmallVector<Value>> maybeInits = |
| 1045 | createInitialTensorsForTiling(rewriter, op, tileSizes, options); |
| 1046 | if (failed(Result: maybeInits)) { |
| 1047 | return rewriter.notifyMatchFailure( |
| 1048 | op, "unable to create initial tensors for tiling" ); |
| 1049 | } |
| 1050 | SmallVector<Value> &initTensors = maybeInits.value(); |
| 1051 | |
| 1052 | // 7. Generate the tiled loops nest using the callback defined above. |
| 1053 | SmallVector<LoopLikeOpInterface> loops; |
| 1054 | if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, |
| 1055 | tileSizes, numThreads, initTensors, |
| 1056 | innerYieldTiledValuesFn, loops))) |
| 1057 | return op.emitOpError("failed to generate tiling loops" ); |
| 1058 | assert(succeeded(tilingResult) && |
| 1059 | "expected tiling result to be computed after loop generation" ); |
| 1060 | |
| 1061 | if (loops.empty()) { |
| 1062 | // If loops are empty, the tiled op is used as the replacement for the |
| 1063 | // untiled op. |
| 1064 | return scf::SCFTilingResult{tilingResult->tiledOps, |
| 1065 | initTensors, |
| 1066 | loops, |
| 1067 | tilingResult->tiledValues, |
| 1068 | tilingResult->generatedSlices, |
| 1069 | {}}; |
| 1070 | } |
| 1071 | |
| 1072 | auto loopResults = llvm::map_to_vector(loops.front()->getResults(), |
| 1073 | [](OpResult r) -> Value { return r; }); |
| 1074 | |
| 1075 | // For the full reduction case, there is nothing more to do. |
| 1076 | if (options.reductionStrategy == |
| 1077 | scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction) { |
| 1078 | return scf::SCFTilingResult{ |
| 1079 | tilingResult->tiledOps, initTensors, loops, loopResults, |
| 1080 | tilingResult->generatedSlices, {}}; |
| 1081 | } |
| 1082 | |
| 1083 | // The results of the loop needs to be merged. |
| 1084 | FailureOr<MergeResult> mergeResult = |
| 1085 | mergeTilingResults(rewriter, op, loopResults, options); |
| 1086 | if (failed(Result: mergeResult)) { |
| 1087 | return rewriter.notifyMatchFailure( |
| 1088 | op, "Failed to merge partial results from tiling" ); |
| 1089 | } |
| 1090 | return scf::SCFTilingResult{tilingResult->tiledOps, |
| 1091 | initTensors, |
| 1092 | loops, |
| 1093 | mergeResult->replacements, |
| 1094 | tilingResult->generatedSlices, |
| 1095 | mergeResult->mergeOps}; |
| 1096 | } |
| 1097 | |
| 1098 | FailureOr<scf::SCFTilingResult> |
| 1099 | mlir::scf::tileReductionUsingScf(RewriterBase &b, |
| 1100 | PartialReductionOpInterface op, |
| 1101 | ArrayRef<OpFoldResult> tileSize) { |
| 1102 | scf::SCFTilingOptions options; |
| 1103 | options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); |
| 1104 | options.setReductionTilingStrategy( |
| 1105 | scf::SCFTilingOptions::ReductionTilingStrategy:: |
| 1106 | PartialReductionOuterReduction); |
| 1107 | options.setTileSizes(tileSize); |
| 1108 | return tileUsingSCF(b, op, options); |
| 1109 | } |
| 1110 | |
| 1111 | //===----------------------------------------------------------------------===// |
| 1112 | // tileConsumerAndFuseProducersUsingSCF implementation. |
| 1113 | //===----------------------------------------------------------------------===// |
| 1114 | |
| 1115 | /// Return the untiled producer whose slice is used in a tiled consumer. The |
| 1116 | /// method traverses the tile loop nest (`loops`) if needed, and returns the |
| 1117 | /// `iter_args` of the outer most that is encountered. Traversing the |
| 1118 | /// iter_args indicates that this is a destination operand of the consumer. If |
| 1119 | /// there was no loop traversal needed, the second value of the returned tuple |
| 1120 | /// is empty. |
| 1121 | static std::tuple<OpResult, std::optional<OpOperand *>> |
| 1122 | getUntiledProducerFromSliceSource(OpOperand *source, |
| 1123 | ArrayRef<LoopLikeOpInterface> loops) { |
| 1124 | std::optional<OpOperand *> destinationIterArg; |
| 1125 | assert(!loops.empty() && "expected non empty loops container" ); |
| 1126 | auto loopIt = loops.rbegin(); |
| 1127 | while (loopIt != loops.rend() && isa<BlockArgument>(Val: source->get())) { |
| 1128 | auto iterArg = cast<BlockArgument>(Val: source->get()); |
| 1129 | auto loop = *loopIt; |
| 1130 | if (iterArg.getOwner()->getParentOp() != loop) |
| 1131 | break; |
| 1132 | source = loop.getTiedLoopInit(iterArg); |
| 1133 | loopIt++; |
| 1134 | } |
| 1135 | if (loopIt == loops.rend()) |
| 1136 | destinationIterArg = source; |
| 1137 | return {dyn_cast<OpResult>(Val: source->get()), destinationIterArg}; |
| 1138 | } |
| 1139 | |
| 1140 | /// Implementation of fusing producer of a single slice by computing the |
| 1141 | /// slice of the producer in-place. |
| 1142 | std::optional<scf::SCFFuseProducerOfSliceResult> |
| 1143 | mlir::scf::tileAndFuseProducerOfSlice( |
| 1144 | RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, |
| 1145 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1146 | // 1. Get the producer of the source (potentially walking through |
| 1147 | // `iter_args` of nested `scf.for`) |
| 1148 | auto [fusableProducer, destinationInitArg] = |
| 1149 | getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), |
| 1150 | loops); |
| 1151 | if (!fusableProducer) |
| 1152 | return std::nullopt; |
| 1153 | unsigned resultNumber = fusableProducer.getResultNumber(); |
| 1154 | |
| 1155 | OpBuilder::InsertionGuard g(rewriter); |
| 1156 | rewriter.setInsertionPoint(candidateSliceOp); |
| 1157 | |
| 1158 | // 2. Clone the fused producer |
| 1159 | // 2a. Compute the destination operands to use for the cloned operation. |
| 1160 | SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; |
| 1161 | Operation *fusableProducerOp = fusableProducer.getOwner(); |
| 1162 | if (isa<DestinationStyleOpInterface>(fusableProducerOp) && |
| 1163 | failed(tensor::getOrCreateDestinations( |
| 1164 | rewriter, fusableProducerOp->getLoc(), fusableProducerOp, |
| 1165 | origDestinationTensors))) |
| 1166 | return std::nullopt; |
| 1167 | |
| 1168 | clonedOpDestinationTensors = origDestinationTensors; |
| 1169 | if (destinationInitArg && |
| 1170 | isa<DestinationStyleOpInterface>(fusableProducerOp)) { |
| 1171 | // 2b. If the producer is also destination style, then to maintain the |
| 1172 | // destination passing style, update the destination of the producer to be |
| 1173 | // the source of the slice. |
| 1174 | clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); |
| 1175 | } |
| 1176 | // 2c. Clone the fused producer. |
| 1177 | Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( |
| 1178 | rewriter, op: fusableProducerOp, newDestArgs: clonedOpDestinationTensors); |
| 1179 | // 2d. Update the source of the candidateSlice to be the cloned producer. |
| 1180 | // Easier to just clone the slice with different source since |
| 1181 | // replacements and DCE of cloned ops becomes easier |
| 1182 | SmallVector<Value> candidateSliceOpOperands = |
| 1183 | llvm::to_vector(candidateSliceOp->getOperands()); |
| 1184 | candidateSliceOpOperands[0] = clonedProducerOp->getResult(idx: resultNumber); |
| 1185 | tensor::ExtractSliceOp clonedCandidateSliceOp = |
| 1186 | mlir::clone(rewriter, candidateSliceOp, |
| 1187 | candidateSliceOp->getResultTypes(), candidateSliceOpOperands); |
| 1188 | |
| 1189 | // 3. Generate the tiled implementation of the producer of the source |
| 1190 | FailureOr<TilingResult> tileAndFuseResult = |
| 1191 | tensor::replaceExtractSliceWithTiledProducer( |
| 1192 | rewriter, clonedCandidateSliceOp, |
| 1193 | clonedProducerOp->getResult(idx: resultNumber)); |
| 1194 | if (failed(Result: tileAndFuseResult)) |
| 1195 | return std::nullopt; |
| 1196 | // Note: Do not delete the candidateSliceOp, since its passed in from the |
| 1197 | // caller. |
| 1198 | rewriter.replaceAllUsesWith(candidateSliceOp, |
| 1199 | tileAndFuseResult->tiledValues[0]); |
| 1200 | rewriter.eraseOp(op: clonedCandidateSliceOp); |
| 1201 | rewriter.eraseOp(op: clonedProducerOp); |
| 1202 | |
| 1203 | // 3. If the slice is for a destination operand, for example, |
| 1204 | // |
| 1205 | // ```mlir |
| 1206 | // %0 = linalg.init |
| 1207 | // %1 = linalg.fill .. outs(%0 : ) |
| 1208 | // %2 = scf.for .. iter_args(%arg0 = %1) { |
| 1209 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
| 1210 | // %4 = tensor.extract_slice %arg1 [..] |
| 1211 | // .. = linalg.matmul .. outs(%4 : ) |
| 1212 | // } |
| 1213 | // } |
| 1214 | // ``` |
| 1215 | // |
| 1216 | // the IR is currently |
| 1217 | // |
| 1218 | // ``` |
| 1219 | // %0 = linalg.init |
| 1220 | // %1 = linalg.fill |
| 1221 | // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { |
| 1222 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
| 1223 | // %4 = tensor.extract_slice %arg1[..] |
| 1224 | // %5 = linalg.fill .. outs(%4 : ) |
| 1225 | // .. = linalg.matmul .. outs(%5 : ) |
| 1226 | // } |
| 1227 | // } |
| 1228 | // ``` |
| 1229 | // |
| 1230 | // The untiled `linalg.fill` is still used as the `init_value` since it |
| 1231 | // was originally a destination operand of the untiled `linalg.matmul`. |
| 1232 | // When fusing an operand that is a destination operand, the iter_arg of |
| 1233 | // the outer most loop should be changed to use the destination of the |
| 1234 | // fused operation. With this the IR will be. |
| 1235 | // |
| 1236 | // ``` |
| 1237 | // %0 = linalg.init |
| 1238 | // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { |
| 1239 | // %2 = scf.for .. iter_args(%arg1 = %arg0) { |
| 1240 | // %3 = tensor.extract_slice %arg1[..] |
| 1241 | // %4 = linalg.fill .. outs(%3 : ) |
| 1242 | // .. = linalg.matmul .. outs(%4 : ) |
| 1243 | // } |
| 1244 | // } |
| 1245 | // ``` |
| 1246 | if (destinationInitArg && |
| 1247 | isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) { |
| 1248 | loops.front() |
| 1249 | ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] |
| 1250 | .set(origDestinationTensors[resultNumber]); |
| 1251 | } |
| 1252 | return scf::SCFFuseProducerOfSliceResult{ |
| 1253 | fusableProducer, tileAndFuseResult->tiledValues[0], |
| 1254 | tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; |
| 1255 | } |
| 1256 | |
| 1257 | /// Reconstruct the fused producer from within the tiled-and-fused code. |
| 1258 | FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer( |
| 1259 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, |
| 1260 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, |
| 1261 | MutableArrayRef<LoopLikeOpInterface> loops, |
| 1262 | ArrayRef<unsigned> yieldResultNumber) { |
| 1263 | if (loops.empty()) |
| 1264 | return success(); |
| 1265 | |
| 1266 | Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), |
| 1267 | *tiledOwner = fusedProducerInfo.tiledOps[0]; |
| 1268 | |
| 1269 | Location loc = originalOwner->getLoc(); |
| 1270 | // a. collect all init Value to be appended |
| 1271 | SmallVector<unsigned> initNumberList = |
| 1272 | yieldResultNumber.empty() ? llvm::to_vector(Range: llvm::seq<unsigned>( |
| 1273 | Begin: 0, End: originalOwner->getNumResults())) |
| 1274 | : llvm::to_vector(Range&: yieldResultNumber); |
| 1275 | SmallVector<Value> initValueList; |
| 1276 | for (const auto &resultNumber : initNumberList) { |
| 1277 | FailureOr<Value> initValue = tensor::getOrCreateDestination( |
| 1278 | b&: rewriter, loc, opResult: originalOwner->getResult(idx: resultNumber)); |
| 1279 | if (succeeded(Result: initValue)) { |
| 1280 | initValueList.push_back(Elt: initValue.value()); |
| 1281 | } else { |
| 1282 | return failure(); |
| 1283 | } |
| 1284 | } |
| 1285 | |
| 1286 | SmallVector<Operation *> generatedSlices; |
| 1287 | YieldTiledValuesFn newYieldValuesFn = |
| 1288 | [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, |
| 1289 | ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, |
| 1290 | SmallVector<SmallVector<OpFoldResult>> &tiledOffset, |
| 1291 | SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { |
| 1292 | OpBuilder::InsertionGuard g(innerRewriter); |
| 1293 | |
| 1294 | // get sliceOp tile information |
| 1295 | SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), |
| 1296 | sliceSizes = sliceOp.getMixedSizes(); |
| 1297 | |
| 1298 | // expect all strides of sliceOp being 1 |
| 1299 | if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) |
| 1300 | return failure(); |
| 1301 | |
| 1302 | unsigned sliceResultNumber = |
| 1303 | fusedProducerInfo.origProducer.getResultNumber(); |
| 1304 | |
| 1305 | auto tilableOp = cast<TilingInterface>(originalOwner); |
| 1306 | // b. get iterDomain Offset and Sizes based on sliceOp tile |
| 1307 | SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; |
| 1308 | // skip tensor.pack/unpack/pad, which expects single opResult |
| 1309 | if (tilableOp->getNumResults() > 1 && |
| 1310 | failed(tilableOp.getIterationDomainTileFromResultTile( |
| 1311 | rewriter, sliceResultNumber, sliceOffset, sliceSizes, |
| 1312 | iterDomainOffset, iterDomainSizes))) { |
| 1313 | // In theory, it is unnecessary to raise an error here. Actually |
| 1314 | // although it fails to reconstruct the result tensor, it should not |
| 1315 | // broke current fusion anyway. The reason why we must return failure |
| 1316 | // currently is that the callback function `newYieldValuesFn` will be |
| 1317 | // called after new init operand(s) has already been appended. It will |
| 1318 | // take more refactoring to make sure the init operands are added |
| 1319 | // consistently in the future. For more details, please refer to: |
| 1320 | // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 |
| 1321 | return failure(); |
| 1322 | } |
| 1323 | |
| 1324 | // c. calculate offsets and sizes info of all OpResults respectively based |
| 1325 | // on iteration Domain Tile |
| 1326 | SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; |
| 1327 | for (const auto &resultNumber : initNumberList) { |
| 1328 | if (resultNumber == sliceResultNumber) { |
| 1329 | offsetList.push_back(Elt: sliceOffset); |
| 1330 | sizesList.push_back(Elt: sliceSizes); |
| 1331 | } else { |
| 1332 | assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); |
| 1333 | // infer result tile according to the iteration domain tile |
| 1334 | SmallVector<OpFoldResult> offset, sizes; |
| 1335 | if (failed(tilableOp.getResultTilePosition( |
| 1336 | rewriter, resultNumber, iterDomainOffset, iterDomainSizes, |
| 1337 | offset, sizes))) { |
| 1338 | return failure(); |
| 1339 | } |
| 1340 | offsetList.push_back(Elt: offset); |
| 1341 | sizesList.push_back(Elt: sizes); |
| 1342 | } |
| 1343 | } |
| 1344 | |
| 1345 | // d. create `extract_slice` for `iter_args` for DPS operation if |
| 1346 | // necessary |
| 1347 | if (auto tiledDestStyleOp = |
| 1348 | dyn_cast<DestinationStyleOpInterface>(tiledOwner)) { |
| 1349 | rewriter.setInsertionPoint(tiledDestStyleOp); |
| 1350 | for (const auto &&[index, newRegionArg] : |
| 1351 | llvm::enumerate(First&: newRegionIterArgs)) { |
| 1352 | auto destSlice = rewriter.create<tensor::ExtractSliceOp>( |
| 1353 | loc, newRegionArg, offsetList[index], sizesList[index], |
| 1354 | SmallVector<OpFoldResult>(offsetList[index].size(), |
| 1355 | rewriter.getIndexAttr(1))); |
| 1356 | generatedSlices.push_back(Elt: destSlice); |
| 1357 | unsigned resultNumber = initNumberList[index]; |
| 1358 | rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { |
| 1359 | tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); |
| 1360 | }); |
| 1361 | } |
| 1362 | } |
| 1363 | |
| 1364 | // e. prepare tiled offset and sizes for later `insert_slice` creation by |
| 1365 | // caller |
| 1366 | Block *block = rewriter.getInsertionPoint()->getBlock(); |
| 1367 | rewriter.setInsertionPoint(block->getTerminator()); |
| 1368 | for (const auto &&[index, resultNumber] : llvm::enumerate(First&: initNumberList)) { |
| 1369 | tiledResult.push_back(Elt: tiledOwner->getResult(idx: resultNumber)); |
| 1370 | tiledOffset.emplace_back(Args&: offsetList[index]); |
| 1371 | tiledSizes.emplace_back(Args&: sizesList[index]); |
| 1372 | } |
| 1373 | return success(); |
| 1374 | }; |
| 1375 | |
| 1376 | if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, |
| 1377 | newYieldValuesFn))) { |
| 1378 | return failure(); |
| 1379 | } |
| 1380 | return generatedSlices; |
| 1381 | } |
| 1382 | |
| 1383 | namespace { |
| 1384 | |
| 1385 | //===----------------------------------------------------------------------===// |
| 1386 | // SliceTrackingListener |
| 1387 | //===----------------------------------------------------------------------===// |
| 1388 | |
| 1389 | /// This class is a listener for tracking the insertion and removal of |
| 1390 | /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy |
| 1391 | /// fusion algorithm to apply cleanup patterns in between fusion steps. |
| 1392 | class SliceTrackingListener : public RewriterBase::Listener { |
| 1393 | public: |
| 1394 | explicit SliceTrackingListener( |
| 1395 | std::optional<FrozenRewritePatternSet> patterns); |
| 1396 | SliceTrackingListener() = default; |
| 1397 | |
| 1398 | /// Adds the given list of operations to the worklist, and if present, |
| 1399 | /// applies the list of `patterns` to the newly added operations. This only |
| 1400 | /// processes the given operations and any newly inserted ones by the |
| 1401 | /// pattern set. |
| 1402 | LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps); |
| 1403 | |
| 1404 | /// Add to the new operation worklist if it is an extract_slice. |
| 1405 | void notifyOperationInserted(Operation *op, |
| 1406 | OpBuilder::InsertPoint previous) override; |
| 1407 | |
| 1408 | /// Shared helper for operation removal from the worklist. |
| 1409 | void removeOp(Operation *op); |
| 1410 | |
| 1411 | /// Remove the operation from the worklist. |
| 1412 | void notifyOperationErased(Operation *op) override; |
| 1413 | |
| 1414 | /// Remove the operation from the worklist. |
| 1415 | void notifyOperationReplaced(Operation *op, ValueRange replacement) override; |
| 1416 | |
| 1417 | /// The worklist for this transformation keeps track of the slices to visit |
| 1418 | /// next for fusion. |
| 1419 | std::deque<tensor::ExtractSliceOp> worklist; |
| 1420 | |
| 1421 | private: |
| 1422 | /// Optional pattern set to apply when adding new operations to the |
| 1423 | /// worklist. |
| 1424 | std::optional<FrozenRewritePatternSet> patterns = std::nullopt; |
| 1425 | }; |
| 1426 | |
| 1427 | SliceTrackingListener::SliceTrackingListener( |
| 1428 | std::optional<FrozenRewritePatternSet> p) { |
| 1429 | patterns = std::move(p); |
| 1430 | } |
| 1431 | |
| 1432 | LogicalResult |
| 1433 | SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) { |
| 1434 | for (Operation *op : ops) { |
| 1435 | if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op)) |
| 1436 | worklist.push_back(slice); |
| 1437 | } |
| 1438 | |
| 1439 | if (!patterns) |
| 1440 | return success(); |
| 1441 | |
| 1442 | return applyOpPatternsGreedily( |
| 1443 | ops, patterns: patterns.value(), |
| 1444 | config: GreedyRewriteConfig().setListener(this).setStrictness( |
| 1445 | GreedyRewriteStrictness::ExistingAndNewOps)); |
| 1446 | } |
| 1447 | |
| 1448 | void SliceTrackingListener::notifyOperationInserted( |
| 1449 | Operation *op, OpBuilder::InsertPoint previous) { |
| 1450 | auto slice = dyn_cast<tensor::ExtractSliceOp>(op); |
| 1451 | if (!slice) |
| 1452 | return; |
| 1453 | worklist.push_back(slice); |
| 1454 | } |
| 1455 | |
| 1456 | // Scan the worklist for the given op and remove it if present. The |
| 1457 | // expectation is for the worklist to be small and for removal to be |
| 1458 | // relatively rare. |
| 1459 | void SliceTrackingListener::removeOp(Operation *op) { |
| 1460 | if (!isa<tensor::ExtractSliceOp>(op)) |
| 1461 | return; |
| 1462 | auto iter = worklist.begin(); |
| 1463 | while (iter != worklist.end()) { |
| 1464 | if (*iter == op) |
| 1465 | break; |
| 1466 | iter++; |
| 1467 | } |
| 1468 | if (iter == worklist.end()) |
| 1469 | return; |
| 1470 | |
| 1471 | worklist.erase(iter); |
| 1472 | } |
| 1473 | |
| 1474 | void SliceTrackingListener::notifyOperationErased(Operation *op) { |
| 1475 | removeOp(op); |
| 1476 | } |
| 1477 | |
| 1478 | void SliceTrackingListener::notifyOperationReplaced(Operation *op, |
| 1479 | ValueRange replacement) { |
| 1480 | removeOp(op); |
| 1481 | } |
| 1482 | |
| 1483 | //===----------------------------------------------------------------------===// |
| 1484 | // ReplacementListener |
| 1485 | //===----------------------------------------------------------------------===// |
| 1486 | |
| 1487 | /// Listener that tracks updates replacements for values which can be mutated. |
| 1488 | /// This listener runs on top of the existing listener for the rewriter, |
| 1489 | /// to make sure external users can still run listeners. |
| 1490 | class ReplacementListener : public RewriterBase::ForwardingListener { |
| 1491 | public: |
| 1492 | ReplacementListener(DenseMap<Value, Value> &replacements, |
| 1493 | OpBuilder::Listener *listener) |
| 1494 | : ForwardingListener(listener), replacements(replacements) {} |
| 1495 | |
| 1496 | void updateReplacementValues(ValueRange origValues, |
| 1497 | ValueRange replaceValues) { |
| 1498 | // This can probably be written better, but just iterates over the map |
| 1499 | // and the new replacements for now. |
| 1500 | for (auto &[key, val] : replacements) { |
| 1501 | for (auto [orig, replace] : llvm::zip_equal(t&: origValues, u&: replaceValues)) { |
| 1502 | if (val == orig) { |
| 1503 | val = replace; |
| 1504 | } |
| 1505 | } |
| 1506 | } |
| 1507 | } |
| 1508 | |
| 1509 | void notifyOperationReplaced(Operation *op, Operation *newOp) override { |
| 1510 | ForwardingListener::notifyOperationReplaced(op, newOp); |
| 1511 | updateReplacementValues(origValues: op->getResults(), replaceValues: newOp->getResults()); |
| 1512 | } |
| 1513 | |
| 1514 | void notifyOperationReplaced(Operation *op, ValueRange values) override { |
| 1515 | ForwardingListener::notifyOperationReplaced(op, replacement: values); |
| 1516 | updateReplacementValues(origValues: op->getResults(), replaceValues: values); |
| 1517 | } |
| 1518 | |
| 1519 | private: |
| 1520 | DenseMap<Value, Value> &replacements; |
| 1521 | }; |
| 1522 | |
| 1523 | } // namespace |
| 1524 | |
| 1525 | /// Implementation of tile consumer and fuse producer greedily. |
| 1526 | FailureOr<scf::SCFTileAndFuseResult> |
| 1527 | mlir::scf::tileConsumerAndFuseProducersUsingSCF( |
| 1528 | RewriterBase &rewriter, TilingInterface consumer, |
| 1529 | const scf::SCFTileAndFuseOptions &options) { |
| 1530 | // This transformation is only valid for ops that return values (i.e. not |
| 1531 | // valid to use with operations that have memref operands). |
| 1532 | if (!consumer->getNumResults()) { |
| 1533 | return rewriter.notifyMatchFailure( |
| 1534 | consumer, "invalid pattern for op with no results" ); |
| 1535 | } |
| 1536 | |
| 1537 | // 1. First tile the consumer. |
| 1538 | SetVector<Operation *> fusedProducers, tiledAndFusedOps; |
| 1539 | |
| 1540 | FailureOr<scf::SCFTilingResult> tilingResult = |
| 1541 | tileUsingSCF(rewriter, consumer, options.tilingOptions); |
| 1542 | |
| 1543 | if (failed(Result: tilingResult)) |
| 1544 | return rewriter.notifyMatchFailure(consumer, "failed to tile consumer" ); |
| 1545 | tiledAndFusedOps.insert_range(tilingResult->tiledOps); |
| 1546 | |
| 1547 | DenseMap<Value, Value> replacements; |
| 1548 | for (auto [origVal, replacement] : |
| 1549 | llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { |
| 1550 | replacements[origVal] = replacement; |
| 1551 | } |
| 1552 | |
| 1553 | // If there are no loops generated, fusion is immaterial. |
| 1554 | auto &loops = tilingResult->loops; |
| 1555 | if (loops.empty()) { |
| 1556 | return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, |
| 1557 | replacements}; |
| 1558 | } |
| 1559 | |
| 1560 | // Since the loop gets potentially replaced during fusion, we need to track |
| 1561 | // the mutation of replacement values. To do this, we attach a listener to |
| 1562 | // update the replacements as they happen. |
| 1563 | OpBuilder::Listener *previousListener = rewriter.getListener(); |
| 1564 | auto resetListener = |
| 1565 | llvm::make_scope_exit(F: [&]() { rewriter.setListener(previousListener); }); |
| 1566 | ReplacementListener replaceListener(replacements, previousListener); |
| 1567 | rewriter.setListener(&replaceListener); |
| 1568 | |
| 1569 | // 2. Typically, the operands of the tiled operation are slices of the |
| 1570 | // operands of the untiled operation. These are expressed in IR using |
| 1571 | // `tensor.extract_slice` operations with source being the operands of |
| 1572 | // the untiled operation. Create a worklist of these |
| 1573 | // `tensor.extract_slice` operations. If the producers of the source of |
| 1574 | // the `tensor.extract_slice` can be tiled such that the tiled value is |
| 1575 | // generated in-place, that effectively tiles + fuses the operations. |
| 1576 | struct WorklistItem { |
| 1577 | tensor::ExtractSliceOp candidateSlice; |
| 1578 | SCFTileAndFuseOptions::ControlFnResult controlFnResult; |
| 1579 | }; |
| 1580 | |
| 1581 | SliceTrackingListener sliceTracker = |
| 1582 | SliceTrackingListener(options.cleanupPatterns); |
| 1583 | |
| 1584 | if (failed( |
| 1585 | sliceTracker.insertAndApplyPatterns(ops: tilingResult->generatedSlices))) { |
| 1586 | return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed" ); |
| 1587 | } |
| 1588 | OpBuilder::InsertionGuard g(rewriter); |
| 1589 | while (!sliceTracker.worklist.empty()) { |
| 1590 | auto candidateSlice = sliceTracker.worklist.front(); |
| 1591 | sliceTracker.worklist.pop_front(); |
| 1592 | |
| 1593 | auto [fusableProducer, destinationInitArg] = |
| 1594 | getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), |
| 1595 | loops); |
| 1596 | if (!fusableProducer) |
| 1597 | continue; |
| 1598 | |
| 1599 | std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult = |
| 1600 | options.fusionControlFn(candidateSlice, fusableProducer, |
| 1601 | destinationInitArg.has_value()); |
| 1602 | if (!controlFnResult) |
| 1603 | continue; |
| 1604 | |
| 1605 | WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; |
| 1606 | |
| 1607 | // The operands of the fused producer might themselved be slices of |
| 1608 | // values produced by operations that implement the `TilingInterface`. |
| 1609 | // Add these operations to the worklist. |
| 1610 | std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = |
| 1611 | tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, |
| 1612 | loops); |
| 1613 | if (!fusedResult) |
| 1614 | continue; |
| 1615 | |
| 1616 | SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices; |
| 1617 | |
| 1618 | if (worklistItem.controlFnResult.yieldProducerReplacement) { |
| 1619 | // Reconstruct and yield all opResult of fusableProducerOp by default. |
| 1620 | // The caller can specific which one to yield by designating optional |
| 1621 | // argument named `yieldResultNumber` of |
| 1622 | // `yieldReplacementForFusedProducer`. |
| 1623 | Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); |
| 1624 | FailureOr<SmallVector<Operation *>> newSlices = |
| 1625 | yieldReplacementForFusedProducer(rewriter, |
| 1626 | worklistItem.candidateSlice, |
| 1627 | fusedResult.value(), loops); |
| 1628 | if (failed(Result: newSlices)) { |
| 1629 | return rewriter.notifyMatchFailure( |
| 1630 | arg&: fusableProducerOp, msg: "failed to replacement value for this " |
| 1631 | "operation from within the tiled loop" ); |
| 1632 | } |
| 1633 | worklistCandidates.append(RHS: newSlices.value()); |
| 1634 | for (auto [index, result] : |
| 1635 | llvm::enumerate(First: fusableProducerOp->getResults())) { |
| 1636 | replacements[result] = loops.front()->getResult( |
| 1637 | loops.front()->getNumResults() - |
| 1638 | fusableProducerOp->getNumResults() + index); |
| 1639 | } |
| 1640 | } |
| 1641 | if (Operation *tiledAndFusedOp = |
| 1642 | fusedResult->tiledAndFusedProducer.getDefiningOp()) { |
| 1643 | fusedProducers.insert(X: fusedResult->origProducer.getDefiningOp()); |
| 1644 | tiledAndFusedOps.insert(X: tiledAndFusedOp); |
| 1645 | } |
| 1646 | |
| 1647 | if (failed(Result: sliceTracker.insertAndApplyPatterns(ops: worklistCandidates))) { |
| 1648 | return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed" ); |
| 1649 | } |
| 1650 | } |
| 1651 | |
| 1652 | return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, |
| 1653 | replacements}; |
| 1654 | } |
| 1655 | |
| 1656 | //===----------------------------------------------------------------------===// |
| 1657 | // tileAndFuseConsumerUsingSCF implementation. |
| 1658 | //===----------------------------------------------------------------------===// |
| 1659 | |
| 1660 | /// A utility function that checks whether the only use of the result of a |
| 1661 | /// tensor.insert_slice op is in a scf.yield op. |
| 1662 | static LogicalResult |
| 1663 | checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { |
| 1664 | Value result = candidateSliceOp.getResult(); |
| 1665 | Value::use_range uses = result.getUses(); |
| 1666 | if (!llvm::hasSingleElement(C&: uses)) { |
| 1667 | LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n" ); |
| 1668 | return failure(); |
| 1669 | } |
| 1670 | OpOperand &operandUse = (*uses.begin()); |
| 1671 | Operation *userOp = operandUse.getOwner(); |
| 1672 | if (!isa<scf::YieldOp>(userOp)) { |
| 1673 | LLVM_DEBUG(llvm::dbgs() |
| 1674 | << "Expected scf.yield to be the only user, but got -> " |
| 1675 | << (*userOp)); |
| 1676 | return failure(); |
| 1677 | } |
| 1678 | if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { |
| 1679 | LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " |
| 1680 | "be in the same block\n" ); |
| 1681 | return failure(); |
| 1682 | } |
| 1683 | return success(); |
| 1684 | } |
| 1685 | |
| 1686 | /// An utility to get the first user of the given loopOp. If any of user stay |
| 1687 | /// in different block of loopOp, return failure. |
| 1688 | static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) { |
| 1689 | if (!isa<LoopLikeOpInterface>(loopOp)) |
| 1690 | return failure(); |
| 1691 | Operation *firstUserOfLoop = nullptr; |
| 1692 | for (Operation *userOp : loopOp->getUsers()) { |
| 1693 | // `ParallelInsertSlice` located inside `InParallelOp` has no same parent |
| 1694 | // block with any other types of operation. Thus, just redirecting to its |
| 1695 | // parent `InParallelOp`. E.g. |
| 1696 | // |
| 1697 | // ``` |
| 1698 | // %1 = scf.for { |
| 1699 | // ... |
| 1700 | // } |
| 1701 | // %2 = consumerOp ins(%1, ...) |
| 1702 | // scf.forall.in_parallel { |
| 1703 | // tensor.parallel_insert_slice %1 |
| 1704 | // } |
| 1705 | // ``` |
| 1706 | // where `InParallelOp` but not `ParallelInsertSlice` stays in the same |
| 1707 | // same block with `consumerOp`. |
| 1708 | if (isa<tensor::ParallelInsertSliceOp>(userOp)) |
| 1709 | userOp = userOp->getParentOfType<scf::InParallelOp>(); |
| 1710 | |
| 1711 | if (loopOp->getBlock() != userOp->getBlock()) |
| 1712 | return failure(); |
| 1713 | |
| 1714 | if (!firstUserOfLoop || userOp->isBeforeInBlock(other: firstUserOfLoop)) |
| 1715 | firstUserOfLoop = userOp; |
| 1716 | } |
| 1717 | return firstUserOfLoop; |
| 1718 | } |
| 1719 | |
| 1720 | /// This utility currently checks whether the first userOp of loop is NOT |
| 1721 | /// before the last defineOp of consumer operand. Because that we need to move |
| 1722 | /// the whole loop structure right before the `firstUserOfLoop`. This utility |
| 1723 | /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice |
| 1724 | /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that: |
| 1725 | /// |
| 1726 | /// ``` |
| 1727 | /// %0 = scf.for() { |
| 1728 | /// ... |
| 1729 | /// } |
| 1730 | /// ... |
| 1731 | /// %1 = firstUserOfLoop(%0) |
| 1732 | /// ... |
| 1733 | /// %2 = lastDefOfConsumerOperand |
| 1734 | /// ... |
| 1735 | /// %3 = consumerOp(%2) |
| 1736 | /// ``` |
| 1737 | /// |
| 1738 | /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it |
| 1739 | /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`, |
| 1740 | /// a.k.a. use-def chain violation: |
| 1741 | /// |
| 1742 | /// ``` |
| 1743 | /// %0:2 = scf.for() { |
| 1744 | /// // use before define error |
| 1745 | /// %3 = tiledConsumerOp(%2) |
| 1746 | /// } |
| 1747 | /// %1 = firstUserOfLoop(%0) |
| 1748 | /// ... |
| 1749 | /// %2 = lastDefOfConsumerOperand |
| 1750 | /// ``` |
| 1751 | /// |
| 1752 | /// @param loopOp: loop operation |
| 1753 | /// @param consumerOp: consumer operation |
| 1754 | /// @param reorderOperations: the flag controls whether to reorder the |
| 1755 | /// backward slice w.r.t. the defineOp of `consumerOp` operands. |
| 1756 | /// @return: computed backward slice of consumerOp, but excluding those |
| 1757 | /// already dominates `firstUserOfLoop`. |
| 1758 | static FailureOr<llvm::SetVector<Operation *>> |
| 1759 | checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, |
| 1760 | bool reorderOperations) { |
| 1761 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); |
| 1762 | if (failed(Result: firstUserOfLoop)) |
| 1763 | return failure(); |
| 1764 | |
| 1765 | BackwardSliceOptions options; |
| 1766 | DominanceInfo dominanceInfo; |
| 1767 | options.inclusive = true; |
| 1768 | options.omitBlockArguments = true; |
| 1769 | bool includeLoopOp = false; |
| 1770 | options.filter = [&](Operation *op) { |
| 1771 | if (op == loopOp) { |
| 1772 | includeLoopOp = true; |
| 1773 | return false; |
| 1774 | } |
| 1775 | // Cut off the slice to not include any operation that already dominates |
| 1776 | // firstUserOfLoop. |
| 1777 | return !dominanceInfo.properlyDominates(a: op, b: *firstUserOfLoop); |
| 1778 | }; |
| 1779 | llvm::SetVector<Operation *> slice; |
| 1780 | for (auto operand : consumerOp->getOperands()) { |
| 1781 | LogicalResult result = getBackwardSlice(root: operand, backwardSlice: &slice, options); |
| 1782 | assert(result.succeeded() && "expected a backward slice" ); |
| 1783 | (void)result; |
| 1784 | } |
| 1785 | |
| 1786 | if (!slice.empty()) { |
| 1787 | // If consumerOp has one producer, which is also the user of loopOp. |
| 1788 | // E.g. |
| 1789 | // ``` |
| 1790 | // %0 = %loopOp |
| 1791 | // %1 = consumerOp1 ins(%0) |
| 1792 | // %2 = consumerOp2 ins(%0, %1) |
| 1793 | // ``` |
| 1794 | // We can not fuse consumerOp2 into loopOp due to UD chain, unless |
| 1795 | // consumerOp1 has already been fused into loopOp before. |
| 1796 | if (includeLoopOp || !reorderOperations) |
| 1797 | return failure(); |
| 1798 | } |
| 1799 | |
| 1800 | return slice; |
| 1801 | } |
| 1802 | |
| 1803 | /// Fetches the OpOperand of the first valid user (and use) of the value `val` |
| 1804 | /// which implements `TilingInterface` and `DestinationStyleOpInterface`. |
| 1805 | /// Returns failure otherwise. |
| 1806 | static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, |
| 1807 | Operation *loopOp, |
| 1808 | unsigned resultNumber) { |
| 1809 | if (!isa<LoopLikeOpInterface>(loopOp)) |
| 1810 | return failure(); |
| 1811 | Value val = loopOp->getResult(idx: resultNumber); |
| 1812 | Block *loopBlock = loopOp->getBlock(); |
| 1813 | for (OpOperand &opOperand : val.getUses()) { |
| 1814 | Operation *consumerOp = opOperand.getOwner(); |
| 1815 | // Step 1. Check if the user is tilable. |
| 1816 | if (!isa<TilingInterface>(consumerOp) || |
| 1817 | !isa<DestinationStyleOpInterface>(consumerOp)) { |
| 1818 | // TODO: We have to init result of consumer before scf.for, use |
| 1819 | // DestinationStyleOpInterface to get result shape from init for now. |
| 1820 | // Add support for other op such as op has InferTypeOpInterface. |
| 1821 | continue; |
| 1822 | } |
| 1823 | // Step 2. Check if user stay in the same block. |
| 1824 | if (loopBlock != consumerOp->getBlock()) |
| 1825 | continue; |
| 1826 | // Step 3. Check if user has succeeding user. Otherwise, it usually |
| 1827 | // represents already tiled. |
| 1828 | if (consumerOp->use_empty()) |
| 1829 | continue; |
| 1830 | // Step 4. Check assumption for loop with `reorderOperations` enabled. |
| 1831 | FailureOr<llvm::SetVector<Operation *>> slice = |
| 1832 | checkAssumptionForLoop(loopOp, consumerOp, reorderOperations: true); |
| 1833 | if (failed(Result: slice)) |
| 1834 | continue; |
| 1835 | // Step 5. If backward sice is not empty, move them before |
| 1836 | // firstUserOfLoop. |
| 1837 | if (!slice->empty()) { |
| 1838 | mlir::topologicalSort(toSort: *slice); |
| 1839 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); |
| 1840 | assert(succeeded(firstUserOfLoop) && "First user of loop is not found" ); |
| 1841 | for (auto op : *slice) { |
| 1842 | rewriter.moveOpBefore(op, existingOp: *firstUserOfLoop); |
| 1843 | } |
| 1844 | } |
| 1845 | return &opOperand; |
| 1846 | } |
| 1847 | return failure(); |
| 1848 | } |
| 1849 | |
| 1850 | /// Check that the loop is perfectly nested. |
| 1851 | /// The loops are expected to be ordered from outer most to inner most. |
| 1852 | /// For example: |
| 1853 | /// ``` |
| 1854 | /// %0 = scf.for() |
| 1855 | /// %1 = scf.for() |
| 1856 | /// %2 = scf.for() |
| 1857 | /// %3 = ... |
| 1858 | /// yield %3 |
| 1859 | /// yield %2 |
| 1860 | /// yield %1 |
| 1861 | /// ``` |
| 1862 | /// Here loops should be [%0, %1]. |
| 1863 | static bool |
| 1864 | isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1865 | assert(!loops.empty() && "unexpected empty loop nest" ); |
| 1866 | if (loops.size() == 1) { |
| 1867 | return isa_and_nonnull<scf::ForOp>(loops.front().getOperation()); |
| 1868 | } |
| 1869 | for (auto [outerLoop, innerLoop] : |
| 1870 | llvm::zip_equal(loops.drop_back(), loops.drop_front())) { |
| 1871 | auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation()); |
| 1872 | auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation()); |
| 1873 | if (!outerFor || !innerFor) { |
| 1874 | return false; |
| 1875 | } |
| 1876 | auto outerBBArgs = outerFor.getRegionIterArgs(); |
| 1877 | auto innerIterArgs = innerFor.getInitArgs(); |
| 1878 | if (outerBBArgs.size() != innerIterArgs.size()) { |
| 1879 | return false; |
| 1880 | } |
| 1881 | |
| 1882 | for (auto [outerBBArg, innerIterArg] : |
| 1883 | llvm::zip_equal(outerBBArgs, innerIterArgs)) { |
| 1884 | if (!llvm::hasSingleElement(outerBBArg.getUses()) || |
| 1885 | innerIterArg != outerBBArg) { |
| 1886 | return false; |
| 1887 | } |
| 1888 | } |
| 1889 | |
| 1890 | ValueRange outerYields = |
| 1891 | cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands(); |
| 1892 | ValueRange innerResults = innerFor.getResults(); |
| 1893 | if (outerYields.size() != innerResults.size()) { |
| 1894 | return false; |
| 1895 | } |
| 1896 | for (auto [outerYield, innerResult] : |
| 1897 | llvm::zip_equal(outerYields, innerResults)) { |
| 1898 | if (!llvm::hasSingleElement(innerResult.getUses()) || |
| 1899 | outerYield != innerResult) { |
| 1900 | return false; |
| 1901 | } |
| 1902 | } |
| 1903 | } |
| 1904 | return true; |
| 1905 | } |
| 1906 | |
| 1907 | /// Fetch the untiled consumer of the outermost scf.for's result which is |
| 1908 | /// yielded by a tensor.insert_slice from the innermost scf.for. This function |
| 1909 | /// makes the following assumptions : |
| 1910 | /// 1. tensor.insert_slice has scf.yield as its only user. |
| 1911 | /// 2. scf.for's corresponding result has only one use. |
| 1912 | /// 3. The `loops` passed in are perfectly nested `scf.for` operations. |
| 1913 | static FailureOr<OpOperand *> |
| 1914 | getUntiledConsumerFromSlice(RewriterBase &rewriter, |
| 1915 | tensor::InsertSliceOp candidateSliceOp, |
| 1916 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1917 | assert(!loops.empty() && "unexpected loops to be empty" ); |
| 1918 | // 1. Expect slice to be part of the body of the inner most loop. |
| 1919 | Operation *containingOp = candidateSliceOp->getParentOp(); |
| 1920 | if (containingOp != loops.back()) { |
| 1921 | return rewriter.notifyMatchFailure( |
| 1922 | candidateSliceOp, |
| 1923 | "expected slice to be within body of inner-most loop" ); |
| 1924 | } |
| 1925 | |
| 1926 | // 2. Check that the loop is perfectly nested. |
| 1927 | if (!isPerfectlyNestedForLoops(loops)) { |
| 1928 | return rewriter.notifyMatchFailure( |
| 1929 | candidateSliceOp, "expected passed loops to be perfectly nested." ); |
| 1930 | } |
| 1931 | |
| 1932 | if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) |
| 1933 | return failure(); |
| 1934 | Value sliceResult = candidateSliceOp.getResult(); |
| 1935 | |
| 1936 | // 3. Fetch the corresponding output. |
| 1937 | OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); |
| 1938 | unsigned resultNumber = yieldOpOperand.getOperandNumber(); |
| 1939 | |
| 1940 | scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation()); |
| 1941 | |
| 1942 | return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); |
| 1943 | } |
| 1944 | |
| 1945 | /// Fetch the first untiled consumer of a scf.forall's result which is yielded |
| 1946 | /// by a tensor.parallel_insert_slice. |
| 1947 | static FailureOr<OpOperand *> |
| 1948 | getUntiledConsumerFromSlice(RewriterBase &rewriter, |
| 1949 | tensor::ParallelInsertSliceOp candidateSliceOp, |
| 1950 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1951 | assert(!loops.empty() && "unexpected loops to be empty" ); |
| 1952 | // 1. Check that the surrounding loop is a single scf.forall loop. |
| 1953 | if (loops.size() != 1) { |
| 1954 | return rewriter.notifyMatchFailure( |
| 1955 | candidateSliceOp, "expected single surrounding scf.forall" ); |
| 1956 | } |
| 1957 | auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation()); |
| 1958 | if (!forallOp) { |
| 1959 | return rewriter.notifyMatchFailure( |
| 1960 | candidateSliceOp, "expected single surrounding scf.forall" ); |
| 1961 | } |
| 1962 | |
| 1963 | // 2. Fetch the corresponding output |
| 1964 | Value sliceDest = candidateSliceOp.getDest(); |
| 1965 | auto iterArg = dyn_cast<BlockArgument>(Val&: sliceDest); |
| 1966 | if (!iterArg) |
| 1967 | return failure(); |
| 1968 | if (iterArg.getOwner()->getParentOp() != forallOp) |
| 1969 | return failure(); |
| 1970 | |
| 1971 | unsigned resultNumber = |
| 1972 | forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) |
| 1973 | .getResultNumber(); |
| 1974 | |
| 1975 | return getConsumerFromLoopUses(rewriter, forallOp, resultNumber); |
| 1976 | } |
| 1977 | |
| 1978 | /// A utility to fetch an untiled consumer of |
| 1979 | /// tensor.insert_slice/tensor.parallel_insert_slice. |
| 1980 | static FailureOr<OpOperand *> |
| 1981 | getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp, |
| 1982 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1983 | assert(!loops.empty() && "unexpected empty loops" ); |
| 1984 | if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) { |
| 1985 | return getUntiledConsumerFromSlice(rewriter, insertSlice, loops); |
| 1986 | } else if (auto parallelInsertSlice = |
| 1987 | dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) { |
| 1988 | return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops); |
| 1989 | } else { |
| 1990 | return failure(); |
| 1991 | } |
| 1992 | } |
| 1993 | |
| 1994 | /// Implementation of fusing consumer of a single slice by computing the |
| 1995 | /// slice of the consumer in-place for scf loop. |
| 1996 | FailureOr<scf::SCFFuseConsumerOfSliceResult> |
| 1997 | mlir::scf::tileAndFuseConsumerOfSlice( |
| 1998 | RewriterBase &rewriter, Operation *candidateSliceOp, |
| 1999 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 2000 | // Return if `loops` is empty, return an error for now. Caller is expected |
| 2001 | // to handle this case. |
| 2002 | if (loops.empty()) { |
| 2003 | return candidateSliceOp->emitOpError( |
| 2004 | message: "cannot call tile and fuse consumer with an empty loop nest" ); |
| 2005 | } |
| 2006 | if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( |
| 2007 | candidateSliceOp)) |
| 2008 | return failure(); |
| 2009 | |
| 2010 | // 1. Get the consumer of scf.for for the result yielded by |
| 2011 | // tensor.insert_slice/parallel_insert_slice. |
| 2012 | FailureOr<OpOperand *> maybeConsumerOpOperand = |
| 2013 | getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops); |
| 2014 | if (failed(Result: maybeConsumerOpOperand)) { |
| 2015 | return rewriter.notifyMatchFailure(arg&: candidateSliceOp, |
| 2016 | msg: "could not fetch consumer to fuse" ); |
| 2017 | } |
| 2018 | OpOperand *consumerOpOperand = *maybeConsumerOpOperand; |
| 2019 | Operation *consumerOp = consumerOpOperand->getOwner(); |
| 2020 | unsigned operandNumber = consumerOpOperand->getOperandNumber(); |
| 2021 | unsigned resultNumber = 0; |
| 2022 | if (auto producerResult = dyn_cast<OpResult>(Val: consumerOpOperand->get())) { |
| 2023 | resultNumber = producerResult.getResultNumber(); |
| 2024 | } else { |
| 2025 | return rewriter.notifyMatchFailure( |
| 2026 | arg&: consumerOp, msg: "consumer op's operand doesn't seem to be an OpResult" ); |
| 2027 | } |
| 2028 | |
| 2029 | LoopLikeOpInterface outerMostLoop = loops.front(); |
| 2030 | LoopLikeOpInterface innerMostLoop = loops.back(); |
| 2031 | |
| 2032 | // Check assumption for loop with `reorderOperations` disabled. |
| 2033 | if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { |
| 2034 | return rewriter.notifyMatchFailure( |
| 2035 | outerMostLoop, "the first user of loop should not dominate any define " |
| 2036 | "of consumer operand(s)" ); |
| 2037 | } |
| 2038 | |
| 2039 | OpBuilder::InsertionGuard g(rewriter); |
| 2040 | |
| 2041 | // 2. Check consumer is not using scf loop's output as init. |
| 2042 | auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); |
| 2043 | if (!dstOp) |
| 2044 | return rewriter.notifyMatchFailure(arg&: consumerOp, |
| 2045 | msg: "consumer op is not DPS operation" ); |
| 2046 | SmallVector<Value> dpsInits = |
| 2047 | llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); |
| 2048 | if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) { |
| 2049 | return rewriter.notifyMatchFailure( |
| 2050 | arg&: consumerOp, |
| 2051 | msg: "consumer op taking the result of scf.for as init is not supported" ); |
| 2052 | } |
| 2053 | SmallVector<Value> newInits = dpsInits; |
| 2054 | |
| 2055 | Location loc = outerMostLoop->getLoc(); |
| 2056 | |
| 2057 | // 3. Move the whole loop structure right before firstUserOfLoop, the |
| 2058 | // dominance should be already ensured by `checkAssumptionForLoop`. |
| 2059 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop); |
| 2060 | if (failed(Result: firstUserOfLoop)) { |
| 2061 | return rewriter.notifyMatchFailure( |
| 2062 | outerMostLoop, "could not find the first user of outer most loop" ); |
| 2063 | } |
| 2064 | rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop); |
| 2065 | |
| 2066 | // 4. Set insertion point before terminator op of the loop and create a new |
| 2067 | // tensor.insert_slice. In the scf.for case this is a clone of the |
| 2068 | // candidateSliceOp whereas in the scf.forall case this is created from the |
| 2069 | // operands of tensor.parallel_insert_slice. |
| 2070 | tensor::InsertSliceOp clonedInsertSliceOp; |
| 2071 | if (auto sliceOp = |
| 2072 | dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) { |
| 2073 | auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation()); |
| 2074 | rewriter.setInsertionPoint(newForallOp.getTerminator()); |
| 2075 | clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>( |
| 2076 | loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), |
| 2077 | sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); |
| 2078 | } else { |
| 2079 | rewriter.setInsertionPoint(candidateSliceOp); |
| 2080 | clonedInsertSliceOp = |
| 2081 | cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)); |
| 2082 | } |
| 2083 | |
| 2084 | // 5.a. Clone consumer op. |
| 2085 | auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(op&: *consumerOp)); |
| 2086 | |
| 2087 | // 5.b. Replace all uses of the loop result with the result of the cloned |
| 2088 | // tensor.insert_slice. |
| 2089 | OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); |
| 2090 | rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { |
| 2091 | operandToReplace.set(clonedInsertSliceOp.getResult()); |
| 2092 | }); |
| 2093 | |
| 2094 | // 6. Perform tiling of the cloned consumer and replace the operand at |
| 2095 | // `operandNumber` with the source of the cloned tensor.insert_slice op. |
| 2096 | auto ossSliceOp = |
| 2097 | cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation()); |
| 2098 | FailureOr<TilingResult> tileAndFuseResult = |
| 2099 | tensor::replaceInsertSliceWithTiledConsumer( |
| 2100 | builder&: rewriter, sliceOp: ossSliceOp, consumerOp&: clonedConsumerOp->getOpOperand(operandNumber)); |
| 2101 | if (failed(Result: tileAndFuseResult)) { |
| 2102 | return failure(); |
| 2103 | } |
| 2104 | auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]); |
| 2105 | rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber), |
| 2106 | clonedInsertSliceOp.getSource()); |
| 2107 | |
| 2108 | // 7. Reconstruct [nested] loop with new inits. |
| 2109 | YieldTiledValuesFn newYieldValuesFn = |
| 2110 | [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, |
| 2111 | ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, |
| 2112 | SmallVector<SmallVector<OpFoldResult>> &tiledOffset, |
| 2113 | SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { |
| 2114 | OpBuilder::InsertionGuard g(innerRewriter); |
| 2115 | // 8. Set inner insertPoint right before tiled consumer op. |
| 2116 | innerRewriter.setInsertionPoint(tiledConsumerOp); |
| 2117 | |
| 2118 | SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets(); |
| 2119 | SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes(); |
| 2120 | SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides(); |
| 2121 | |
| 2122 | // 9. Check all insert stride is 1. |
| 2123 | if (!llvm::all_of(Range&: strides, P: isOneInteger)) { |
| 2124 | return rewriter.notifyMatchFailure( |
| 2125 | arg&: candidateSliceOp, msg: "containingOp's result yield with stride" ); |
| 2126 | } |
| 2127 | |
| 2128 | // 10. Try to get iter domain position from input position. Use |
| 2129 | // clonedConsumerOp instead of tiledConsumerOp, because the iteration |
| 2130 | // domain may require index computation based on the result size. The |
| 2131 | // sizes and offsets should be the same either way, but using |
| 2132 | // tiledConsumerOp could lead to some chained unnecessary extra index |
| 2133 | // computation. |
| 2134 | SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; |
| 2135 | if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( |
| 2136 | rewriter, operandNumber, offsets, sizes, iterDomainOffsets, |
| 2137 | iterDomainSizes))) { |
| 2138 | return rewriter.notifyMatchFailure( |
| 2139 | clonedConsumerOp, |
| 2140 | "can't get iter domain position from input position" ); |
| 2141 | } |
| 2142 | |
| 2143 | // 11. Try to fetch the offset and size for all results of the cloned |
| 2144 | // consumer. This would then be used to form the corresponding |
| 2145 | // tensor.insert_slice/parallel_insert_slice later. |
| 2146 | unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); |
| 2147 | SmallVector<SmallVector<OpFoldResult>> resultOffsets( |
| 2148 | totalNumResultsOfConsumer); |
| 2149 | SmallVector<SmallVector<OpFoldResult>> resultSizes( |
| 2150 | totalNumResultsOfConsumer); |
| 2151 | for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { |
| 2152 | if (failed(tiledConsumerOp.getResultTilePosition( |
| 2153 | rewriter, idx, iterDomainOffsets, iterDomainSizes, |
| 2154 | resultOffsets[idx], resultSizes[idx]))) { |
| 2155 | return rewriter.notifyMatchFailure( |
| 2156 | tiledConsumerOp, |
| 2157 | "can't get result domain position from iter domain position" ); |
| 2158 | } |
| 2159 | } |
| 2160 | |
| 2161 | // 12. Create `extract_slice` for `iter_args` for DPS operation if |
| 2162 | // necessary. |
| 2163 | if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( |
| 2164 | tiledConsumerOp.getOperation())) { |
| 2165 | rewriter.setInsertionPoint(tiledDestStyleOp); |
| 2166 | for (const auto &&[index, newRegionArg] : |
| 2167 | llvm::enumerate(First&: newRegionIterArgs)) { |
| 2168 | auto destSlice = rewriter.create<tensor::ExtractSliceOp>( |
| 2169 | loc, newRegionArg, resultOffsets[index], resultSizes[index], |
| 2170 | SmallVector<OpFoldResult>(resultOffsets[index].size(), |
| 2171 | rewriter.getIndexAttr(1))); |
| 2172 | // Make a copy of index to avoid a capturing structured binding, which |
| 2173 | // is a C++20 extension. |
| 2174 | auto dstNumber = index; |
| 2175 | rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { |
| 2176 | tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); |
| 2177 | }); |
| 2178 | } |
| 2179 | } |
| 2180 | |
| 2181 | // 13. Prepare tiled offset and sizes for later `insert_slice` creation by |
| 2182 | // caller. |
| 2183 | Block *block = rewriter.getInsertionPoint()->getBlock(); |
| 2184 | rewriter.setInsertionPoint(block->getTerminator()); |
| 2185 | for (const auto &&[index, result] : |
| 2186 | llvm::enumerate(tiledConsumerOp->getResults())) { |
| 2187 | tiledResult.push_back(result); |
| 2188 | tiledOffset.emplace_back(resultOffsets[index]); |
| 2189 | tiledSizes.emplace_back(resultSizes[index]); |
| 2190 | } |
| 2191 | return success(); |
| 2192 | }; |
| 2193 | // 14. Add new inits to [nested] loops. |
| 2194 | if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits, |
| 2195 | newYieldValuesFn))) { |
| 2196 | return rewriter.notifyMatchFailure(tiledConsumerOp, |
| 2197 | "unable to add new inits to nest loop" ); |
| 2198 | } |
| 2199 | |
| 2200 | // 15. Replace the result of scf loop and consumer op with new loop's |
| 2201 | // results. |
| 2202 | |
| 2203 | for (auto &&[oldResult, newResult] : |
| 2204 | llvm::zip(consumerOp->getResults(), |
| 2205 | loops.front()->getResults().take_back(newInits.size()))) { |
| 2206 | rewriter.replaceAllUsesWith(oldResult, newResult); |
| 2207 | } |
| 2208 | |
| 2209 | // 16. Need to erase the old scf loop and the cloned consumer op. |
| 2210 | rewriter.eraseOp(op: clonedConsumerOp); |
| 2211 | |
| 2212 | return scf::SCFFuseConsumerOfSliceResult{ |
| 2213 | .origConsumerOperand: consumerOpOperand, |
| 2214 | .tiledAndFusedConsumerOperand: &(tileAndFuseResult->tiledOps[0]->getOpOperand(idx: operandNumber)), |
| 2215 | .tiledOps: tileAndFuseResult->tiledOps}; |
| 2216 | } |
| 2217 | |
| 2218 | //===----------------------------------------------------------------------===// |
| 2219 | // lowerToLoopsUsingSCFForOp implementation. |
| 2220 | //===----------------------------------------------------------------------===// |
| 2221 | |
| 2222 | FailureOr<SmallVector<scf::ForOp>> |
| 2223 | mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, |
| 2224 | TilingInterface op) { |
| 2225 | // TODO: Handle cases where the op has results if needed. |
| 2226 | if (op->getNumResults() > 0) { |
| 2227 | return rewriter.notifyMatchFailure( |
| 2228 | op, "unable to lower to loops operations with return values" ); |
| 2229 | } |
| 2230 | |
| 2231 | SmallVector<Range> domain = op.getIterationDomain(rewriter); |
| 2232 | SmallVector<Value> ivs; |
| 2233 | SmallVector<scf::ForOp> loops; |
| 2234 | Location loc = op.getLoc(); |
| 2235 | for (auto loopRange : domain) { |
| 2236 | Value offsetVal = |
| 2237 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); |
| 2238 | Value sizeVal = |
| 2239 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); |
| 2240 | Value strideVal = |
| 2241 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); |
| 2242 | auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, |
| 2243 | strideVal, ValueRange{}); |
| 2244 | loops.push_back(loop); |
| 2245 | ivs.push_back(loop.getInductionVar()); |
| 2246 | rewriter.setInsertionPoint(loop.getBody()->getTerminator()); |
| 2247 | } |
| 2248 | if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { |
| 2249 | return failure(); |
| 2250 | } |
| 2251 | return loops; |
| 2252 | } |
| 2253 | |