| 1 | //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the tiling using TilingInterface. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
| 14 | |
| 15 | #include "mlir/Analysis/SliceAnalysis.h" |
| 16 | #include "mlir/Analysis/TopologicalSortUtils.h" |
| 17 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 19 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 21 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 22 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 24 | #include "mlir/IR/Dominance.h" |
| 25 | #include "mlir/IR/Matchers.h" |
| 26 | #include "mlir/IR/PatternMatch.h" |
| 27 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
| 28 | #include "mlir/Interfaces/TilingInterface.h" |
| 29 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 30 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 31 | #include "llvm/ADT/ScopeExit.h" |
| 32 | #include "llvm/ADT/TypeSwitch.h" |
| 33 | #include "llvm/Support/Debug.h" |
| 34 | #include <optional> |
| 35 | |
| 36 | #define DEBUG_TYPE "tile-using-interface" |
| 37 | |
| 38 | using namespace mlir; |
| 39 | |
| 40 | scf::SCFTilingOptions & |
| 41 | scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) { |
| 42 | assert(!tileSizeComputationFunction && "tile sizes already set" ); |
| 43 | auto tileSizes = llvm::to_vector(Range&: ts); |
| 44 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { |
| 45 | return tileSizes; |
| 46 | }; |
| 47 | return *this; |
| 48 | } |
| 49 | |
| 50 | scf::SCFTilingOptions & |
| 51 | scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) { |
| 52 | assert(!numThreadsComputationFunction && "num tiles already set" ); |
| 53 | auto numThreads = llvm::to_vector(Range&: nt); |
| 54 | numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { |
| 55 | return numThreads; |
| 56 | }; |
| 57 | return *this; |
| 58 | } |
| 59 | |
| 60 | /// Helper method to adjust the interchange vector to match the iteration |
| 61 | /// domain. |
| 62 | static SmallVector<int64_t> |
| 63 | fillInterchangeVector(ArrayRef<int64_t> interchangeVector, |
| 64 | size_t iterationDomainSize) { |
| 65 | SmallVector<int64_t> filledVector = llvm::to_vector(Range&: interchangeVector); |
| 66 | if (filledVector.size() < iterationDomainSize) { |
| 67 | auto range = llvm::seq<int64_t>(Begin: filledVector.size(), End: iterationDomainSize); |
| 68 | filledVector.append(in_start: range.begin(), in_end: range.end()); |
| 69 | } |
| 70 | if (filledVector.size() > iterationDomainSize) |
| 71 | filledVector.resize(N: iterationDomainSize); |
| 72 | return filledVector; |
| 73 | } |
| 74 | |
| 75 | //===----------------------------------------------------------------------===// |
| 76 | // tileUsingSCF implementation. |
| 77 | //===----------------------------------------------------------------------===// |
| 78 | |
| 79 | /// Verify the tile size options are set in a consistent manner. |
| 80 | static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, |
| 81 | const scf::SCFTilingOptions &options) { |
| 82 | // Specifying number of threads is only supported on `scf.forall` op. |
| 83 | if (options.numThreadsComputationFunction && |
| 84 | options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { |
| 85 | return rewriter.notifyMatchFailure( |
| 86 | arg&: loc, msg: "number of threads can only by specified when loop type is " |
| 87 | "set to use `scf.forall`" ); |
| 88 | } |
| 89 | |
| 90 | // If specified, check that the interchange vector is a permutation. |
| 91 | if (!options.interchangeVector.empty()) { |
| 92 | if (!isPermutationVector(interchange: options.interchangeVector)) { |
| 93 | return rewriter.notifyMatchFailure( |
| 94 | arg&: loc, msg: "invalid interchange vector, not a permutation of the entire " |
| 95 | "iteration space" ); |
| 96 | } |
| 97 | } |
| 98 | return success(); |
| 99 | } |
| 100 | |
| 101 | /// Method to instantiate the tile sizes and/or number of threads specified |
| 102 | /// by the user. |
| 103 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> |
| 104 | getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, |
| 105 | ArrayRef<Range> iterationDomain, |
| 106 | const scf::SCFTilingOptions &options) { |
| 107 | OpFoldResult zero = rewriter.getIndexAttr(value: 0); |
| 108 | SmallVector<OpFoldResult> tileSizes, numThreads; |
| 109 | size_t numLoops = iterationDomain.size(); |
| 110 | |
| 111 | // Check whether the number of tiles to use is specified. |
| 112 | if (options.numThreadsComputationFunction) { |
| 113 | numThreads = options.numThreadsComputationFunction(rewriter, op); |
| 114 | numThreads.resize(N: numLoops, NV: zero); |
| 115 | |
| 116 | // If the number of tiles is also specified, use that. |
| 117 | if (options.tileSizeComputationFunction) { |
| 118 | tileSizes = options.tileSizeComputationFunction(rewriter, op); |
| 119 | tileSizes.resize(N: numLoops, NV: zero); |
| 120 | return {tileSizes, numThreads}; |
| 121 | } |
| 122 | |
| 123 | // Compute the tile sizes from the iteration domain and number |
| 124 | // of tiles as follows |
| 125 | // - niters = ceilDiv(ub - lb, step) |
| 126 | // - tileSize = ceilDiv(niters, numThreads) |
| 127 | AffineExpr s0, s1, s2; |
| 128 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2); |
| 129 | // TODO: The step here is assumed to be 1. |
| 130 | AffineExpr numItersExpr = (s1 - s0); |
| 131 | AffineExpr tileSizeExpr = numItersExpr.ceilDiv(other: s2); |
| 132 | tileSizes.resize(N: numLoops, NV: zero); |
| 133 | for (auto [index, range, nt] : |
| 134 | llvm::enumerate(First&: iterationDomain, Rest&: numThreads)) { |
| 135 | if (isZeroInteger(v: nt)) |
| 136 | continue; |
| 137 | |
| 138 | tileSizes[index] = affine::makeComposedFoldedAffineApply( |
| 139 | b&: rewriter, loc: op.getLoc(), expr: tileSizeExpr, operands: {range.offset, range.size, nt}); |
| 140 | } |
| 141 | tileSizes.resize(N: numLoops, NV: zero); |
| 142 | return {tileSizes, numThreads}; |
| 143 | } |
| 144 | |
| 145 | // Enforce the convention that "tiling by zero" |
| 146 | // skips tiling a particular dimension. This convention is significantly |
| 147 | // simpler to handle instead of adjusting affine maps to account for missing |
| 148 | // dimensions. |
| 149 | assert(options.tileSizeComputationFunction && |
| 150 | "expected tile sizes to be specified" ); |
| 151 | tileSizes = options.tileSizeComputationFunction(rewriter, op); |
| 152 | tileSizes.resize(N: numLoops, NV: zero); |
| 153 | |
| 154 | return {tileSizes, numThreads}; |
| 155 | } |
| 156 | |
| 157 | /// Checks if any of the tiled loops are not parallel. |
| 158 | static LogicalResult checkTileSizes(TilingInterface op, |
| 159 | scf::SCFTilingOptions::LoopType loopType, |
| 160 | ReductionTilingStrategy reductionStrategy, |
| 161 | ArrayRef<OpFoldResult> tileSizes, |
| 162 | ArrayRef<OpFoldResult> numThreads) { |
| 163 | auto iterators = op.getLoopIteratorTypes(); |
| 164 | assert(iterators.size() == tileSizes.size() && |
| 165 | "expected as many tile size values as number of loops" ); |
| 166 | assert((numThreads.empty() || (numThreads.size() == iterators.size())) && |
| 167 | "when specified, expected number of threads to use for each loop" ); |
| 168 | |
| 169 | bool isParallelTiling = false; |
| 170 | for (auto [index, iterator, tileSize] : |
| 171 | llvm::enumerate(First&: iterators, Rest&: tileSizes)) { |
| 172 | if (!isConstantIntValue(ofr: tileSize, value: 0)) { |
| 173 | isParallelTiling |= iterator == utils::IteratorType::parallel; |
| 174 | } |
| 175 | |
| 176 | if (loopType == scf::SCFTilingOptions::LoopType::ForallOp && |
| 177 | reductionStrategy == ReductionTilingStrategy::FullReduction) { |
| 178 | // If num threads is specified, check that it is greater than one only for |
| 179 | // parallel dimensions. |
| 180 | if (!numThreads.empty()) { |
| 181 | if (std::optional<int64_t> constNumThreads = |
| 182 | getConstantIntValue(ofr: numThreads[index])) { |
| 183 | if (constNumThreads.value() > 1 && |
| 184 | iterator != utils::IteratorType::parallel) { |
| 185 | op.emitWarning() << "tiling is not thread safe at axis #" << index; |
| 186 | } |
| 187 | } |
| 188 | continue; |
| 189 | } |
| 190 | |
| 191 | if (std::optional<int64_t> constTileSize = |
| 192 | getConstantIntValue(ofr: tileSize)) { |
| 193 | if (constTileSize.value() > 0 && |
| 194 | iterator != utils::IteratorType::parallel) { |
| 195 | op.emitWarning() << "tiling is not thread safe at axis #" << index; |
| 196 | } |
| 197 | } |
| 198 | } |
| 199 | } |
| 200 | |
| 201 | if (reductionStrategy != ReductionTilingStrategy::FullReduction) { |
| 202 | if (isParallelTiling) { |
| 203 | return op->emitOpError(message: "tiling parallel dimensions is not supported with " |
| 204 | "partial reduction tiling strategies" ); |
| 205 | } |
| 206 | } |
| 207 | return success(); |
| 208 | } |
| 209 | |
| 210 | /// Get the reduction dims that are tiled. This accounts for reduction dims |
| 211 | /// that are specified as tiled, but the tile size is 0. |
| 212 | static SetVector<unsigned> |
| 213 | getSanitizedReductionDims(ArrayRef<OpFoldResult> tileSizes, |
| 214 | const scf::SCFTilingOptions &options) { |
| 215 | SetVector<unsigned> reductionDims; |
| 216 | for (auto dim : options.reductionDims) { |
| 217 | if (isConstantIntValue(ofr: tileSizes[dim], value: 0)) |
| 218 | continue; |
| 219 | reductionDims.insert(X: dim); |
| 220 | } |
| 221 | return reductionDims; |
| 222 | } |
| 223 | |
| 224 | /// Check if `stride` evenly divides the trip count `size - offset`. |
| 225 | static bool tileDividesIterationDomain(Range loopRange) { |
| 226 | std::optional<int64_t> offsetAsInt = getConstantIntValue(ofr: loopRange.offset); |
| 227 | if (!offsetAsInt) |
| 228 | return false; |
| 229 | std::optional<int64_t> sizeAsInt = getConstantIntValue(ofr: loopRange.size); |
| 230 | if (!sizeAsInt) |
| 231 | return false; |
| 232 | std::optional<int64_t> strideAsInt = getConstantIntValue(ofr: loopRange.stride); |
| 233 | if (!strideAsInt) |
| 234 | return false; |
| 235 | return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); |
| 236 | } |
| 237 | |
| 238 | /// Returns the bounded tile size given the current `offset`, `loopRange` and |
| 239 | /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. |
| 240 | static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, |
| 241 | Range loopRange, OpFoldResult offset, |
| 242 | OpFoldResult tileSize) { |
| 243 | std::optional<int64_t> ts = getConstantIntValue(ofr: tileSize); |
| 244 | if (ts && ts.value() == 1) |
| 245 | return tileSize; |
| 246 | |
| 247 | if (tileDividesIterationDomain( |
| 248 | loopRange: Range{.offset: loopRange.offset, .size: loopRange.size, .stride: tileSize})) |
| 249 | return tileSize; |
| 250 | |
| 251 | // The tile size to use (to avoid out of bounds access) is minimum of |
| 252 | // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled |
| 253 | // loop. |
| 254 | AffineExpr s0, s1, d0; |
| 255 | bindDims(ctx: b.getContext(), exprs&: d0); |
| 256 | bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1); |
| 257 | AffineMap minMap = AffineMap::get(dimCount: 1, symbolCount: 2, results: {s0 - d0, s1}, context: b.getContext()); |
| 258 | Value size = getValueOrCreateConstantIndexOp(b, loc, ofr: loopRange.size); |
| 259 | return affine::makeComposedFoldedAffineMin( |
| 260 | b, loc, map: minMap, operands: SmallVector<OpFoldResult>{offset, size, tileSize}); |
| 261 | } |
| 262 | |
| 263 | /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less |
| 264 | /// than `iterationSize`. |
| 265 | static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, |
| 266 | OpFoldResult numThreads, |
| 267 | OpFoldResult iterationSize) { |
| 268 | std::optional<int64_t> tileSizeConst = getConstantIntValue(ofr: tileSize); |
| 269 | std::optional<int64_t> numThreadsConst = getConstantIntValue(ofr: numThreads); |
| 270 | std::optional<int64_t> iterSizeConst = getConstantIntValue(ofr: iterationSize); |
| 271 | if (!tileSizeConst || !numThreadsConst || !iterSizeConst) |
| 272 | return false; |
| 273 | return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; |
| 274 | } |
| 275 | |
| 276 | /// Compute the `OpFoldResult`s that represents the multi-dimensional |
| 277 | /// `offset`s and `size`s of the tile of the iteration space that the |
| 278 | /// innermost loop body of the generated tiled loops corresponds to. |
| 279 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> |
| 280 | getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, |
| 281 | ReductionTilingStrategy strategy, ValueRange ivs, |
| 282 | ArrayRef<Range> iterationDomain, |
| 283 | ArrayRef<OpFoldResult> tileSizes, |
| 284 | ArrayRef<OpFoldResult> numThreads, |
| 285 | const llvm::SetVector<unsigned> &reductionDims) { |
| 286 | SmallVector<OpFoldResult> offsets, sizes; |
| 287 | int materializedLoopNum = 0; |
| 288 | |
| 289 | if (!numThreads.empty()) { |
| 290 | AffineExpr d0, d1, s0, s1; |
| 291 | AffineExpr offsetExpr, residualTileSizeExpr; |
| 292 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1); |
| 293 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
| 294 | offsetExpr = d0 + d1 * s0; |
| 295 | residualTileSizeExpr = s1 - (d0 + d1 * s0); |
| 296 | |
| 297 | for (auto [index, nt, tileSize, loopRange] : |
| 298 | llvm::enumerate(First&: numThreads, Rest&: tileSizes, Rest&: iterationDomain)) { |
| 299 | |
| 300 | // Non-tiled cases, set the offset and size to the |
| 301 | // `loopRange.offset/size`. |
| 302 | if (isZeroInteger(v: nt)) { |
| 303 | offsets.push_back(Elt: loopRange.offset); |
| 304 | sizes.push_back(Elt: loopRange.size); |
| 305 | continue; |
| 306 | } |
| 307 | |
| 308 | Value iv = ivs[materializedLoopNum++]; |
| 309 | OpFoldResult offset = affine::makeComposedFoldedAffineApply( |
| 310 | b&: rewriter, loc, expr: offsetExpr, |
| 311 | operands: ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize}); |
| 312 | OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( |
| 313 | b&: rewriter, loc, expr: residualTileSizeExpr, |
| 314 | operands: {loopRange.offset, nt, tileSize, loopRange.size}); |
| 315 | |
| 316 | OpFoldResult size = tileSize; |
| 317 | if (!isZeroInteger(v: residualTileSize)) { |
| 318 | OpFoldResult sizeMinusOffsetPerThread = |
| 319 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: s0 - d0, |
| 320 | operands: {offset, loopRange.size}); |
| 321 | size = affine::makeComposedFoldedAffineMin( |
| 322 | b&: rewriter, loc, |
| 323 | map: AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()), |
| 324 | operands: {sizeMinusOffsetPerThread, tileSize}); |
| 325 | } |
| 326 | |
| 327 | // Consider the case where the original loop was `[0, 100)`. |
| 328 | // If number of threads are `7`, the tile size would be computed as |
| 329 | // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) |
| 330 | // - `offset = 0 + 6 * 15 = 105` |
| 331 | // - `tileSize = min(15, 100 - 105) = -5` |
| 332 | // To avoid negative tile sizes, we need to do a further |
| 333 | // `nonNegativeTileSize = affine.max(0, tileSize)`. |
| 334 | // This `max` can be avoided if |
| 335 | // `offset + tileSize * (numThreads - 1) < (ub - lb)` |
| 336 | if (!canOmitTileOffsetInBoundsCheck(tileSize, numThreads: nt, iterationSize: loopRange.size)) { |
| 337 | AffineMap maxMap = |
| 338 | AffineMap::getMultiDimIdentityMap(numDims: 2, context: rewriter.getContext()); |
| 339 | size = affine::makeComposedFoldedAffineMax( |
| 340 | b&: rewriter, loc, map: maxMap, operands: {rewriter.getIndexAttr(value: 0), size}); |
| 341 | } |
| 342 | |
| 343 | offsets.push_back(Elt: offset); |
| 344 | sizes.push_back(Elt: size); |
| 345 | } |
| 346 | return {offsets, sizes}; |
| 347 | } else { |
| 348 | for (auto [tileSize, loopRange] : |
| 349 | llvm::zip_equal(t&: tileSizes, u&: iterationDomain)) { |
| 350 | |
| 351 | // Non-tiled cases, set the offset and size to the |
| 352 | // `loopRange.offset/size`. |
| 353 | if (isZeroInteger(v: tileSize)) { |
| 354 | offsets.push_back(Elt: loopRange.offset); |
| 355 | sizes.push_back(Elt: loopRange.size); |
| 356 | continue; |
| 357 | } |
| 358 | |
| 359 | Value iv = ivs[materializedLoopNum++]; |
| 360 | OpFoldResult offset = getAsOpFoldResult(val: iv); |
| 361 | offsets.push_back(Elt: offset); |
| 362 | OpFoldResult size = |
| 363 | getBoundedTileSize(b&: rewriter, loc, loopRange, offset, tileSize); |
| 364 | sizes.push_back(Elt: size); |
| 365 | } |
| 366 | return {offsets, sizes}; |
| 367 | } |
| 368 | } |
| 369 | |
| 370 | /// Function to return the bounds of the loops to be generated. |
| 371 | static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>, |
| 372 | SmallVector<OpFoldResult>> |
| 373 | getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
| 374 | ArrayRef<OpFoldResult> tileSizes) { |
| 375 | SmallVector<OpFoldResult> lbs, ubs, steps; |
| 376 | for (auto [loopRange, tileSize] : llvm::zip_equal(t&: loopRanges, u&: tileSizes)) { |
| 377 | // No loop if the tile size is 0. |
| 378 | if (isZeroInteger(v: tileSize)) |
| 379 | continue; |
| 380 | lbs.push_back(Elt: loopRange.offset); |
| 381 | ubs.push_back(Elt: loopRange.size); |
| 382 | steps.push_back(Elt: tileSize); |
| 383 | } |
| 384 | return {lbs, ubs, steps}; |
| 385 | } |
| 386 | |
| 387 | /// A function that allows returning additional yielded values during |
| 388 | /// `yieldTiledValuesAndReplace`. |
| 389 | /// - `ivs` induction variable for the loop. |
| 390 | /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. |
| 391 | /// - `tiledValues` the tiled values to return. Must be of same size as |
| 392 | /// `newbbArgs`, each element of this array is inserted into the corresponding |
| 393 | /// element in `newbbArgs`. |
| 394 | /// - `resultOffsets` is of the same size as `tiledValues` and represents |
| 395 | /// the offsets to use when inserting corresponding element from `tiledValues` |
| 396 | /// into the element from `newBbArgs`. |
| 397 | /// - `resultSizes` is of the same size as `tiledValues` and represents |
| 398 | /// the size of the corresponding element from `tiledValues` inserted into |
| 399 | /// the element from `newBbArgs`. |
| 400 | /// In case the method needs to return `failure()` the method is expected |
| 401 | /// to clean up any inserted operations. |
| 402 | using YieldTiledValuesFn = std::function<LogicalResult( |
| 403 | RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, |
| 404 | SmallVector<Value> &tiledValues, |
| 405 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
| 406 | SmallVector<SmallVector<OpFoldResult>> &resultSizes)>; |
| 407 | |
| 408 | /// Clones the operation and updates the destination if the operation |
| 409 | /// implements the `DestinationStyleOpInterface`. |
| 410 | static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, |
| 411 | Operation *op, |
| 412 | ValueRange newDestArgs) { |
| 413 | Operation *clonedOp = rewriter.clone(op&: *op); |
| 414 | if (newDestArgs.empty()) |
| 415 | return clonedOp; |
| 416 | if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(Val: clonedOp)) |
| 417 | destinationStyleOp.getDpsInitsMutable().assign(values: newDestArgs); |
| 418 | return clonedOp; |
| 419 | } |
| 420 | |
| 421 | /// Generate the tile-loop nest using `scf.for` operation. |
| 422 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
| 423 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
| 424 | /// - `destinationTensors` are the init values to use for the outer most loop. |
| 425 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
| 426 | /// most |
| 427 | /// loop. |
| 428 | /// - `loops` is an in-out parameter into which the generated loops are |
| 429 | /// populated. |
| 430 | static LogicalResult generateLoopNestUsingForOp( |
| 431 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
| 432 | ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors, |
| 433 | YieldTiledValuesFn yieldTiledValuesFn, |
| 434 | SmallVector<LoopLikeOpInterface> &loops) { |
| 435 | assert(!loopRanges.empty() && "unexpected empty loop ranges" ); |
| 436 | assert(loopRanges.size() == tileSizes.size() && |
| 437 | "expected as many tile sizes as loop ranges" ); |
| 438 | OpBuilder::InsertionGuard guard(rewriter); |
| 439 | |
| 440 | SmallVector<OpFoldResult> lbs, ubs, steps; |
| 441 | std::tie(args&: lbs, args&: ubs, args&: steps) = |
| 442 | getLoopBounds(rewriter, loc, loopRanges, tileSizes); |
| 443 | SmallVector<Value> lbVals = |
| 444 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: lbs); |
| 445 | SmallVector<Value> ubVals = |
| 446 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: ubs); |
| 447 | SmallVector<Value> stepVals = |
| 448 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: steps); |
| 449 | |
| 450 | SmallVector<Value> ivs; |
| 451 | for (auto [lb, ub, step] : llvm::zip_equal(t&: lbVals, u&: ubVals, args&: stepVals)) { |
| 452 | auto loop = |
| 453 | rewriter.create<scf::ForOp>(location: loc, args&: lb, args&: ub, args&: step, args&: destinationTensors, |
| 454 | args: [](OpBuilder &bodyBuilder, Location bodyLoc, |
| 455 | Value iv, ValueRange /*iterArgs*/) {}); |
| 456 | loops.push_back(Elt: loop); |
| 457 | ivs.push_back(Elt: loop.getInductionVar()); |
| 458 | rewriter.setInsertionPointToEnd(loop.getBody()); |
| 459 | destinationTensors = loop.getRegionIterArgs(); |
| 460 | } |
| 461 | |
| 462 | SmallVector<Value> tiledResults; |
| 463 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 464 | if (failed(Result: yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors, |
| 465 | tiledResults, resultOffsets, resultSizes))) { |
| 466 | return rewriter.notifyMatchFailure( |
| 467 | arg&: loc, msg: "failed to generate inner tile loop body" ); |
| 468 | } |
| 469 | if (loops.empty()) |
| 470 | return success(); |
| 471 | |
| 472 | assert(tiledResults.size() == destinationTensors.size() && |
| 473 | "Number of results of body should be equal to number of iter args" ); |
| 474 | |
| 475 | // 6. Yield all the results of the tiled operation. |
| 476 | SmallVector<Value> yieldedValues; |
| 477 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
| 478 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
| 479 | args&: resultSizes)) { |
| 480 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 481 | rewriter.getIndexAttr(value: 1)); |
| 482 | auto insertSlice = rewriter.create<tensor::InsertSliceOp>( |
| 483 | location: loc, args&: tiledValue, args&: destinationTensor, args&: resultOffset, args&: resultSize, |
| 484 | args&: resultStride); |
| 485 | yieldedValues.push_back(Elt: insertSlice); |
| 486 | } |
| 487 | rewriter.create<scf::YieldOp>(location: loc, args&: yieldedValues); |
| 488 | |
| 489 | // Add the scf.yield operations for all the outer loops. |
| 490 | for (auto [outerLoop, innerLoop] : |
| 491 | llvm::zip_equal(t: MutableArrayRef(loops).drop_back(), |
| 492 | u: MutableArrayRef(loops).drop_front())) { |
| 493 | rewriter.setInsertionPointToEnd( |
| 494 | cast<scf::ForOp>(Val: outerLoop.getOperation()).getBody()); |
| 495 | rewriter.create<scf::YieldOp>(location: outerLoop.getLoc(), args: innerLoop->getResults()); |
| 496 | } |
| 497 | return success(); |
| 498 | } |
| 499 | |
| 500 | /// Generate the tile-loop nest using `scf.forall` operation. |
| 501 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
| 502 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
| 503 | /// - `destinationTensors` are the init values to use for the outer most loop. |
| 504 | /// - `mappingVector` is the mapping attributes to use for loop construction. |
| 505 | /// Can be empty. |
| 506 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
| 507 | /// most |
| 508 | /// loop. |
| 509 | /// - `loops` is an in-out parameter into which the generated loops are |
| 510 | /// populated. |
| 511 | static LogicalResult generateLoopNestUsingForallOp( |
| 512 | RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
| 513 | ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, |
| 514 | ArrayRef<Attribute> mappingVector, ValueRange destinationTensors, |
| 515 | YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { |
| 516 | assert(!loopRanges.empty() && "unexpected empty loop ranges" ); |
| 517 | assert(loopRanges.size() == tileSizes.size() && |
| 518 | "expected as many tile sizes as loop ranges" ); |
| 519 | OpBuilder::InsertionGuard guard(rewriter); |
| 520 | |
| 521 | std::optional<ArrayAttr> mappingAttr; |
| 522 | if (!mappingVector.empty()) |
| 523 | mappingAttr = rewriter.getArrayAttr(value: mappingVector); |
| 524 | |
| 525 | scf::ForallOp forallOp; |
| 526 | bool useNumThreads = !numThreads.empty(); |
| 527 | |
| 528 | if (useNumThreads) { |
| 529 | // Prune the zero numthreads. |
| 530 | SmallVector<OpFoldResult> nonZeroNumThreads; |
| 531 | for (auto nt : numThreads) { |
| 532 | if (isZeroInteger(v: nt)) |
| 533 | continue; |
| 534 | nonZeroNumThreads.push_back(Elt: nt); |
| 535 | } |
| 536 | forallOp = rewriter.create<scf::ForallOp>(location: loc, args&: nonZeroNumThreads, |
| 537 | args&: destinationTensors, args&: mappingAttr); |
| 538 | } else { |
| 539 | SmallVector<OpFoldResult> lbs, ubs, steps; |
| 540 | std::tie(args&: lbs, args&: ubs, args&: steps) = |
| 541 | getLoopBounds(rewriter, loc, loopRanges, tileSizes); |
| 542 | forallOp = rewriter.create<scf::ForallOp>(location: loc, args&: lbs, args&: ubs, args&: steps, |
| 543 | args&: destinationTensors, args&: mappingAttr); |
| 544 | } |
| 545 | loops.push_back(Elt: forallOp); |
| 546 | |
| 547 | rewriter.setInsertionPoint(forallOp.getTerminator()); |
| 548 | destinationTensors = forallOp.getRegionOutArgs(); |
| 549 | |
| 550 | SmallVector<Value> tiledResults; |
| 551 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 552 | if (failed(Result: tiledBodyFn(rewriter, loc, forallOp.getInductionVars(), |
| 553 | destinationTensors, tiledResults, resultOffsets, |
| 554 | resultSizes))) |
| 555 | return rewriter.notifyMatchFailure(arg&: loc, msg: "failed to generate loop body" ); |
| 556 | |
| 557 | rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); |
| 558 | for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : |
| 559 | llvm::zip_equal(t&: tiledResults, u&: destinationTensors, args&: resultOffsets, |
| 560 | args&: resultSizes)) { |
| 561 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 562 | rewriter.getIndexAttr(value: 1)); |
| 563 | |
| 564 | rewriter.create<tensor::ParallelInsertSliceOp>( |
| 565 | location: loc, args&: tiledValue, args&: destinationTensor, args&: resultOffset, args&: resultSize, |
| 566 | args&: resultStride); |
| 567 | } |
| 568 | return success(); |
| 569 | } |
| 570 | |
| 571 | /// Generate the tile-loop nest using the loop construct specifed in `options`. |
| 572 | /// - `options`: Tiling options specified. |
| 573 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. |
| 574 | /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. |
| 575 | /// - `destinationTensors` are the init values to use for the outer most loop. |
| 576 | /// - `yieldTiledValuesFn` is called to generated the loop body of the inner |
| 577 | /// most |
| 578 | /// loop. |
| 579 | /// - `loops` is an in-out parameter into which the generated loops are |
| 580 | /// populated. |
| 581 | static LogicalResult generateLoopNest( |
| 582 | RewriterBase &rewriter, Location loc, |
| 583 | scf::SCFTilingOptions::LoopType loopType, ArrayRef<Range> loopRanges, |
| 584 | ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads, |
| 585 | ValueRange destinationTensors, ArrayRef<Attribute> mappingVector, |
| 586 | YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { |
| 587 | // If the tile sizes are all zero, no loops are generated. Just call the |
| 588 | // callback function to handle untiled case. |
| 589 | if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) { |
| 590 | SmallVector<Value> tiledResults; |
| 591 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 592 | return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, |
| 593 | tiledResults, resultOffsets, resultSizes); |
| 594 | } |
| 595 | if (loopType == scf::SCFTilingOptions::LoopType::ForOp) { |
| 596 | return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, |
| 597 | destinationTensors, yieldTiledValuesFn: tiledBodyFn, loops); |
| 598 | } |
| 599 | if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) { |
| 600 | return generateLoopNestUsingForallOp( |
| 601 | rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector, |
| 602 | destinationTensors, tiledBodyFn, loops); |
| 603 | } |
| 604 | return rewriter.notifyMatchFailure(arg&: loc, msg: "unhandled loop type" ); |
| 605 | } |
| 606 | |
| 607 | static FailureOr<SmallVector<Value>> createInitialTensorsForTiling( |
| 608 | RewriterBase &rewriter, TilingInterface op, |
| 609 | ReductionTilingStrategy reductionStrategy, ArrayRef<Range> iterationDomain, |
| 610 | ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes, |
| 611 | const SetVector<unsigned> &reductionDims) { |
| 612 | SmallVector<Value> initTensors; |
| 613 | Location loc = op->getLoc(); |
| 614 | if (reductionStrategy == ReductionTilingStrategy::FullReduction) { |
| 615 | if (failed(Result: tensor::getOrCreateDestinations(b&: rewriter, loc, op, result&: initTensors))) |
| 616 | return failure(); |
| 617 | return initTensors; |
| 618 | } |
| 619 | |
| 620 | auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation()); |
| 621 | if (!redOp) { |
| 622 | return op->emitOpError( |
| 623 | message: "PartialReductionOuterReduction tiling strategy is only supported for " |
| 624 | "operations implementing PartialReductionOpInterface" ); |
| 625 | } |
| 626 | SmallVector<OpFoldResult> sizes(iterationDomain.size()); |
| 627 | AffineExpr s0, s1, s2; |
| 628 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1, exprs&: s2); |
| 629 | AffineExpr sizeExpr = ((s0 - s1).ceilDiv(other: s2)); |
| 630 | AffineExpr divExpr = s0.ceilDiv(other: s1); |
| 631 | for (auto [index, domain, tileSize] : |
| 632 | llvm::enumerate(First&: iterationDomain, Rest&: tileSizes)) { |
| 633 | if (!numThreads.empty()) { |
| 634 | // Untiled case. |
| 635 | if (isConstantIntValue(ofr: numThreads[index], value: 0)) { |
| 636 | sizes[index] = affine::makeComposedFoldedAffineApply( |
| 637 | b&: rewriter, loc: op.getLoc(), expr: sizeExpr, |
| 638 | operands: {domain.size, domain.offset, domain.stride}); |
| 639 | continue; |
| 640 | } |
| 641 | sizes[index] = numThreads[index]; |
| 642 | continue; |
| 643 | } |
| 644 | |
| 645 | // Non reduction dimensions/non-tiled dimensions. |
| 646 | if (!reductionDims.contains(key: index) || isConstantIntValue(ofr: tileSize, value: 0)) { |
| 647 | sizes[index] = affine::makeComposedFoldedAffineApply( |
| 648 | b&: rewriter, loc: op.getLoc(), expr: sizeExpr, |
| 649 | operands: {domain.size, domain.offset, domain.stride}); |
| 650 | continue; |
| 651 | } |
| 652 | |
| 653 | if (reductionStrategy == |
| 654 | ReductionTilingStrategy::PartialReductionOuterReduction) { |
| 655 | sizes[index] = tileSize; |
| 656 | continue; |
| 657 | } |
| 658 | |
| 659 | assert(reductionStrategy == |
| 660 | ReductionTilingStrategy::PartialReductionOuterParallel); |
| 661 | OpFoldResult normalizedRange = affine::makeComposedFoldedAffineApply( |
| 662 | b&: rewriter, loc: op.getLoc(), expr: sizeExpr, |
| 663 | operands: {domain.size, domain.offset, domain.stride}); |
| 664 | sizes[index] = affine::makeComposedFoldedAffineApply( |
| 665 | b&: rewriter, loc: op.getLoc(), expr: divExpr, operands: {normalizedRange, tileSize}); |
| 666 | } |
| 667 | return redOp.generateInitialTensorForPartialReduction(b&: rewriter, loc, tileSizes: sizes, |
| 668 | reductionDims); |
| 669 | } |
| 670 | |
| 671 | /// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel` |
| 672 | /// the `PartialReductionOpInterface` methods need the index of the parallel |
| 673 | /// split reduction being executed. |
| 674 | static SmallVector<OpFoldResult> |
| 675 | getSplitReductionIvs(RewriterBase &rewriter, Location loc, |
| 676 | ReductionTilingStrategy reductionStrategy, ValueRange ivs, |
| 677 | ArrayRef<OpFoldResult> numThreads, |
| 678 | ArrayRef<OpFoldResult> tileSizes, |
| 679 | const SetVector<unsigned> &reductionDims) { |
| 680 | SmallVector<OpFoldResult> splitReductionIvs; |
| 681 | splitReductionIvs.resize(N: reductionDims.size(), NV: rewriter.getIndexAttr(value: 0)); |
| 682 | AffineExpr s0, s1; |
| 683 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0, exprs&: s1); |
| 684 | AffineExpr divExpr = s0.floorDiv(other: s1); |
| 685 | int ivIndex = 0; |
| 686 | if (reductionStrategy == |
| 687 | ReductionTilingStrategy::PartialReductionOuterParallel) { |
| 688 | for (auto [index, reductionDim] : llvm::enumerate(First: reductionDims)) { |
| 689 | if (!numThreads.empty()) { |
| 690 | splitReductionIvs[index] = ivs[ivIndex++]; |
| 691 | continue; |
| 692 | } |
| 693 | splitReductionIvs[index] = affine::makeComposedFoldedAffineApply( |
| 694 | b&: rewriter, loc, expr: divExpr, |
| 695 | operands: ArrayRef<OpFoldResult>{ivs[ivIndex++], tileSizes[reductionDim]}); |
| 696 | } |
| 697 | } |
| 698 | return splitReductionIvs; |
| 699 | } |
| 700 | |
| 701 | static FailureOr<TilingResult> |
| 702 | getTiledImplementation(RewriterBase &rewriter, TilingInterface op, |
| 703 | ReductionTilingStrategy reductionStrategy, |
| 704 | ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets, |
| 705 | ArrayRef<OpFoldResult> sizes, ValueRange ivs, |
| 706 | ArrayRef<OpFoldResult> numThreads, |
| 707 | ArrayRef<OpFoldResult> tileSizes, |
| 708 | const SetVector<unsigned> &reductionDims) { |
| 709 | if (reductionStrategy == ReductionTilingStrategy::FullReduction) { |
| 710 | return op.getTiledImplementation(b&: rewriter, offsets, sizes); |
| 711 | } |
| 712 | |
| 713 | auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation()); |
| 714 | if (!redOp) { |
| 715 | return rewriter.notifyMatchFailure( |
| 716 | arg&: op, msg: "PartialReductionOuterReduction tiling strategy is only " |
| 717 | "supported for operations " |
| 718 | "implementing PartialReductionOpInterface" ); |
| 719 | } |
| 720 | |
| 721 | SmallVector<OpFoldResult> splitReductionIvs = |
| 722 | getSplitReductionIvs(rewriter, loc: op.getLoc(), reductionStrategy, ivs, |
| 723 | numThreads, tileSizes, reductionDims); |
| 724 | return redOp.tileToPartialReduction(b&: rewriter, loc: op.getLoc(), tilingStrategy: reductionStrategy, |
| 725 | init: regionIterArg, offsets, sizes, |
| 726 | reductionDims, splitReductionIvs); |
| 727 | } |
| 728 | |
| 729 | static LogicalResult getResultTilePosition( |
| 730 | RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, |
| 731 | int64_t index, Value tiledResult, TilingInterface op, |
| 732 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 733 | ValueRange ivs, ArrayRef<OpFoldResult> numThreads, |
| 734 | ArrayRef<OpFoldResult> tileSizes, const SetVector<unsigned> &reductionDims, |
| 735 | SmallVector<OpFoldResult> &resultOffset, |
| 736 | SmallVector<OpFoldResult> &resultSize) { |
| 737 | |
| 738 | if (reductionStrategy == ReductionTilingStrategy::FullReduction) { |
| 739 | return op.getResultTilePosition(b&: rewriter, resultNumber: index, offsets, sizes, |
| 740 | resultOffsets&: resultOffset, resultSizes&: resultSize); |
| 741 | } |
| 742 | auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation()); |
| 743 | if (!redOp) { |
| 744 | return rewriter.notifyMatchFailure( |
| 745 | arg&: op, msg: "PartialReductionOuterReduction tiling strategy is only supported" |
| 746 | "for operations implementing PartialReductionOpInterface" ); |
| 747 | } |
| 748 | SmallVector<OpFoldResult> splitReductionIvs = |
| 749 | getSplitReductionIvs(rewriter, loc: op.getLoc(), reductionStrategy, ivs, |
| 750 | numThreads, tileSizes, reductionDims); |
| 751 | return redOp.getPartialResultTilePosition( |
| 752 | b&: rewriter, resultNumber: index, tilingStrategy: reductionStrategy, offsets, sizes, reductionDims, |
| 753 | splitReductionIvs, resultOffsets&: resultOffset, resultSizes&: resultSize); |
| 754 | } |
| 755 | |
| 756 | static FailureOr<MergeResult> |
| 757 | mergeTilingResults(RewriterBase &rewriter, TilingInterface op, |
| 758 | ReductionTilingStrategy reductionStrategy, |
| 759 | const SetVector<unsigned> &reductionDims, |
| 760 | ValueRange partialResults) { |
| 761 | assert(reductionStrategy != ReductionTilingStrategy::FullReduction && |
| 762 | "expected merge to be called for only partial reduction cases" ); |
| 763 | |
| 764 | auto redOp = dyn_cast<PartialReductionOpInterface>(Val: op.getOperation()); |
| 765 | if (!redOp) { |
| 766 | return rewriter.notifyMatchFailure( |
| 767 | arg&: op, msg: "PartialReductionOuterReduction tiling strategy is only " |
| 768 | "supported for operations " |
| 769 | "implementing PartialReductionOpInterface" ); |
| 770 | } |
| 771 | return redOp.mergeReductions(b&: rewriter, loc: op.getLoc(), partialReduce: partialResults, |
| 772 | reductionDims); |
| 773 | } |
| 774 | |
| 775 | /// Append the specified additional `newInitOperands` operands to the |
| 776 | /// loops existing `init` operands (or similar), and replace `loopOp` with |
| 777 | /// the new loop that has the additional init operands. The loop body of |
| 778 | /// this loop is moved over to the new loop. `yieldTiledValuesFn` |
| 779 | /// is called to get the new tiled values returned, and the offset |
| 780 | /// and sizes at which the tiled value is inserted into the |
| 781 | /// new region iter_args that correspond to the newly added init operands. |
| 782 | template <typename LoopType> |
| 783 | FailureOr<LoopLikeOpInterface> |
| 784 | yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, |
| 785 | ValueRange newInitOperands, |
| 786 | YieldTiledValuesFn yieldTiledValuesFn) { |
| 787 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type" ); |
| 788 | } |
| 789 | |
| 790 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. |
| 791 | template <> |
| 792 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( |
| 793 | scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
| 794 | YieldTiledValuesFn yieldTiledValuesFn) { |
| 795 | OpBuilder::InsertionGuard g(rewriter); |
| 796 | Location loc = loopOp.getLoc(); |
| 797 | rewriter.setInsertionPoint(loopOp); |
| 798 | |
| 799 | auto inits = llvm::to_vector(Range: loopOp.getInitArgs()); |
| 800 | inits.append(in_start: newInitOperands.begin(), in_end: newInitOperands.end()); |
| 801 | auto newLoop = rewriter.create<scf::ForOp>( |
| 802 | location: loc, args: loopOp.getLowerBound(), args: loopOp.getUpperBound(), args: loopOp.getStep(), |
| 803 | args&: inits, args: [](OpBuilder &, Location, Value, ValueRange) {}); |
| 804 | |
| 805 | // Move the loop body to the new op. |
| 806 | Block *loopBody = loopOp.getBody(); |
| 807 | Block *newLoopBody = newLoop.getBody(); |
| 808 | rewriter.mergeBlocks( |
| 809 | source: loopBody, dest: newLoopBody, |
| 810 | argValues: newLoopBody->getArguments().take_front(N: loopBody->getNumArguments())); |
| 811 | |
| 812 | auto yieldOp = cast<scf::YieldOp>(Val: newLoopBody->getTerminator()); |
| 813 | rewriter.setInsertionPoint(yieldOp); |
| 814 | |
| 815 | SmallVector<Value> tiledValues; |
| 816 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 817 | ValueRange newRegionIterArgs = |
| 818 | newLoop.getRegionIterArgs().take_back(N: newInitOperands.size()); |
| 819 | if (failed(Result: yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), |
| 820 | newRegionIterArgs, tiledValues, resultOffsets, |
| 821 | resultSizes))) { |
| 822 | rewriter.eraseOp(op: newLoop); |
| 823 | return rewriter.notifyMatchFailure(arg&: loopOp, msg: "failed to get tiled values" ); |
| 824 | } |
| 825 | |
| 826 | SmallVector<Value> newYieldValues = llvm::to_vector(Range: yieldOp.getOperands()); |
| 827 | for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : |
| 828 | llvm::zip_equal(t&: tiledValues, u&: newRegionIterArgs, args&: resultOffsets, |
| 829 | args&: resultSizes)) { |
| 830 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 831 | rewriter.getIndexAttr(value: 1)); |
| 832 | Value insert = rewriter.create<tensor::InsertSliceOp>( |
| 833 | location: yieldOp->getLoc(), args&: tiledValue, args&: regionIterArg, args&: resultOffset, args&: resultSize, |
| 834 | args&: resultStride); |
| 835 | newYieldValues.push_back(Elt: insert); |
| 836 | } |
| 837 | |
| 838 | rewriter.replaceOpWithNewOp<scf::YieldOp>(op: yieldOp, args&: newYieldValues); |
| 839 | rewriter.replaceOp(op: loopOp, |
| 840 | newValues: newLoop->getResults().take_front(n: loopOp.getNumResults())); |
| 841 | return cast<LoopLikeOpInterface>(Val: newLoop.getOperation()); |
| 842 | } |
| 843 | |
| 844 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` |
| 845 | template <> |
| 846 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( |
| 847 | scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, |
| 848 | YieldTiledValuesFn yieldTiledValuesFn) { |
| 849 | OpBuilder::InsertionGuard g(rewriter); |
| 850 | Location loc = loopOp.getLoc(); |
| 851 | rewriter.setInsertionPoint(loopOp); |
| 852 | auto inits = llvm::to_vector(Range: loopOp.getOutputs()); |
| 853 | inits.append(in_start: newInitOperands.begin(), in_end: newInitOperands.end()); |
| 854 | auto newLoop = rewriter.create<scf::ForallOp>( |
| 855 | location: loc, args: loopOp.getMixedLowerBound(), args: loopOp.getMixedUpperBound(), |
| 856 | args: loopOp.getMixedStep(), args&: inits, args: loopOp.getMapping(), |
| 857 | args: [](OpBuilder &, Location, ValueRange) {}); |
| 858 | |
| 859 | // Move the region of the current block to the newly created op. |
| 860 | Block *loopBody = loopOp.getBody(); |
| 861 | Block *newLoopBody = newLoop.getBody(); |
| 862 | rewriter.mergeBlocks( |
| 863 | source: loopBody, dest: newLoopBody, |
| 864 | argValues: newLoopBody->getArguments().take_front(N: loopBody->getNumArguments())); |
| 865 | |
| 866 | auto terminator = cast<scf::InParallelOp>(Val: newLoopBody->getTerminator()); |
| 867 | rewriter.setInsertionPoint(terminator); |
| 868 | SmallVector<Value> tiledValues; |
| 869 | SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes; |
| 870 | ValueRange regionIterArgs = |
| 871 | newLoop.getRegionIterArgs().take_back(N: newInitOperands.size()); |
| 872 | if (failed(Result: yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), |
| 873 | regionIterArgs, tiledValues, resultOffsets, |
| 874 | resultSizes))) { |
| 875 | rewriter.eraseOp(op: newLoop); |
| 876 | return rewriter.notifyMatchFailure(arg&: loopOp, |
| 877 | msg: "failed to get yielded tiled values" ); |
| 878 | } |
| 879 | |
| 880 | // Update the terminator. |
| 881 | rewriter.setInsertionPointToEnd(terminator.getBody()); |
| 882 | |
| 883 | for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( |
| 884 | t&: tiledValues, u&: regionIterArgs, args&: resultOffsets, args&: resultSizes)) { |
| 885 | SmallVector<OpFoldResult> resultStride(resultOffset.size(), |
| 886 | rewriter.getIndexAttr(value: 1)); |
| 887 | rewriter.create<tensor::ParallelInsertSliceOp>( |
| 888 | location: terminator.getLoc(), args&: tiledValue, args&: iterArg, args&: resultOffset, args&: resultSize, |
| 889 | args&: resultStride); |
| 890 | } |
| 891 | |
| 892 | rewriter.replaceOp(op: loopOp, |
| 893 | newValues: newLoop->getResults().take_front(n: loopOp.getNumResults())); |
| 894 | return cast<LoopLikeOpInterface>(Val: newLoop.getOperation()); |
| 895 | } |
| 896 | |
| 897 | /// Implementation of `yieldTiledValuesAndReplaceLoop` for |
| 898 | /// `LoopLikeOpInterface`, that just dispatches to the implementation for each |
| 899 | /// supported loop type. |
| 900 | FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop( |
| 901 | LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, |
| 902 | ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { |
| 903 | return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>( |
| 904 | loopLikeOp.getOperation()) |
| 905 | .Case<scf::ForOp, scf::ForallOp>( |
| 906 | caseFn: [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
| 907 | return yieldTiledValuesAndReplaceLoop( |
| 908 | loopOp, rewriter, newInitOperands, yieldTiledValuesFn); |
| 909 | }) |
| 910 | .Default(defaultFn: [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> { |
| 911 | return rewriter.notifyMatchFailure(loopOp, "unhandled loop type" ); |
| 912 | }); |
| 913 | } |
| 914 | |
| 915 | /// Method to add new init values to a loop nest. Updates `loops` in-place |
| 916 | /// with new loops that use the `newInitValues`. The outer-loops are updated |
| 917 | /// to yield the new result values of the inner loop. For the innermost loop, |
| 918 | /// the call back `getNewYields` is invoked to get the additional values to |
| 919 | /// yield form the innermost loop. |
| 920 | static LogicalResult addInitOperandsToLoopNest( |
| 921 | RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops, |
| 922 | ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { |
| 923 | if (loops.empty()) |
| 924 | return success(); |
| 925 | OpBuilder::InsertionGuard g(rewriter); |
| 926 | rewriter.setInsertionPoint(loops.front()); |
| 927 | |
| 928 | SmallVector<Value> ivs; |
| 929 | for (auto &loop : loops.drop_back()) { |
| 930 | rewriter.setInsertionPoint(loop); |
| 931 | |
| 932 | // if loops.size() > 1 we assume that scf.for is used for the loops. |
| 933 | auto forLoop = cast<scf::ForOp>(Val: loop.getOperation()); |
| 934 | |
| 935 | // Create a new loop with the new init values for this loop. |
| 936 | SmallVector<Value> newInits = llvm::to_vector(Range: forLoop.getInitArgs()); |
| 937 | newInits.append(in_start: newInitValues.begin(), in_end: newInitValues.end()); |
| 938 | auto newLoop = rewriter.create<scf::ForOp>( |
| 939 | location: forLoop.getLoc(), args: forLoop.getLowerBound(), args: forLoop.getUpperBound(), |
| 940 | args: forLoop.getStep(), args&: newInits, |
| 941 | args: [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); |
| 942 | |
| 943 | // Merge the body of the new loop with the body of the old loops. |
| 944 | SmallVector<Value> sourceBlockArgs; |
| 945 | sourceBlockArgs.push_back(Elt: newLoop.getInductionVar()); |
| 946 | auto newRegionIterArgs = newLoop.getRegionIterArgs(); |
| 947 | sourceBlockArgs.append( |
| 948 | in_start: newRegionIterArgs.begin(), |
| 949 | in_end: std::next(x: newRegionIterArgs.begin(), n: forLoop.getNumResults())); |
| 950 | rewriter.mergeBlocks(source: forLoop.getBody(), dest: newLoop.getBody(), argValues: sourceBlockArgs); |
| 951 | rewriter.replaceOp( |
| 952 | op: forLoop, newValues: newLoop.getResults().take_front(n: forLoop.getNumResults())); |
| 953 | loop = newLoop; |
| 954 | ivs.push_back(Elt: newLoop.getInductionVar()); |
| 955 | newInitValues = newLoop.getRegionIterArgs().take_back(N: newInitValues.size()); |
| 956 | } |
| 957 | |
| 958 | // Update the loop body of the innermost loop to get new yield values. |
| 959 | LoopLikeOpInterface innerMostLoop = loops.back(); |
| 960 | FailureOr<LoopLikeOpInterface> newInnerMostLoop = |
| 961 | yieldTiledValuesAndReplaceLoop(loopLikeOp: innerMostLoop, rewriter, newInitOperands: newInitValues, |
| 962 | yieldTiledValuesFn: getNewTiledYieldsFn); |
| 963 | |
| 964 | if (failed(Result: newInnerMostLoop)) |
| 965 | return innerMostLoop.emitOpError(message: "failed to return additional yields" ); |
| 966 | loops.back() = newInnerMostLoop.value(); |
| 967 | |
| 968 | // Make all other loops except the innermost loops yield the values returned |
| 969 | // by the inner loop. |
| 970 | for (auto [outerLoop, innerLoop] : |
| 971 | llvm::zip_equal(t: loops.drop_back(), u: loops.drop_front())) { |
| 972 | // Again assume that all the outer loops are scf.for operations. |
| 973 | auto outerForLoop = cast<scf::ForOp>(Val&: outerLoop); |
| 974 | auto outerLoopYield = |
| 975 | cast<scf::YieldOp>(Val: outerForLoop.getBody()->getTerminator()); |
| 976 | SmallVector<Value> newYields = |
| 977 | llvm::to_vector(Range: outerLoopYield.getOperands()); |
| 978 | ValueRange additionalYields = |
| 979 | innerLoop->getResults().take_back(n: newInitValues.size()); |
| 980 | newYields.append(in_start: additionalYields.begin(), in_end: additionalYields.end()); |
| 981 | rewriter.setInsertionPoint(outerLoopYield); |
| 982 | rewriter.replaceOpWithNewOp<scf::YieldOp>(op: outerLoopYield, args&: newYields); |
| 983 | } |
| 984 | return success(); |
| 985 | } |
| 986 | |
| 987 | /// Implementation of tiling transformation of `op` that implements the |
| 988 | /// `TilingInterface` using `scf.for` to iterate over the tiles. |
| 989 | FailureOr<scf::SCFTilingResult> |
| 990 | mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, |
| 991 | const scf::SCFTilingOptions &options) { |
| 992 | if (failed(Result: verifyOptions(rewriter, loc: op.getLoc(), options))) { |
| 993 | return failure(); |
| 994 | } |
| 995 | |
| 996 | OpBuilder::InsertionGuard guard(rewriter); |
| 997 | rewriter.setInsertionPointAfter(op); |
| 998 | |
| 999 | // 1. Get the range of the loops that are represented by the operation. |
| 1000 | SmallVector<Range> iterationDomain = op.getIterationDomain(b&: rewriter); |
| 1001 | |
| 1002 | // 2. Materialize the tile sizes and/or number of threads; |
| 1003 | SmallVector<OpFoldResult> tileSizes, numThreads; |
| 1004 | std::tie(args&: tileSizes, args&: numThreads) = |
| 1005 | getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); |
| 1006 | |
| 1007 | // Check if it is safe to tile. This is hold over from previous iterations |
| 1008 | // of tile to for-all. Consider dropping it. |
| 1009 | if (failed(Result: checkTileSizes(op, loopType: options.loopType, reductionStrategy: options.reductionStrategy, |
| 1010 | tileSizes, numThreads))) { |
| 1011 | return failure(); |
| 1012 | } |
| 1013 | |
| 1014 | // Get the reduction dims |
| 1015 | SetVector<unsigned> reductionDims = |
| 1016 | getSanitizedReductionDims(tileSizes, options); |
| 1017 | |
| 1018 | // 3. If there is an interchange specified, permute the iteration domain and |
| 1019 | // the tile sizes. |
| 1020 | SmallVector<int64_t> interchangeVector; |
| 1021 | if (!options.interchangeVector.empty()) { |
| 1022 | interchangeVector = fillInterchangeVector(interchangeVector: options.interchangeVector, |
| 1023 | iterationDomainSize: iterationDomain.size()); |
| 1024 | assert(isPermutationVector(interchangeVector) && |
| 1025 | "expected interchange vector to be a permutation" ); |
| 1026 | |
| 1027 | applyPermutationToVector(inVec&: iterationDomain, permutation: interchangeVector); |
| 1028 | applyPermutationToVector(inVec&: tileSizes, permutation: interchangeVector); |
| 1029 | if (!numThreads.empty()) |
| 1030 | applyPermutationToVector(inVec&: numThreads, permutation: interchangeVector); |
| 1031 | } |
| 1032 | |
| 1033 | FailureOr<TilingResult> tilingResult; |
| 1034 | // 4. Define the lambda function used later to generate the body of the |
| 1035 | // innermost tiled loop. |
| 1036 | YieldTiledValuesFn innerYieldTiledValuesFn = |
| 1037 | [&](RewriterBase &rewriter, Location loc, ValueRange ivs, |
| 1038 | ValueRange regionIterArgs, SmallVector<Value> &tiledResults, |
| 1039 | SmallVector<SmallVector<OpFoldResult>> &resultOffsets, |
| 1040 | SmallVector<SmallVector<OpFoldResult>> &resultSizes) |
| 1041 | -> LogicalResult { |
| 1042 | // 4a. Compute the `offsets` and `sizes` to use for tiling. |
| 1043 | SmallVector<OpFoldResult> offsets, sizes; |
| 1044 | std::tie(args&: offsets, args&: sizes) = getTileOffsetAndSizes( |
| 1045 | rewriter, loc, strategy: options.reductionStrategy, ivs, iterationDomain, |
| 1046 | tileSizes, numThreads, reductionDims); |
| 1047 | |
| 1048 | // 4b. If interchange was provided, apply inverse of the interchange |
| 1049 | // to get back the offsets/sizes in the order to be specified. |
| 1050 | if (!interchangeVector.empty()) { |
| 1051 | auto inversePermutation = invertPermutationVector(permutation: interchangeVector); |
| 1052 | applyPermutationToVector(inVec&: offsets, permutation: inversePermutation); |
| 1053 | applyPermutationToVector(inVec&: sizes, permutation: inversePermutation); |
| 1054 | } |
| 1055 | |
| 1056 | // 5. Generate the tiled implementation within the inner most loop. |
| 1057 | |
| 1058 | // 5a. Clone the operation within the loop body. |
| 1059 | auto clonedOp = cast<TilingInterface>( |
| 1060 | Val: cloneOpAndUpdateDestinationArgs(rewriter, op, newDestArgs: regionIterArgs)); |
| 1061 | |
| 1062 | // 5b. Early return cloned op if tiling is not happening. We can not |
| 1063 | // return the original op because it could lead to `rewriter.replaceOp(op, |
| 1064 | // op->getResults())` and users would get crash. |
| 1065 | if (llvm::all_of(Range&: tileSizes, P: isZeroInteger)) { |
| 1066 | tiledResults.append(in_start: clonedOp->result_begin(), in_end: clonedOp->result_end()); |
| 1067 | tilingResult = |
| 1068 | TilingResult{/*tiledOps=*/{clonedOp}, .tiledValues: clonedOp->getResults(), |
| 1069 | /*generatedSlices=*/{}}; |
| 1070 | return success(); |
| 1071 | } |
| 1072 | |
| 1073 | // 5c. Tile the cloned operation. |
| 1074 | tilingResult = getTiledImplementation( |
| 1075 | rewriter, op: clonedOp, reductionStrategy: options.reductionStrategy, regionIterArg: regionIterArgs, offsets, |
| 1076 | sizes, ivs, numThreads, tileSizes, reductionDims); |
| 1077 | if (failed(Result: tilingResult)) { |
| 1078 | rewriter.eraseOp(op: clonedOp); |
| 1079 | return op.emitOpError(message: "faild to tile operation" ); |
| 1080 | } |
| 1081 | |
| 1082 | // 5d. Delete the cloned operation. |
| 1083 | rewriter.eraseOp(op: clonedOp); |
| 1084 | |
| 1085 | // 5e. Compute the offsets at which the result values are to be inserted |
| 1086 | // back into its destinations. |
| 1087 | for (auto [index, tiledValue] : |
| 1088 | llvm::enumerate(First&: tilingResult->tiledValues)) { |
| 1089 | tiledResults.push_back(Elt: tiledValue); |
| 1090 | SmallVector<OpFoldResult> resultOffset, resultSize; |
| 1091 | if (failed(Result: getResultTilePosition( |
| 1092 | rewriter, reductionStrategy: options.reductionStrategy, index, tiledResult: tiledValue, op, |
| 1093 | offsets, sizes, ivs, numThreads, tileSizes, reductionDims, |
| 1094 | resultOffset, resultSize))) { |
| 1095 | for (auto op : tilingResult->tiledOps) { |
| 1096 | rewriter.eraseOp(op); |
| 1097 | } |
| 1098 | return rewriter.notifyMatchFailure( |
| 1099 | arg&: op, msg: "failed to get slice of result produced" ); |
| 1100 | } |
| 1101 | resultOffsets.emplace_back(Args: std::move(resultOffset)); |
| 1102 | resultSizes.emplace_back(Args: std::move(resultSize)); |
| 1103 | } |
| 1104 | |
| 1105 | return success(); |
| 1106 | }; |
| 1107 | |
| 1108 | // 6. Find the destination tensors to use for the operation. |
| 1109 | FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling( |
| 1110 | rewriter, op, reductionStrategy: options.reductionStrategy, iterationDomain, numThreads, |
| 1111 | tileSizes, reductionDims); |
| 1112 | if (failed(Result: maybeInits)) { |
| 1113 | return rewriter.notifyMatchFailure( |
| 1114 | arg&: op, msg: "unable to create initial tensors for tiling" ); |
| 1115 | } |
| 1116 | SmallVector<Value> &initTensors = maybeInits.value(); |
| 1117 | |
| 1118 | // 7. Generate the tiled loops nest using the callback defined above. |
| 1119 | SmallVector<LoopLikeOpInterface> loops; |
| 1120 | if (failed(Result: generateLoopNest(rewriter, loc: op.getLoc(), loopType: options.loopType, |
| 1121 | loopRanges: iterationDomain, tileSizes, numThreads, |
| 1122 | destinationTensors: initTensors, mappingVector: options.mappingVector, |
| 1123 | tiledBodyFn: innerYieldTiledValuesFn, loops))) |
| 1124 | return op.emitOpError(message: "failed to generate tiling loops" ); |
| 1125 | assert(succeeded(tilingResult) && |
| 1126 | "expected tiling result to be computed after loop generation" ); |
| 1127 | |
| 1128 | if (loops.empty()) { |
| 1129 | // If loops are empty, the tiled op is used as the replacement for the |
| 1130 | // untiled op. |
| 1131 | return scf::SCFTilingResult{.tiledOps: tilingResult->tiledOps, |
| 1132 | .initialValues: initTensors, |
| 1133 | .loops: loops, |
| 1134 | .replacements: tilingResult->tiledValues, |
| 1135 | .generatedSlices: tilingResult->generatedSlices, |
| 1136 | .mergeOps: {}}; |
| 1137 | } |
| 1138 | |
| 1139 | auto loopResults = llvm::map_to_vector(C: loops.front()->getResults(), |
| 1140 | F: [](OpResult r) -> Value { return r; }); |
| 1141 | |
| 1142 | // For the full reduction case, there is nothing more to do. |
| 1143 | if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { |
| 1144 | return scf::SCFTilingResult{ |
| 1145 | .tiledOps: tilingResult->tiledOps, .initialValues: initTensors, .loops: loops, .replacements: loopResults, |
| 1146 | .generatedSlices: tilingResult->generatedSlices, .mergeOps: {}}; |
| 1147 | } |
| 1148 | |
| 1149 | // The results of the loop needs to be merged. |
| 1150 | FailureOr<MergeResult> mergeResult = mergeTilingResults( |
| 1151 | rewriter, op, reductionStrategy: options.reductionStrategy, reductionDims, partialResults: loopResults); |
| 1152 | if (failed(Result: mergeResult)) { |
| 1153 | return rewriter.notifyMatchFailure( |
| 1154 | arg&: op, msg: "Failed to merge partial results from tiling" ); |
| 1155 | } |
| 1156 | return scf::SCFTilingResult{.tiledOps: tilingResult->tiledOps, |
| 1157 | .initialValues: initTensors, |
| 1158 | .loops: loops, |
| 1159 | .replacements: mergeResult->replacements, |
| 1160 | .generatedSlices: tilingResult->generatedSlices, |
| 1161 | .mergeOps: mergeResult->mergeOps}; |
| 1162 | } |
| 1163 | |
| 1164 | FailureOr<scf::SCFTilingResult> |
| 1165 | mlir::scf::tileReductionUsingScf(RewriterBase &b, |
| 1166 | PartialReductionOpInterface op, |
| 1167 | ArrayRef<OpFoldResult> tileSize) { |
| 1168 | scf::SCFTilingOptions options; |
| 1169 | options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); |
| 1170 | options.setReductionTilingStrategy( |
| 1171 | ReductionTilingStrategy::PartialReductionOuterReduction); |
| 1172 | options.setTileSizes(tileSize); |
| 1173 | SmallVector<unsigned> reductionDims; |
| 1174 | for (auto [index, iteratorType] : llvm::enumerate(First: op.getLoopIteratorTypes())) |
| 1175 | if (iteratorType == utils::IteratorType::reduction) |
| 1176 | reductionDims.push_back(Elt: index); |
| 1177 | options.setReductionDims(reductionDims); |
| 1178 | return tileUsingSCF(rewriter&: b, op, options); |
| 1179 | } |
| 1180 | |
| 1181 | //===----------------------------------------------------------------------===// |
| 1182 | // tileConsumerAndFuseProducersUsingSCF implementation. |
| 1183 | //===----------------------------------------------------------------------===// |
| 1184 | |
| 1185 | /// Return the untiled producer whose slice is used in a tiled consumer. The |
| 1186 | /// method traverses the tile loop nest (`loops`) if needed, and returns the |
| 1187 | /// `iter_args` of the outer most that is encountered. Traversing the |
| 1188 | /// iter_args indicates that this is a destination operand of the consumer. If |
| 1189 | /// there was no loop traversal needed, the second value of the returned tuple |
| 1190 | /// is empty. |
| 1191 | static std::tuple<OpResult, std::optional<OpOperand *>> |
| 1192 | getUntiledProducerFromSliceSource(OpOperand *source, |
| 1193 | ArrayRef<LoopLikeOpInterface> loops) { |
| 1194 | std::optional<OpOperand *> destinationIterArg; |
| 1195 | assert(!loops.empty() && "expected non empty loops container" ); |
| 1196 | auto loopIt = loops.rbegin(); |
| 1197 | while (loopIt != loops.rend() && isa<BlockArgument>(Val: source->get())) { |
| 1198 | auto iterArg = cast<BlockArgument>(Val: source->get()); |
| 1199 | auto loop = *loopIt; |
| 1200 | if (iterArg.getOwner()->getParentOp() != loop) |
| 1201 | break; |
| 1202 | source = loop.getTiedLoopInit(bbArg: iterArg); |
| 1203 | loopIt++; |
| 1204 | } |
| 1205 | if (loopIt == loops.rend()) |
| 1206 | destinationIterArg = source; |
| 1207 | return {dyn_cast<OpResult>(Val: source->get()), destinationIterArg}; |
| 1208 | } |
| 1209 | |
| 1210 | /// Implementation of fusing producer of a single slice by computing the |
| 1211 | /// slice of the producer in-place. |
| 1212 | std::optional<scf::SCFFuseProducerOfSliceResult> |
| 1213 | mlir::scf::tileAndFuseProducerOfSlice( |
| 1214 | RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, |
| 1215 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1216 | // 1. Get the producer of the source (potentially walking through |
| 1217 | // `iter_args` of nested `scf.for`) |
| 1218 | auto [fusableProducer, destinationInitArg] = |
| 1219 | getUntiledProducerFromSliceSource(source: &candidateSliceOp.getSourceMutable(), |
| 1220 | loops); |
| 1221 | if (!fusableProducer) |
| 1222 | return std::nullopt; |
| 1223 | unsigned resultNumber = fusableProducer.getResultNumber(); |
| 1224 | |
| 1225 | OpBuilder::InsertionGuard g(rewriter); |
| 1226 | rewriter.setInsertionPoint(candidateSliceOp); |
| 1227 | |
| 1228 | // 2. Clone the fused producer |
| 1229 | // 2a. Compute the destination operands to use for the cloned operation. |
| 1230 | SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors; |
| 1231 | Operation *fusableProducerOp = fusableProducer.getOwner(); |
| 1232 | if (isa<DestinationStyleOpInterface>(Val: fusableProducerOp) && |
| 1233 | failed(Result: tensor::getOrCreateDestinations( |
| 1234 | b&: rewriter, loc: fusableProducerOp->getLoc(), op: fusableProducerOp, |
| 1235 | result&: origDestinationTensors))) |
| 1236 | return std::nullopt; |
| 1237 | |
| 1238 | clonedOpDestinationTensors = origDestinationTensors; |
| 1239 | if (destinationInitArg && |
| 1240 | isa<DestinationStyleOpInterface>(Val: fusableProducerOp)) { |
| 1241 | // 2b. If the producer is also destination style, then to maintain the |
| 1242 | // destination passing style, update the destination of the producer to be |
| 1243 | // the source of the slice. |
| 1244 | clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); |
| 1245 | } |
| 1246 | // 2c. Clone the fused producer. |
| 1247 | Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( |
| 1248 | rewriter, op: fusableProducerOp, newDestArgs: clonedOpDestinationTensors); |
| 1249 | // 2d. Update the source of the candidateSlice to be the cloned producer. |
| 1250 | // Easier to just clone the slice with different source since |
| 1251 | // replacements and DCE of cloned ops becomes easier |
| 1252 | SmallVector<Value> candidateSliceOpOperands = |
| 1253 | llvm::to_vector(Range: candidateSliceOp->getOperands()); |
| 1254 | candidateSliceOpOperands[0] = clonedProducerOp->getResult(idx: resultNumber); |
| 1255 | tensor::ExtractSliceOp clonedCandidateSliceOp = |
| 1256 | mlir::clone(b&: rewriter, op: candidateSliceOp, |
| 1257 | newResultTypes: candidateSliceOp->getResultTypes(), newOperands: candidateSliceOpOperands); |
| 1258 | |
| 1259 | // 3. Generate the tiled implementation of the producer of the source |
| 1260 | FailureOr<TilingResult> tileAndFuseResult = |
| 1261 | tensor::replaceExtractSliceWithTiledProducer( |
| 1262 | builder&: rewriter, sliceOp: clonedCandidateSliceOp, |
| 1263 | producerOp: clonedProducerOp->getResult(idx: resultNumber)); |
| 1264 | if (failed(Result: tileAndFuseResult)) |
| 1265 | return std::nullopt; |
| 1266 | // Note: Do not delete the candidateSliceOp, since its passed in from the |
| 1267 | // caller. |
| 1268 | rewriter.replaceAllUsesWith(from: candidateSliceOp, |
| 1269 | to: tileAndFuseResult->tiledValues[0]); |
| 1270 | rewriter.eraseOp(op: clonedCandidateSliceOp); |
| 1271 | rewriter.eraseOp(op: clonedProducerOp); |
| 1272 | |
| 1273 | // 3. If the slice is for a destination operand, for example, |
| 1274 | // |
| 1275 | // ```mlir |
| 1276 | // %0 = linalg.init |
| 1277 | // %1 = linalg.fill .. outs(%0 : ) |
| 1278 | // %2 = scf.for .. iter_args(%arg0 = %1) { |
| 1279 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
| 1280 | // %4 = tensor.extract_slice %arg1 [..] |
| 1281 | // .. = linalg.matmul .. outs(%4 : ) |
| 1282 | // } |
| 1283 | // } |
| 1284 | // ``` |
| 1285 | // |
| 1286 | // the IR is currently |
| 1287 | // |
| 1288 | // ``` |
| 1289 | // %0 = linalg.init |
| 1290 | // %1 = linalg.fill |
| 1291 | // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { |
| 1292 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { |
| 1293 | // %4 = tensor.extract_slice %arg1[..] |
| 1294 | // %5 = linalg.fill .. outs(%4 : ) |
| 1295 | // .. = linalg.matmul .. outs(%5 : ) |
| 1296 | // } |
| 1297 | // } |
| 1298 | // ``` |
| 1299 | // |
| 1300 | // The untiled `linalg.fill` is still used as the `init_value` since it |
| 1301 | // was originally a destination operand of the untiled `linalg.matmul`. |
| 1302 | // When fusing an operand that is a destination operand, the iter_arg of |
| 1303 | // the outer most loop should be changed to use the destination of the |
| 1304 | // fused operation. With this the IR will be. |
| 1305 | // |
| 1306 | // ``` |
| 1307 | // %0 = linalg.init |
| 1308 | // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { |
| 1309 | // %2 = scf.for .. iter_args(%arg1 = %arg0) { |
| 1310 | // %3 = tensor.extract_slice %arg1[..] |
| 1311 | // %4 = linalg.fill .. outs(%3 : ) |
| 1312 | // .. = linalg.matmul .. outs(%4 : ) |
| 1313 | // } |
| 1314 | // } |
| 1315 | // ``` |
| 1316 | if (destinationInitArg && |
| 1317 | isa<DestinationStyleOpInterface>(Val: fusableProducerOp) && !loops.empty()) { |
| 1318 | loops.front() |
| 1319 | ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] |
| 1320 | .set(origDestinationTensors[resultNumber]); |
| 1321 | } |
| 1322 | return scf::SCFFuseProducerOfSliceResult{ |
| 1323 | .origProducer: fusableProducer, .tiledAndFusedProducer: tileAndFuseResult->tiledValues[0], |
| 1324 | .tiledOps: tileAndFuseResult->tiledOps, .generatedSlices: tileAndFuseResult->generatedSlices}; |
| 1325 | } |
| 1326 | |
| 1327 | /// Reconstruct the fused producer from within the tiled-and-fused code. |
| 1328 | FailureOr<SmallVector<Operation *>> mlir::scf::( |
| 1329 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, |
| 1330 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, |
| 1331 | MutableArrayRef<LoopLikeOpInterface> loops, |
| 1332 | ArrayRef<unsigned> yieldResultNumber) { |
| 1333 | if (loops.empty()) |
| 1334 | return success(); |
| 1335 | |
| 1336 | Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), |
| 1337 | *tiledOwner = fusedProducerInfo.tiledOps[0]; |
| 1338 | |
| 1339 | Location loc = originalOwner->getLoc(); |
| 1340 | // a. collect all init Value to be appended |
| 1341 | SmallVector<unsigned> initNumberList = |
| 1342 | yieldResultNumber.empty() ? llvm::to_vector(Range: llvm::seq<unsigned>( |
| 1343 | Begin: 0, End: originalOwner->getNumResults())) |
| 1344 | : llvm::to_vector(Range&: yieldResultNumber); |
| 1345 | SmallVector<Value> initValueList; |
| 1346 | for (const auto &resultNumber : initNumberList) { |
| 1347 | FailureOr<Value> initValue = tensor::getOrCreateDestination( |
| 1348 | b&: rewriter, loc, opResult: originalOwner->getResult(idx: resultNumber)); |
| 1349 | if (succeeded(Result: initValue)) { |
| 1350 | initValueList.push_back(Elt: initValue.value()); |
| 1351 | } else { |
| 1352 | return failure(); |
| 1353 | } |
| 1354 | } |
| 1355 | |
| 1356 | SmallVector<Operation *> generatedSlices; |
| 1357 | YieldTiledValuesFn newYieldValuesFn = |
| 1358 | [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, |
| 1359 | ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, |
| 1360 | SmallVector<SmallVector<OpFoldResult>> &tiledOffset, |
| 1361 | SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { |
| 1362 | OpBuilder::InsertionGuard g(innerRewriter); |
| 1363 | |
| 1364 | // get sliceOp tile information |
| 1365 | SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(), |
| 1366 | sliceSizes = sliceOp.getMixedSizes(); |
| 1367 | |
| 1368 | // expect all strides of sliceOp being 1 |
| 1369 | if (!llvm::all_of(Range: sliceOp.getMixedStrides(), P: isOneInteger)) |
| 1370 | return failure(); |
| 1371 | |
| 1372 | unsigned sliceResultNumber = |
| 1373 | fusedProducerInfo.origProducer.getResultNumber(); |
| 1374 | |
| 1375 | auto tilableOp = cast<TilingInterface>(Val: originalOwner); |
| 1376 | // b. get iterDomain Offset and Sizes based on sliceOp tile |
| 1377 | SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes; |
| 1378 | // skip tensor.pack/unpack/pad, which expects single opResult |
| 1379 | if (tilableOp->getNumResults() > 1 && |
| 1380 | failed(Result: tilableOp.getIterationDomainTileFromResultTile( |
| 1381 | b&: rewriter, resultNumber: sliceResultNumber, offsets: sliceOffset, sizes: sliceSizes, |
| 1382 | iterDomainOffsets&: iterDomainOffset, iterDomainSizes))) { |
| 1383 | // In theory, it is unnecessary to raise an error here. Actually |
| 1384 | // although it fails to reconstruct the result tensor, it should not |
| 1385 | // broke current fusion anyway. The reason why we must return failure |
| 1386 | // currently is that the callback function `newYieldValuesFn` will be |
| 1387 | // called after new init operand(s) has already been appended. It will |
| 1388 | // take more refactoring to make sure the init operands are added |
| 1389 | // consistently in the future. For more details, please refer to: |
| 1390 | // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 |
| 1391 | return failure(); |
| 1392 | } |
| 1393 | |
| 1394 | // c. calculate offsets and sizes info of all OpResults respectively based |
| 1395 | // on iteration Domain Tile |
| 1396 | SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList; |
| 1397 | for (const auto &resultNumber : initNumberList) { |
| 1398 | if (resultNumber == sliceResultNumber) { |
| 1399 | offsetList.push_back(Elt: sliceOffset); |
| 1400 | sizesList.push_back(Elt: sliceSizes); |
| 1401 | } else { |
| 1402 | assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); |
| 1403 | // infer result tile according to the iteration domain tile |
| 1404 | SmallVector<OpFoldResult> offset, sizes; |
| 1405 | if (failed(Result: tilableOp.getResultTilePosition( |
| 1406 | b&: rewriter, resultNumber, offsets: iterDomainOffset, sizes: iterDomainSizes, |
| 1407 | resultOffsets&: offset, resultSizes&: sizes))) { |
| 1408 | return failure(); |
| 1409 | } |
| 1410 | offsetList.push_back(Elt: offset); |
| 1411 | sizesList.push_back(Elt: sizes); |
| 1412 | } |
| 1413 | } |
| 1414 | |
| 1415 | // d. create `extract_slice` for `iter_args` for DPS operation if |
| 1416 | // necessary |
| 1417 | if (auto tiledDestStyleOp = |
| 1418 | dyn_cast<DestinationStyleOpInterface>(Val: tiledOwner)) { |
| 1419 | rewriter.setInsertionPoint(tiledDestStyleOp); |
| 1420 | for (const auto &&[index, newRegionArg] : |
| 1421 | llvm::enumerate(First&: newRegionIterArgs)) { |
| 1422 | auto destSlice = rewriter.create<tensor::ExtractSliceOp>( |
| 1423 | location: loc, args&: newRegionArg, args&: offsetList[index], args&: sizesList[index], |
| 1424 | args: SmallVector<OpFoldResult>(offsetList[index].size(), |
| 1425 | rewriter.getIndexAttr(value: 1))); |
| 1426 | generatedSlices.push_back(Elt: destSlice); |
| 1427 | unsigned resultNumber = initNumberList[index]; |
| 1428 | rewriter.modifyOpInPlace(root: tiledDestStyleOp, callable: [&]() { |
| 1429 | tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); |
| 1430 | }); |
| 1431 | } |
| 1432 | } |
| 1433 | |
| 1434 | // e. prepare tiled offset and sizes for later `insert_slice` creation by |
| 1435 | // caller |
| 1436 | Block *block = rewriter.getInsertionPoint()->getBlock(); |
| 1437 | rewriter.setInsertionPoint(block->getTerminator()); |
| 1438 | for (const auto &&[index, resultNumber] : llvm::enumerate(First&: initNumberList)) { |
| 1439 | tiledResult.push_back(Elt: tiledOwner->getResult(idx: resultNumber)); |
| 1440 | tiledOffset.emplace_back(Args&: offsetList[index]); |
| 1441 | tiledSizes.emplace_back(Args&: sizesList[index]); |
| 1442 | } |
| 1443 | return success(); |
| 1444 | }; |
| 1445 | |
| 1446 | if (failed(Result: addInitOperandsToLoopNest(rewriter, loops, newInitValues: initValueList, |
| 1447 | getNewTiledYieldsFn: newYieldValuesFn))) { |
| 1448 | return failure(); |
| 1449 | } |
| 1450 | return generatedSlices; |
| 1451 | } |
| 1452 | |
| 1453 | namespace { |
| 1454 | |
| 1455 | //===----------------------------------------------------------------------===// |
| 1456 | // SliceTrackingListener |
| 1457 | //===----------------------------------------------------------------------===// |
| 1458 | |
| 1459 | /// This class is a listener for tracking the insertion and removal of |
| 1460 | /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy |
| 1461 | /// fusion algorithm to apply cleanup patterns in between fusion steps. |
| 1462 | class SliceTrackingListener : public RewriterBase::Listener { |
| 1463 | public: |
| 1464 | explicit SliceTrackingListener( |
| 1465 | std::optional<FrozenRewritePatternSet> patterns); |
| 1466 | SliceTrackingListener() = default; |
| 1467 | |
| 1468 | /// Adds the given list of operations to the worklist, and if present, |
| 1469 | /// applies the list of `patterns` to the newly added operations. This only |
| 1470 | /// processes the given operations and any newly inserted ones by the |
| 1471 | /// pattern set. |
| 1472 | LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps); |
| 1473 | |
| 1474 | /// Add to the new operation worklist if it is an extract_slice. |
| 1475 | void notifyOperationInserted(Operation *op, |
| 1476 | OpBuilder::InsertPoint previous) override; |
| 1477 | |
| 1478 | /// Shared helper for operation removal from the worklist. |
| 1479 | void removeOp(Operation *op); |
| 1480 | |
| 1481 | /// Remove the operation from the worklist. |
| 1482 | void notifyOperationErased(Operation *op) override; |
| 1483 | |
| 1484 | /// Remove the operation from the worklist. |
| 1485 | void notifyOperationReplaced(Operation *op, ValueRange replacement) override; |
| 1486 | |
| 1487 | /// The worklist for this transformation keeps track of the slices to visit |
| 1488 | /// next for fusion. |
| 1489 | std::deque<tensor::ExtractSliceOp> worklist; |
| 1490 | |
| 1491 | private: |
| 1492 | /// Optional pattern set to apply when adding new operations to the |
| 1493 | /// worklist. |
| 1494 | std::optional<FrozenRewritePatternSet> patterns = std::nullopt; |
| 1495 | }; |
| 1496 | |
| 1497 | SliceTrackingListener::SliceTrackingListener( |
| 1498 | std::optional<FrozenRewritePatternSet> p) { |
| 1499 | patterns = std::move(p); |
| 1500 | } |
| 1501 | |
| 1502 | LogicalResult |
| 1503 | SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) { |
| 1504 | for (Operation *op : ops) { |
| 1505 | if (auto slice = dyn_cast<tensor::ExtractSliceOp>(Val: op)) |
| 1506 | worklist.push_back(x: slice); |
| 1507 | } |
| 1508 | |
| 1509 | if (!patterns) |
| 1510 | return success(); |
| 1511 | |
| 1512 | return applyOpPatternsGreedily( |
| 1513 | ops, patterns: patterns.value(), |
| 1514 | config: GreedyRewriteConfig().setListener(this).setStrictness( |
| 1515 | GreedyRewriteStrictness::ExistingAndNewOps)); |
| 1516 | } |
| 1517 | |
| 1518 | void SliceTrackingListener::notifyOperationInserted( |
| 1519 | Operation *op, OpBuilder::InsertPoint previous) { |
| 1520 | auto slice = dyn_cast<tensor::ExtractSliceOp>(Val: op); |
| 1521 | if (!slice) |
| 1522 | return; |
| 1523 | worklist.push_back(x: slice); |
| 1524 | } |
| 1525 | |
| 1526 | // Scan the worklist for the given op and remove it if present. The |
| 1527 | // expectation is for the worklist to be small and for removal to be |
| 1528 | // relatively rare. |
| 1529 | void SliceTrackingListener::removeOp(Operation *op) { |
| 1530 | if (!isa<tensor::ExtractSliceOp>(Val: op)) |
| 1531 | return; |
| 1532 | auto iter = worklist.begin(); |
| 1533 | while (iter != worklist.end()) { |
| 1534 | if (*iter == op) |
| 1535 | break; |
| 1536 | iter++; |
| 1537 | } |
| 1538 | if (iter == worklist.end()) |
| 1539 | return; |
| 1540 | |
| 1541 | worklist.erase(position: iter); |
| 1542 | } |
| 1543 | |
| 1544 | void SliceTrackingListener::notifyOperationErased(Operation *op) { |
| 1545 | removeOp(op); |
| 1546 | } |
| 1547 | |
| 1548 | void SliceTrackingListener::notifyOperationReplaced(Operation *op, |
| 1549 | ValueRange replacement) { |
| 1550 | removeOp(op); |
| 1551 | } |
| 1552 | |
| 1553 | //===----------------------------------------------------------------------===// |
| 1554 | // ReplacementListener |
| 1555 | //===----------------------------------------------------------------------===// |
| 1556 | |
| 1557 | /// Listener that tracks updates replacements for values which can be mutated. |
| 1558 | /// This listener runs on top of the existing listener for the rewriter, |
| 1559 | /// to make sure external users can still run listeners. |
| 1560 | class ReplacementListener : public RewriterBase::ForwardingListener { |
| 1561 | public: |
| 1562 | ReplacementListener(DenseMap<Value, Value> &replacements, |
| 1563 | OpBuilder::Listener *listener) |
| 1564 | : ForwardingListener(listener), replacements(replacements) {} |
| 1565 | |
| 1566 | void updateReplacementValues(ValueRange origValues, |
| 1567 | ValueRange replaceValues) { |
| 1568 | // This can probably be written better, but just iterates over the map |
| 1569 | // and the new replacements for now. |
| 1570 | for (auto &[key, val] : replacements) { |
| 1571 | for (auto [orig, replace] : llvm::zip_equal(t&: origValues, u&: replaceValues)) { |
| 1572 | if (val == orig) { |
| 1573 | val = replace; |
| 1574 | } |
| 1575 | } |
| 1576 | } |
| 1577 | } |
| 1578 | |
| 1579 | void notifyOperationReplaced(Operation *op, Operation *newOp) override { |
| 1580 | ForwardingListener::notifyOperationReplaced(op, newOp); |
| 1581 | updateReplacementValues(origValues: op->getResults(), replaceValues: newOp->getResults()); |
| 1582 | } |
| 1583 | |
| 1584 | void notifyOperationReplaced(Operation *op, ValueRange values) override { |
| 1585 | ForwardingListener::notifyOperationReplaced(op, replacement: values); |
| 1586 | updateReplacementValues(origValues: op->getResults(), replaceValues: values); |
| 1587 | } |
| 1588 | |
| 1589 | private: |
| 1590 | DenseMap<Value, Value> &replacements; |
| 1591 | }; |
| 1592 | |
| 1593 | } // namespace |
| 1594 | |
| 1595 | /// Implementation of tile consumer and fuse producer greedily. |
| 1596 | FailureOr<scf::SCFTileAndFuseResult> |
| 1597 | mlir::scf::tileConsumerAndFuseProducersUsingSCF( |
| 1598 | RewriterBase &rewriter, TilingInterface consumer, |
| 1599 | const scf::SCFTileAndFuseOptions &options) { |
| 1600 | // This transformation is only valid for ops that return values (i.e. not |
| 1601 | // valid to use with operations that have memref operands). |
| 1602 | if (!consumer->getNumResults()) { |
| 1603 | return rewriter.notifyMatchFailure( |
| 1604 | arg&: consumer, msg: "invalid pattern for op with no results" ); |
| 1605 | } |
| 1606 | |
| 1607 | // 1. First tile the consumer. |
| 1608 | SetVector<Operation *> fusedProducers, tiledAndFusedOps; |
| 1609 | |
| 1610 | FailureOr<scf::SCFTilingResult> tilingResult = |
| 1611 | tileUsingSCF(rewriter, op: consumer, options: options.tilingOptions); |
| 1612 | |
| 1613 | if (failed(Result: tilingResult)) |
| 1614 | return rewriter.notifyMatchFailure(arg&: consumer, msg: "failed to tile consumer" ); |
| 1615 | tiledAndFusedOps.insert_range(R&: tilingResult->tiledOps); |
| 1616 | |
| 1617 | DenseMap<Value, Value> replacements; |
| 1618 | for (auto [origVal, replacement] : |
| 1619 | llvm::zip_equal(t: consumer->getResults(), u&: tilingResult->replacements)) { |
| 1620 | replacements[origVal] = replacement; |
| 1621 | } |
| 1622 | |
| 1623 | // If there are no loops generated, fusion is immaterial. |
| 1624 | auto &loops = tilingResult->loops; |
| 1625 | if (loops.empty()) { |
| 1626 | return scf::SCFTileAndFuseResult{.fusedProducers: fusedProducers, .tiledAndFusedOps: tiledAndFusedOps, .loops: loops, |
| 1627 | .replacements: replacements}; |
| 1628 | } |
| 1629 | |
| 1630 | // Since the loop gets potentially replaced during fusion, we need to track |
| 1631 | // the mutation of replacement values. To do this, we attach a listener to |
| 1632 | // update the replacements as they happen. |
| 1633 | OpBuilder::Listener *previousListener = rewriter.getListener(); |
| 1634 | auto resetListener = |
| 1635 | llvm::make_scope_exit(F: [&]() { rewriter.setListener(previousListener); }); |
| 1636 | ReplacementListener replaceListener(replacements, previousListener); |
| 1637 | rewriter.setListener(&replaceListener); |
| 1638 | |
| 1639 | // 2. Typically, the operands of the tiled operation are slices of the |
| 1640 | // operands of the untiled operation. These are expressed in IR using |
| 1641 | // `tensor.extract_slice` operations with source being the operands of |
| 1642 | // the untiled operation. Create a worklist of these |
| 1643 | // `tensor.extract_slice` operations. If the producers of the source of |
| 1644 | // the `tensor.extract_slice` can be tiled such that the tiled value is |
| 1645 | // generated in-place, that effectively tiles + fuses the operations. |
| 1646 | struct WorklistItem { |
| 1647 | tensor::ExtractSliceOp candidateSlice; |
| 1648 | SCFTileAndFuseOptions::ControlFnResult controlFnResult; |
| 1649 | }; |
| 1650 | |
| 1651 | SliceTrackingListener sliceTracker = |
| 1652 | SliceTrackingListener(options.cleanupPatterns); |
| 1653 | |
| 1654 | if (failed( |
| 1655 | Result: sliceTracker.insertAndApplyPatterns(ops: tilingResult->generatedSlices))) { |
| 1656 | return rewriter.notifyMatchFailure(arg&: consumer, msg: "cleanup patterns failed" ); |
| 1657 | } |
| 1658 | OpBuilder::InsertionGuard g(rewriter); |
| 1659 | while (!sliceTracker.worklist.empty()) { |
| 1660 | auto candidateSlice = sliceTracker.worklist.front(); |
| 1661 | sliceTracker.worklist.pop_front(); |
| 1662 | |
| 1663 | auto [fusableProducer, destinationInitArg] = |
| 1664 | getUntiledProducerFromSliceSource(source: &candidateSlice.getSourceMutable(), |
| 1665 | loops); |
| 1666 | if (!fusableProducer) |
| 1667 | continue; |
| 1668 | |
| 1669 | std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult = |
| 1670 | options.fusionControlFn(candidateSlice, fusableProducer, |
| 1671 | destinationInitArg.has_value()); |
| 1672 | if (!controlFnResult) |
| 1673 | continue; |
| 1674 | |
| 1675 | WorklistItem worklistItem = {.candidateSlice: candidateSlice, .controlFnResult: controlFnResult.value()}; |
| 1676 | |
| 1677 | // The operands of the fused producer might themselved be slices of |
| 1678 | // values produced by operations that implement the `TilingInterface`. |
| 1679 | // Add these operations to the worklist. |
| 1680 | std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult = |
| 1681 | tileAndFuseProducerOfSlice(rewriter, candidateSliceOp: worklistItem.candidateSlice, |
| 1682 | loops); |
| 1683 | if (!fusedResult) |
| 1684 | continue; |
| 1685 | |
| 1686 | SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices; |
| 1687 | |
| 1688 | if (worklistItem.controlFnResult.yieldProducerReplacement) { |
| 1689 | // Reconstruct and yield all opResult of fusableProducerOp by default. |
| 1690 | // The caller can specific which one to yield by designating optional |
| 1691 | // argument named `yieldResultNumber` of |
| 1692 | // `yieldReplacementForFusedProducer`. |
| 1693 | Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); |
| 1694 | FailureOr<SmallVector<Operation *>> newSlices = |
| 1695 | yieldReplacementForFusedProducer(rewriter, |
| 1696 | sliceOp: worklistItem.candidateSlice, |
| 1697 | fusedProducerInfo: fusedResult.value(), loops); |
| 1698 | if (failed(Result: newSlices)) { |
| 1699 | return rewriter.notifyMatchFailure( |
| 1700 | arg&: fusableProducerOp, msg: "failed to replacement value for this " |
| 1701 | "operation from within the tiled loop" ); |
| 1702 | } |
| 1703 | worklistCandidates.append(RHS: newSlices.value()); |
| 1704 | for (auto [index, result] : |
| 1705 | llvm::enumerate(First: fusableProducerOp->getResults())) { |
| 1706 | replacements[result] = loops.front()->getResult( |
| 1707 | idx: loops.front()->getNumResults() - |
| 1708 | fusableProducerOp->getNumResults() + index); |
| 1709 | } |
| 1710 | } |
| 1711 | if (Operation *tiledAndFusedOp = |
| 1712 | fusedResult->tiledAndFusedProducer.getDefiningOp()) { |
| 1713 | fusedProducers.insert(X: fusedResult->origProducer.getDefiningOp()); |
| 1714 | tiledAndFusedOps.insert(X: tiledAndFusedOp); |
| 1715 | } |
| 1716 | |
| 1717 | if (failed(Result: sliceTracker.insertAndApplyPatterns(ops: worklistCandidates))) { |
| 1718 | return rewriter.notifyMatchFailure(arg&: consumer, msg: "cleanup patterns failed" ); |
| 1719 | } |
| 1720 | } |
| 1721 | |
| 1722 | return scf::SCFTileAndFuseResult{.fusedProducers: fusedProducers, .tiledAndFusedOps: tiledAndFusedOps, .loops: loops, |
| 1723 | .replacements: replacements}; |
| 1724 | } |
| 1725 | |
| 1726 | //===----------------------------------------------------------------------===// |
| 1727 | // tileAndFuseConsumerUsingSCF implementation. |
| 1728 | //===----------------------------------------------------------------------===// |
| 1729 | |
| 1730 | /// A utility function that checks whether the only use of the result of a |
| 1731 | /// tensor.insert_slice op is in a scf.yield op. |
| 1732 | static LogicalResult |
| 1733 | checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { |
| 1734 | Value result = candidateSliceOp.getResult(); |
| 1735 | Value::use_range uses = result.getUses(); |
| 1736 | if (!llvm::hasSingleElement(C&: uses)) { |
| 1737 | LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n" ); |
| 1738 | return failure(); |
| 1739 | } |
| 1740 | OpOperand &operandUse = (*uses.begin()); |
| 1741 | Operation *userOp = operandUse.getOwner(); |
| 1742 | if (!isa<scf::YieldOp>(Val: userOp)) { |
| 1743 | LLVM_DEBUG(llvm::dbgs() |
| 1744 | << "Expected scf.yield to be the only user, but got -> " |
| 1745 | << (*userOp)); |
| 1746 | return failure(); |
| 1747 | } |
| 1748 | if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { |
| 1749 | LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " |
| 1750 | "be in the same block\n" ); |
| 1751 | return failure(); |
| 1752 | } |
| 1753 | return success(); |
| 1754 | } |
| 1755 | |
| 1756 | /// An utility to get the first user of the given loopOp. If any of user stay |
| 1757 | /// in different block of loopOp, return failure. |
| 1758 | static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) { |
| 1759 | if (!isa<LoopLikeOpInterface>(Val: loopOp)) |
| 1760 | return failure(); |
| 1761 | Operation *firstUserOfLoop = nullptr; |
| 1762 | for (Operation *userOp : loopOp->getUsers()) { |
| 1763 | // `ParallelInsertSlice` located inside `InParallelOp` has no same parent |
| 1764 | // block with any other types of operation. Thus, just redirecting to its |
| 1765 | // parent `InParallelOp`. E.g. |
| 1766 | // |
| 1767 | // ``` |
| 1768 | // %1 = scf.for { |
| 1769 | // ... |
| 1770 | // } |
| 1771 | // %2 = consumerOp ins(%1, ...) |
| 1772 | // scf.forall.in_parallel { |
| 1773 | // tensor.parallel_insert_slice %1 |
| 1774 | // } |
| 1775 | // ``` |
| 1776 | // where `InParallelOp` but not `ParallelInsertSlice` stays in the same |
| 1777 | // same block with `consumerOp`. |
| 1778 | if (isa<tensor::ParallelInsertSliceOp>(Val: userOp)) |
| 1779 | userOp = userOp->getParentOfType<scf::InParallelOp>(); |
| 1780 | |
| 1781 | if (loopOp->getBlock() != userOp->getBlock()) |
| 1782 | return failure(); |
| 1783 | |
| 1784 | if (!firstUserOfLoop || userOp->isBeforeInBlock(other: firstUserOfLoop)) |
| 1785 | firstUserOfLoop = userOp; |
| 1786 | } |
| 1787 | return firstUserOfLoop; |
| 1788 | } |
| 1789 | |
| 1790 | /// This utility currently checks whether the first userOp of loop is NOT |
| 1791 | /// before the last defineOp of consumer operand. Because that we need to move |
| 1792 | /// the whole loop structure right before the `firstUserOfLoop`. This utility |
| 1793 | /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice |
| 1794 | /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that: |
| 1795 | /// |
| 1796 | /// ``` |
| 1797 | /// %0 = scf.for() { |
| 1798 | /// ... |
| 1799 | /// } |
| 1800 | /// ... |
| 1801 | /// %1 = firstUserOfLoop(%0) |
| 1802 | /// ... |
| 1803 | /// %2 = lastDefOfConsumerOperand |
| 1804 | /// ... |
| 1805 | /// %3 = consumerOp(%2) |
| 1806 | /// ``` |
| 1807 | /// |
| 1808 | /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it |
| 1809 | /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`, |
| 1810 | /// a.k.a. use-def chain violation: |
| 1811 | /// |
| 1812 | /// ``` |
| 1813 | /// %0:2 = scf.for() { |
| 1814 | /// // use before define error |
| 1815 | /// %3 = tiledConsumerOp(%2) |
| 1816 | /// } |
| 1817 | /// %1 = firstUserOfLoop(%0) |
| 1818 | /// ... |
| 1819 | /// %2 = lastDefOfConsumerOperand |
| 1820 | /// ``` |
| 1821 | /// |
| 1822 | /// @param loopOp: loop operation |
| 1823 | /// @param consumerOp: consumer operation |
| 1824 | /// @param reorderOperations: the flag controls whether to reorder the |
| 1825 | /// backward slice w.r.t. the defineOp of `consumerOp` operands. |
| 1826 | /// @return: computed backward slice of consumerOp, but excluding those |
| 1827 | /// already dominates `firstUserOfLoop`. |
| 1828 | static FailureOr<llvm::SetVector<Operation *>> |
| 1829 | checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, |
| 1830 | bool reorderOperations) { |
| 1831 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); |
| 1832 | if (failed(Result: firstUserOfLoop)) |
| 1833 | return failure(); |
| 1834 | |
| 1835 | BackwardSliceOptions options; |
| 1836 | DominanceInfo dominanceInfo; |
| 1837 | options.inclusive = true; |
| 1838 | options.omitBlockArguments = true; |
| 1839 | bool includeLoopOp = false; |
| 1840 | options.filter = [&](Operation *op) { |
| 1841 | if (op == loopOp) { |
| 1842 | includeLoopOp = true; |
| 1843 | return false; |
| 1844 | } |
| 1845 | // Cut off the slice to not include any operation that already dominates |
| 1846 | // firstUserOfLoop. |
| 1847 | return !dominanceInfo.properlyDominates(a: op, b: *firstUserOfLoop); |
| 1848 | }; |
| 1849 | llvm::SetVector<Operation *> slice; |
| 1850 | for (auto operand : consumerOp->getOperands()) { |
| 1851 | LogicalResult result = getBackwardSlice(root: operand, backwardSlice: &slice, options); |
| 1852 | assert(result.succeeded() && "expected a backward slice" ); |
| 1853 | (void)result; |
| 1854 | } |
| 1855 | |
| 1856 | if (!slice.empty()) { |
| 1857 | // If consumerOp has one producer, which is also the user of loopOp. |
| 1858 | // E.g. |
| 1859 | // ``` |
| 1860 | // %0 = %loopOp |
| 1861 | // %1 = consumerOp1 ins(%0) |
| 1862 | // %2 = consumerOp2 ins(%0, %1) |
| 1863 | // ``` |
| 1864 | // We can not fuse consumerOp2 into loopOp due to UD chain, unless |
| 1865 | // consumerOp1 has already been fused into loopOp before. |
| 1866 | if (includeLoopOp || !reorderOperations) |
| 1867 | return failure(); |
| 1868 | } |
| 1869 | |
| 1870 | return slice; |
| 1871 | } |
| 1872 | |
| 1873 | /// Fetches the OpOperand of the first valid user (and use) of the value `val` |
| 1874 | /// which implements `TilingInterface` and `DestinationStyleOpInterface`. |
| 1875 | /// Returns failure otherwise. |
| 1876 | static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, |
| 1877 | Operation *loopOp, |
| 1878 | unsigned resultNumber) { |
| 1879 | if (!isa<LoopLikeOpInterface>(Val: loopOp)) |
| 1880 | return failure(); |
| 1881 | Value val = loopOp->getResult(idx: resultNumber); |
| 1882 | Block *loopBlock = loopOp->getBlock(); |
| 1883 | for (OpOperand &opOperand : val.getUses()) { |
| 1884 | Operation *consumerOp = opOperand.getOwner(); |
| 1885 | // Step 1. Check if the user is tilable. |
| 1886 | if (!isa<TilingInterface>(Val: consumerOp) || |
| 1887 | !isa<DestinationStyleOpInterface>(Val: consumerOp)) { |
| 1888 | // TODO: We have to init result of consumer before scf.for, use |
| 1889 | // DestinationStyleOpInterface to get result shape from init for now. |
| 1890 | // Add support for other op such as op has InferTypeOpInterface. |
| 1891 | continue; |
| 1892 | } |
| 1893 | // Step 2. Check if user stay in the same block. |
| 1894 | if (loopBlock != consumerOp->getBlock()) |
| 1895 | continue; |
| 1896 | // Step 3. Check if user has succeeding user. Otherwise, it usually |
| 1897 | // represents already tiled. |
| 1898 | if (consumerOp->use_empty()) |
| 1899 | continue; |
| 1900 | // Step 4. Check assumption for loop with `reorderOperations` enabled. |
| 1901 | FailureOr<llvm::SetVector<Operation *>> slice = |
| 1902 | checkAssumptionForLoop(loopOp, consumerOp, reorderOperations: true); |
| 1903 | if (failed(Result: slice)) |
| 1904 | continue; |
| 1905 | // Step 5. If backward sice is not empty, move them before |
| 1906 | // firstUserOfLoop. |
| 1907 | if (!slice->empty()) { |
| 1908 | mlir::topologicalSort(toSort: *slice); |
| 1909 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp); |
| 1910 | assert(succeeded(firstUserOfLoop) && "First user of loop is not found" ); |
| 1911 | for (auto op : *slice) { |
| 1912 | rewriter.moveOpBefore(op, existingOp: *firstUserOfLoop); |
| 1913 | } |
| 1914 | } |
| 1915 | return &opOperand; |
| 1916 | } |
| 1917 | return failure(); |
| 1918 | } |
| 1919 | |
| 1920 | /// Check that the loop is perfectly nested. |
| 1921 | /// The loops are expected to be ordered from outer most to inner most. |
| 1922 | /// For example: |
| 1923 | /// ``` |
| 1924 | /// %0 = scf.for() |
| 1925 | /// %1 = scf.for() |
| 1926 | /// %2 = scf.for() |
| 1927 | /// %3 = ... |
| 1928 | /// yield %3 |
| 1929 | /// yield %2 |
| 1930 | /// yield %1 |
| 1931 | /// ``` |
| 1932 | /// Here loops should be [%0, %1]. |
| 1933 | static bool |
| 1934 | isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1935 | assert(!loops.empty() && "unexpected empty loop nest" ); |
| 1936 | if (loops.size() == 1) { |
| 1937 | return isa_and_nonnull<scf::ForOp>(Val: loops.front().getOperation()); |
| 1938 | } |
| 1939 | for (auto [outerLoop, innerLoop] : |
| 1940 | llvm::zip_equal(t: loops.drop_back(), u: loops.drop_front())) { |
| 1941 | auto outerFor = dyn_cast_or_null<scf::ForOp>(Val: outerLoop.getOperation()); |
| 1942 | auto innerFor = dyn_cast_or_null<scf::ForOp>(Val: innerLoop.getOperation()); |
| 1943 | if (!outerFor || !innerFor) { |
| 1944 | return false; |
| 1945 | } |
| 1946 | auto outerBBArgs = outerFor.getRegionIterArgs(); |
| 1947 | auto innerIterArgs = innerFor.getInitArgs(); |
| 1948 | if (outerBBArgs.size() != innerIterArgs.size()) { |
| 1949 | return false; |
| 1950 | } |
| 1951 | |
| 1952 | for (auto [outerBBArg, innerIterArg] : |
| 1953 | llvm::zip_equal(t&: outerBBArgs, u&: innerIterArgs)) { |
| 1954 | if (!llvm::hasSingleElement(C: outerBBArg.getUses()) || |
| 1955 | innerIterArg != outerBBArg) { |
| 1956 | return false; |
| 1957 | } |
| 1958 | } |
| 1959 | |
| 1960 | ValueRange outerYields = |
| 1961 | cast<scf::YieldOp>(Val: outerFor.getBody()->getTerminator())->getOperands(); |
| 1962 | ValueRange innerResults = innerFor.getResults(); |
| 1963 | if (outerYields.size() != innerResults.size()) { |
| 1964 | return false; |
| 1965 | } |
| 1966 | for (auto [outerYield, innerResult] : |
| 1967 | llvm::zip_equal(t&: outerYields, u&: innerResults)) { |
| 1968 | if (!llvm::hasSingleElement(C: innerResult.getUses()) || |
| 1969 | outerYield != innerResult) { |
| 1970 | return false; |
| 1971 | } |
| 1972 | } |
| 1973 | } |
| 1974 | return true; |
| 1975 | } |
| 1976 | |
| 1977 | /// Fetch the untiled consumer of the outermost scf.for's result which is |
| 1978 | /// yielded by a tensor.insert_slice from the innermost scf.for. This function |
| 1979 | /// makes the following assumptions : |
| 1980 | /// 1. tensor.insert_slice has scf.yield as its only user. |
| 1981 | /// 2. scf.for's corresponding result has only one use. |
| 1982 | /// 3. The `loops` passed in are perfectly nested `scf.for` operations. |
| 1983 | static FailureOr<OpOperand *> |
| 1984 | getUntiledConsumerFromSlice(RewriterBase &rewriter, |
| 1985 | tensor::InsertSliceOp candidateSliceOp, |
| 1986 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 1987 | assert(!loops.empty() && "unexpected loops to be empty" ); |
| 1988 | // 1. Expect slice to be part of the body of the inner most loop. |
| 1989 | Operation *containingOp = candidateSliceOp->getParentOp(); |
| 1990 | if (containingOp != loops.back()) { |
| 1991 | return rewriter.notifyMatchFailure( |
| 1992 | arg&: candidateSliceOp, |
| 1993 | msg: "expected slice to be within body of inner-most loop" ); |
| 1994 | } |
| 1995 | |
| 1996 | // 2. Check that the loop is perfectly nested. |
| 1997 | if (!isPerfectlyNestedForLoops(loops)) { |
| 1998 | return rewriter.notifyMatchFailure( |
| 1999 | arg&: candidateSliceOp, msg: "expected passed loops to be perfectly nested." ); |
| 2000 | } |
| 2001 | |
| 2002 | if (failed(Result: checkAssumptionForFusingConsumer(candidateSliceOp))) |
| 2003 | return failure(); |
| 2004 | Value sliceResult = candidateSliceOp.getResult(); |
| 2005 | |
| 2006 | // 3. Fetch the corresponding output. |
| 2007 | OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); |
| 2008 | unsigned resultNumber = yieldOpOperand.getOperandNumber(); |
| 2009 | |
| 2010 | scf::ForOp topLevelForOp = cast<scf::ForOp>(Val: loops.front().getOperation()); |
| 2011 | |
| 2012 | return getConsumerFromLoopUses(rewriter, loopOp: topLevelForOp, resultNumber); |
| 2013 | } |
| 2014 | |
| 2015 | /// Fetch the first untiled consumer of a scf.forall's result which is yielded |
| 2016 | /// by a tensor.parallel_insert_slice. |
| 2017 | static FailureOr<OpOperand *> |
| 2018 | getUntiledConsumerFromSlice(RewriterBase &rewriter, |
| 2019 | tensor::ParallelInsertSliceOp candidateSliceOp, |
| 2020 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 2021 | assert(!loops.empty() && "unexpected loops to be empty" ); |
| 2022 | // 1. Check that the surrounding loop is a single scf.forall loop. |
| 2023 | if (loops.size() != 1) { |
| 2024 | return rewriter.notifyMatchFailure( |
| 2025 | arg&: candidateSliceOp, msg: "expected single surrounding scf.forall" ); |
| 2026 | } |
| 2027 | auto forallOp = dyn_cast<scf::ForallOp>(Val: loops.front().getOperation()); |
| 2028 | if (!forallOp) { |
| 2029 | return rewriter.notifyMatchFailure( |
| 2030 | arg&: candidateSliceOp, msg: "expected single surrounding scf.forall" ); |
| 2031 | } |
| 2032 | |
| 2033 | // 2. Fetch the corresponding output |
| 2034 | Value sliceDest = candidateSliceOp.getDest(); |
| 2035 | auto iterArg = dyn_cast<BlockArgument>(Val&: sliceDest); |
| 2036 | if (!iterArg) |
| 2037 | return failure(); |
| 2038 | if (iterArg.getOwner()->getParentOp() != forallOp) |
| 2039 | return failure(); |
| 2040 | |
| 2041 | unsigned resultNumber = |
| 2042 | forallOp.getTiedOpResult(opOperand: forallOp.getTiedOpOperand(bbArg: iterArg)) |
| 2043 | .getResultNumber(); |
| 2044 | |
| 2045 | return getConsumerFromLoopUses(rewriter, loopOp: forallOp, resultNumber); |
| 2046 | } |
| 2047 | |
| 2048 | /// A utility to fetch an untiled consumer of |
| 2049 | /// tensor.insert_slice/tensor.parallel_insert_slice. |
| 2050 | static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices( |
| 2051 | RewriterBase &rewriter, ArrayRef<Operation *> sliceOps, |
| 2052 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 2053 | assert(!loops.empty() && "unexpected empty loops" ); |
| 2054 | assert(!sliceOps.empty() && "unexpected empty list of candidate slices" ); |
| 2055 | SmallVector<OpOperand *> fusedOperands; |
| 2056 | for (auto sliceOp : sliceOps) { |
| 2057 | FailureOr<OpOperand *> fusedOperand = |
| 2058 | TypeSwitch<Operation *, FailureOr<OpOperand *>>(sliceOp) |
| 2059 | .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( |
| 2060 | caseFn: [&](auto op) { |
| 2061 | return getUntiledConsumerFromSlice(rewriter, op, loops); |
| 2062 | }) |
| 2063 | .Default(defaultFn: [&](Operation *op) { |
| 2064 | return rewriter.notifyMatchFailure(arg&: op, msg: "unhandled slice type" ); |
| 2065 | }); |
| 2066 | if (failed(Result: fusedOperand)) { |
| 2067 | return failure(); |
| 2068 | } |
| 2069 | if (!fusedOperands.empty() && |
| 2070 | fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) { |
| 2071 | return rewriter.notifyMatchFailure( |
| 2072 | arg: fusedOperand.value()->getOwner(), |
| 2073 | msg: "all candidate slices must be to the same consumer" ); |
| 2074 | } |
| 2075 | fusedOperands.push_back(Elt: fusedOperand.value()); |
| 2076 | } |
| 2077 | return fusedOperands; |
| 2078 | } |
| 2079 | |
| 2080 | template <typename InsertSliceOpTy> |
| 2081 | static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, |
| 2082 | InsertSliceOpTy sliceOp); |
| 2083 | |
| 2084 | template <> |
| 2085 | tensor::InsertSliceOp |
| 2086 | cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter, |
| 2087 | tensor::InsertSliceOp insertSliceOp) { |
| 2088 | return cast<tensor::InsertSliceOp>( |
| 2089 | Val: rewriter.clone(op&: *insertSliceOp.getOperation())); |
| 2090 | } |
| 2091 | |
| 2092 | template <> |
| 2093 | tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>( |
| 2094 | RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) { |
| 2095 | return rewriter.create<tensor::InsertSliceOp>( |
| 2096 | location: insertSliceOp->getLoc(), args: insertSliceOp.getSource(), |
| 2097 | args: insertSliceOp.getDest(), args: insertSliceOp.getMixedOffsets(), |
| 2098 | args: insertSliceOp.getMixedSizes(), args: insertSliceOp.getMixedStrides()); |
| 2099 | } |
| 2100 | |
| 2101 | static SmallVector<tensor::InsertSliceOp> |
| 2102 | cloneAsInsertSlices(RewriterBase &rewriter, |
| 2103 | ArrayRef<Operation *> candidateSlices) { |
| 2104 | assert(!candidateSlices.empty() && |
| 2105 | "unexpected empty list of slices to clone" ); |
| 2106 | SmallVector<tensor::InsertSliceOp> clonedSlices; |
| 2107 | for (auto sliceOp : candidateSlices) { |
| 2108 | TypeSwitch<Operation *>(sliceOp) |
| 2109 | .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>( |
| 2110 | caseFn: [&](auto op) { |
| 2111 | auto clonedOp = cloneAsInsertSlice(rewriter, op); |
| 2112 | clonedSlices.push_back(Elt: clonedOp); |
| 2113 | }) |
| 2114 | .Default(defaultFn: [&](Operation *op) { |
| 2115 | // Assert here assuming this has already been checked. |
| 2116 | assert(0 && "unexpected slice type while cloning as insert slice" ); |
| 2117 | }); |
| 2118 | } |
| 2119 | return clonedSlices; |
| 2120 | } |
| 2121 | |
| 2122 | /// Implementation of fusing consumer of a single slice by computing the |
| 2123 | /// slice of the consumer in-place for scf loop. |
| 2124 | FailureOr<scf::SCFFuseConsumerOfSliceResult> |
| 2125 | mlir::scf::tileAndFuseConsumerOfSlices( |
| 2126 | RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, |
| 2127 | MutableArrayRef<LoopLikeOpInterface> loops) { |
| 2128 | if (candidateSlices.empty()) { |
| 2129 | return rewriter.notifyMatchFailure( |
| 2130 | arg: rewriter.getUnknownLoc(), |
| 2131 | msg: "no candidate slices provided for consumer fusion" ); |
| 2132 | } |
| 2133 | // Return if `loops` is empty, return an error for now. Caller is expected |
| 2134 | // to handle this case. |
| 2135 | if (loops.empty()) { |
| 2136 | return rewriter.notifyMatchFailure( |
| 2137 | arg: candidateSlices.front(), |
| 2138 | msg: "cannot call tile and fuse consumer with an empty loop nest" ); |
| 2139 | } |
| 2140 | |
| 2141 | if (!(llvm::all_of(Range&: candidateSlices, P: llvm::IsaPred<tensor::InsertSliceOp>) || |
| 2142 | llvm::all_of(Range&: candidateSlices, |
| 2143 | P: llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { |
| 2144 | return rewriter.notifyMatchFailure( |
| 2145 | arg: candidateSlices.front(), |
| 2146 | msg: "candidates slices need to be all `tensor.extract_slice`s or " |
| 2147 | "`tensor.parallel_insert_slice`s" ); |
| 2148 | } |
| 2149 | |
| 2150 | // 1. Get the consumer of scf.for for the result yielded by |
| 2151 | // tensor.insert_slice/parallel_insert_slice. |
| 2152 | SmallVector<OpOperand *> consumerOpOperands; |
| 2153 | Operation *consumerOp; |
| 2154 | { |
| 2155 | FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = |
| 2156 | getUntiledConsumerOperandsFromSlices(rewriter, sliceOps: candidateSlices, loops); |
| 2157 | if (failed(Result: maybeConsumerOpOperand)) { |
| 2158 | return rewriter.notifyMatchFailure(arg: candidateSlices.front(), |
| 2159 | msg: "could not fetch consumer to fuse" ); |
| 2160 | } |
| 2161 | std::swap(LHS&: consumerOpOperands, RHS&: maybeConsumerOpOperand.value()); |
| 2162 | consumerOp = consumerOpOperands.front()->getOwner(); |
| 2163 | } |
| 2164 | |
| 2165 | LoopLikeOpInterface outerMostLoop = loops.front(); |
| 2166 | LoopLikeOpInterface innerMostLoop = loops.back(); |
| 2167 | |
| 2168 | // Check assumption for loop with `reorderOperations` disabled. |
| 2169 | if (failed(Result: checkAssumptionForLoop(loopOp: outerMostLoop, consumerOp, reorderOperations: false))) { |
| 2170 | return rewriter.notifyMatchFailure( |
| 2171 | arg&: outerMostLoop, msg: "the first user of loop should not dominate any define " |
| 2172 | "of consumer operand(s)" ); |
| 2173 | } |
| 2174 | |
| 2175 | OpBuilder::InsertionGuard g(rewriter); |
| 2176 | |
| 2177 | // 2. Check consumer is not using scf loop's output as init. |
| 2178 | auto dstOp = dyn_cast<DestinationStyleOpInterface>(Val: consumerOp); |
| 2179 | if (!dstOp) |
| 2180 | return rewriter.notifyMatchFailure(arg&: consumerOp, |
| 2181 | msg: "consumer op is not DPS operation" ); |
| 2182 | if (llvm::any_of(Range&: consumerOpOperands, P: [&](OpOperand *opOperand) { |
| 2183 | return dstOp.isDpsInit(opOperand); |
| 2184 | })) { |
| 2185 | return rewriter.notifyMatchFailure( |
| 2186 | arg&: consumerOp, |
| 2187 | msg: "consumer op taking the result of scf.for as init is not supported" ); |
| 2188 | } |
| 2189 | SmallVector<Value> newInits = llvm::to_vector(Range: dstOp.getDpsInits()); |
| 2190 | |
| 2191 | // 3. Move the whole loop structure right before firstUserOfLoop, the |
| 2192 | // dominance should be already ensured by `checkAssumptionForLoop`. |
| 2193 | FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp: outerMostLoop); |
| 2194 | if (failed(Result: firstUserOfLoop)) { |
| 2195 | return rewriter.notifyMatchFailure( |
| 2196 | arg&: outerMostLoop, msg: "could not find the first user of outer most loop" ); |
| 2197 | } |
| 2198 | rewriter.moveOpBefore(op: outerMostLoop, existingOp: *firstUserOfLoop); |
| 2199 | |
| 2200 | // 4. Set insertion point before terminator op of the loop and create a new |
| 2201 | // tensor.insert_slice. In the scf.for case this is a clone of the |
| 2202 | // candidateSliceOp whereas in the scf.forall case this is created from the |
| 2203 | // operands of tensor.parallel_insert_slice. |
| 2204 | if (auto sliceOp = |
| 2205 | dyn_cast<tensor::ParallelInsertSliceOp>(Val: candidateSlices.front())) { |
| 2206 | auto newForallOp = cast<scf::ForallOp>(Val: innerMostLoop.getOperation()); |
| 2207 | rewriter.setInsertionPoint(newForallOp.getTerminator()); |
| 2208 | } else { |
| 2209 | rewriter.setInsertionPoint(candidateSlices.front()); |
| 2210 | } |
| 2211 | // 5.a. Clone all the candidate slices as equivalent insert slice ops. |
| 2212 | SmallVector<tensor::InsertSliceOp> clonedInsertSlices = |
| 2213 | cloneAsInsertSlices(rewriter, candidateSlices); |
| 2214 | |
| 2215 | // 5.b. Clone consumer op. |
| 2216 | auto clonedConsumerOp = cast<TilingInterface>(Val: rewriter.clone(op&: *consumerOp)); |
| 2217 | SmallVector<unsigned> operandNumbers = |
| 2218 | llvm::map_to_vector(C&: consumerOpOperands, F: [](OpOperand *opOperand) { |
| 2219 | return opOperand->getOperandNumber(); |
| 2220 | }); |
| 2221 | SmallVector<OpOperand *> clonedOpFusedOperandsList = |
| 2222 | llvm::map_to_vector(C&: operandNumbers, F: [&](unsigned operandNum) { |
| 2223 | return &clonedConsumerOp->getOpOperand(idx: operandNum); |
| 2224 | }); |
| 2225 | |
| 2226 | // 5.c. Replace all uses of the loop result with the result of the cloned |
| 2227 | // tensor.insert_slice. |
| 2228 | rewriter.modifyOpInPlace(root: clonedConsumerOp, callable: [&]() { |
| 2229 | for (auto [operandToReplace, clonedSliceOp] : |
| 2230 | llvm::zip_equal(t&: clonedOpFusedOperandsList, u&: clonedInsertSlices)) { |
| 2231 | operandToReplace->set(clonedSliceOp.getResult()); |
| 2232 | } |
| 2233 | }); |
| 2234 | |
| 2235 | // 6. Perform tiling of the cloned consumer and replace the operand at |
| 2236 | // `operandNumber` with the source of the cloned tensor.insert_slice op. |
| 2237 | FailureOr<TilingResult> tileAndFuseResult = |
| 2238 | tensor::replaceInsertSlicesWithTiledConsumer(builder&: rewriter, sliceOps: clonedInsertSlices, |
| 2239 | consumerOperands: clonedOpFusedOperandsList); |
| 2240 | if (failed(Result: tileAndFuseResult)) { |
| 2241 | return failure(); |
| 2242 | } |
| 2243 | |
| 2244 | auto tiledConsumerOp = cast<TilingInterface>(Val: tileAndFuseResult->tiledOps[0]); |
| 2245 | for (auto [operandNum, clonedSliceOp] : |
| 2246 | llvm::zip_equal(t&: operandNumbers, u&: clonedInsertSlices)) { |
| 2247 | rewriter.replaceAllUsesWith(from: tiledConsumerOp->getOperand(idx: operandNum), |
| 2248 | to: clonedSliceOp.getSource()); |
| 2249 | } |
| 2250 | |
| 2251 | // 7. Reconstruct [nested] loop with new inits. |
| 2252 | YieldTiledValuesFn newYieldValuesFn = |
| 2253 | [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, |
| 2254 | ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult, |
| 2255 | SmallVector<SmallVector<OpFoldResult>> &tiledOffset, |
| 2256 | SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult { |
| 2257 | OpBuilder::InsertionGuard g(innerRewriter); |
| 2258 | // 8. Set inner insertPoint right before tiled consumer op. |
| 2259 | innerRewriter.setInsertionPoint(tiledConsumerOp); |
| 2260 | |
| 2261 | SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes; |
| 2262 | for (auto candidateSliceOp : clonedInsertSlices) { |
| 2263 | SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets(); |
| 2264 | SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes(); |
| 2265 | SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides(); |
| 2266 | |
| 2267 | // 9. Check all insert stride is 1. |
| 2268 | if (!llvm::all_of(Range&: strides, P: isOneInteger)) { |
| 2269 | return rewriter.notifyMatchFailure( |
| 2270 | arg&: candidateSliceOp, msg: "containingOp's result yield with stride" ); |
| 2271 | } |
| 2272 | |
| 2273 | allOffsets.emplace_back(Args: std::move(offsets)); |
| 2274 | allSizes.emplace_back(Args: std::move(sizes)); |
| 2275 | } |
| 2276 | |
| 2277 | // 10. Try to get iter domain position from input position. Use |
| 2278 | // clonedConsumerOp instead of tiledConsumerOp, because the iteration |
| 2279 | // domain may require index computation based on the result size. The |
| 2280 | // sizes and offsets should be the same either way, but using |
| 2281 | // tiledConsumerOp could lead to some chained unnecessary extra index |
| 2282 | // computation. |
| 2283 | SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes; |
| 2284 | if (failed(Result: clonedConsumerOp.getIterationDomainTileFromOperandTiles( |
| 2285 | b&: rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets, |
| 2286 | iterDomainSizes))) { |
| 2287 | return rewriter.notifyMatchFailure( |
| 2288 | arg&: clonedConsumerOp, |
| 2289 | msg: "can't get iter domain position from input position" ); |
| 2290 | } |
| 2291 | |
| 2292 | // 11. Try to fetch the offset and size for all results of the cloned |
| 2293 | // consumer. This would then be used to form the corresponding |
| 2294 | // tensor.insert_slice/parallel_insert_slice later. |
| 2295 | unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); |
| 2296 | SmallVector<SmallVector<OpFoldResult>> resultOffsets( |
| 2297 | totalNumResultsOfConsumer); |
| 2298 | SmallVector<SmallVector<OpFoldResult>> resultSizes( |
| 2299 | totalNumResultsOfConsumer); |
| 2300 | for (auto [idx, v] : llvm::enumerate(First: tiledConsumerOp->getResults())) { |
| 2301 | if (failed(Result: tiledConsumerOp.getResultTilePosition( |
| 2302 | b&: rewriter, resultNumber: idx, offsets: iterDomainOffsets, sizes: iterDomainSizes, |
| 2303 | resultOffsets&: resultOffsets[idx], resultSizes&: resultSizes[idx]))) { |
| 2304 | return rewriter.notifyMatchFailure( |
| 2305 | arg&: tiledConsumerOp, |
| 2306 | msg: "can't get result domain position from iter domain position" ); |
| 2307 | } |
| 2308 | } |
| 2309 | |
| 2310 | // 12. Create `extract_slice` for `iter_args` for DPS operation if |
| 2311 | // necessary. |
| 2312 | if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>( |
| 2313 | Val: tiledConsumerOp.getOperation())) { |
| 2314 | rewriter.setInsertionPoint(tiledDestStyleOp); |
| 2315 | for (const auto &&[index, newRegionArg] : |
| 2316 | llvm::enumerate(First&: newRegionIterArgs)) { |
| 2317 | auto destSlice = rewriter.create<tensor::ExtractSliceOp>( |
| 2318 | location: loc, args&: newRegionArg, args&: resultOffsets[index], args&: resultSizes[index], |
| 2319 | args: SmallVector<OpFoldResult>(resultOffsets[index].size(), |
| 2320 | rewriter.getIndexAttr(value: 1))); |
| 2321 | // Make a copy of index to avoid a capturing structured binding, which |
| 2322 | // is a C++20 extension. |
| 2323 | auto dstNumber = index; |
| 2324 | rewriter.modifyOpInPlace(root: tiledDestStyleOp, callable: [&]() { |
| 2325 | tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); |
| 2326 | }); |
| 2327 | } |
| 2328 | } |
| 2329 | |
| 2330 | // 13. Prepare tiled offset and sizes for later `insert_slice` creation by |
| 2331 | // caller. |
| 2332 | Block *block = rewriter.getInsertionPoint()->getBlock(); |
| 2333 | rewriter.setInsertionPoint(block->getTerminator()); |
| 2334 | for (const auto &&[index, result] : |
| 2335 | llvm::enumerate(First: tiledConsumerOp->getResults())) { |
| 2336 | tiledResult.push_back(Elt: result); |
| 2337 | tiledOffset.emplace_back(Args&: resultOffsets[index]); |
| 2338 | tiledSizes.emplace_back(Args&: resultSizes[index]); |
| 2339 | } |
| 2340 | return success(); |
| 2341 | }; |
| 2342 | // 14. Add new inits to [nested] loops. |
| 2343 | if (failed(Result: addInitOperandsToLoopNest(rewriter, loops, newInitValues: newInits, |
| 2344 | getNewTiledYieldsFn: newYieldValuesFn))) { |
| 2345 | return rewriter.notifyMatchFailure(arg&: tiledConsumerOp, |
| 2346 | msg: "unable to add new inits to nest loop" ); |
| 2347 | } |
| 2348 | |
| 2349 | // 15. Replace the result of scf loop and consumer op with new loop's |
| 2350 | // results. |
| 2351 | |
| 2352 | for (auto &&[oldResult, newResult] : |
| 2353 | llvm::zip(t: consumerOp->getResults(), |
| 2354 | u: loops.front()->getResults().take_back(n: newInits.size()))) { |
| 2355 | rewriter.replaceAllUsesWith(from: oldResult, to: newResult); |
| 2356 | } |
| 2357 | |
| 2358 | // 16. Need to erase the old scf loop and the cloned consumer op. |
| 2359 | rewriter.eraseOp(op: clonedConsumerOp); |
| 2360 | |
| 2361 | SmallVector<OpOperand *> tiledAndFusedOpOperands = |
| 2362 | llvm::map_to_vector(C&: operandNumbers, F: [&](unsigned operandNum) { |
| 2363 | return &tileAndFuseResult->tiledOps[0]->getOpOperand(idx: operandNum); |
| 2364 | }); |
| 2365 | return scf::SCFFuseConsumerOfSliceResult{ |
| 2366 | .origConsumerOperands: std::move(consumerOpOperands), .tiledAndFusedConsumerOperands: std::move(tiledAndFusedOpOperands), |
| 2367 | .tiledOps: std::move(tileAndFuseResult->tiledOps)}; |
| 2368 | } |
| 2369 | |
| 2370 | //===----------------------------------------------------------------------===// |
| 2371 | // lowerToLoopsUsingSCFForOp implementation. |
| 2372 | //===----------------------------------------------------------------------===// |
| 2373 | |
| 2374 | FailureOr<SmallVector<scf::ForOp>> |
| 2375 | mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, |
| 2376 | TilingInterface op) { |
| 2377 | // TODO: Handle cases where the op has results if needed. |
| 2378 | if (op->getNumResults() > 0) { |
| 2379 | return rewriter.notifyMatchFailure( |
| 2380 | arg&: op, msg: "unable to lower to loops operations with return values" ); |
| 2381 | } |
| 2382 | |
| 2383 | SmallVector<Range> domain = op.getIterationDomain(b&: rewriter); |
| 2384 | SmallVector<Value> ivs; |
| 2385 | SmallVector<scf::ForOp> loops; |
| 2386 | Location loc = op.getLoc(); |
| 2387 | for (auto loopRange : domain) { |
| 2388 | Value offsetVal = |
| 2389 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.offset); |
| 2390 | Value sizeVal = |
| 2391 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.size); |
| 2392 | Value strideVal = |
| 2393 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: loopRange.stride); |
| 2394 | auto loop = rewriter.create<scf::ForOp>(location: op.getLoc(), args&: offsetVal, args&: sizeVal, |
| 2395 | args&: strideVal, args: ValueRange{}); |
| 2396 | loops.push_back(Elt: loop); |
| 2397 | ivs.push_back(Elt: loop.getInductionVar()); |
| 2398 | rewriter.setInsertionPoint(loop.getBody()->getTerminator()); |
| 2399 | } |
| 2400 | if (failed(Result: op.generateScalarImplementation(b&: rewriter, loc: op.getLoc(), ivs))) { |
| 2401 | return failure(); |
| 2402 | } |
| 2403 | return loops; |
| 2404 | } |
| 2405 | |