| 1 | //===- Transforms.cpp - Linalg transformations as patterns ----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements logic and helpers to expose Linalg transforms as rewrite |
| 10 | // patterns. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 19 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 20 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
| 21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 22 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" |
| 23 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 24 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 25 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 26 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 27 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 28 | #include "mlir/IR/AffineExpr.h" |
| 29 | #include "mlir/IR/Matchers.h" |
| 30 | #include "mlir/Pass/Pass.h" |
| 31 | #include "mlir/Support/LLVM.h" |
| 32 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 33 | #include "llvm/ADT/ScopeExit.h" |
| 34 | #include "llvm/ADT/TypeSwitch.h" |
| 35 | #include "llvm/Support/Debug.h" |
| 36 | #include "llvm/Support/InterleavedRange.h" |
| 37 | #include "llvm/Support/raw_ostream.h" |
| 38 | #include <type_traits> |
| 39 | #include <utility> |
| 40 | |
| 41 | #define DEBUG_TYPE "linalg-transforms" |
| 42 | |
| 43 | using namespace mlir; |
| 44 | using namespace mlir::linalg; |
| 45 | |
| 46 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| 47 | #define DBGSNL() (llvm::dbgs() << "\n") |
| 48 | |
| 49 | //===----------------------------------------------------------------------===// |
| 50 | // Transformations exposed as functional-style API calls. |
| 51 | //===----------------------------------------------------------------------===// |
| 52 | |
| 53 | //===----------------------------------------------------------------------===// |
| 54 | // peelLoop transformation. |
| 55 | //===----------------------------------------------------------------------===// |
| 56 | |
| 57 | /// Try to peel and canonicalize loop `op` and return the new result. |
| 58 | /// Also applies affine_min/max bounds simplification on the fly where relevant. |
| 59 | // TODO: Add support for scf.parallel and affine.for loops. |
| 60 | SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter, |
| 61 | Operation *op) { |
| 62 | return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) |
| 63 | .Case<scf::ForOp>(caseFn: [&](scf::ForOp forOp) { |
| 64 | scf::ForOp partialIteration; |
| 65 | if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, |
| 66 | partialIteration))) |
| 67 | return partialIteration->getResults(); |
| 68 | assert(!partialIteration && "expected that loop was not peeled" ); |
| 69 | return forOp->getResults(); |
| 70 | }) |
| 71 | .Default(defaultFn: [&](Operation *op) { return op->getResults(); }); |
| 72 | } |
| 73 | |
| 74 | /// Peel 'loops' and applies affine_min/max bounds simplification on the fly |
| 75 | /// where relevant. |
| 76 | void mlir::linalg::peelLoops(RewriterBase &rewriter, |
| 77 | ArrayRef<scf::ForOp> loops) { |
| 78 | for (auto loopOp : loops) |
| 79 | peelLoop(rewriter, loopOp); |
| 80 | } |
| 81 | |
| 82 | //===----------------------------------------------------------------------===// |
| 83 | // pack transformation. |
| 84 | //===----------------------------------------------------------------------===// |
| 85 | |
| 86 | #ifndef NDEBUG |
| 87 | /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). |
| 88 | static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { |
| 89 | bool found = false; |
| 90 | for (AffineExpr e : map.getResults()) { |
| 91 | if (!e.isFunctionOfDim(position: dim)) |
| 92 | continue; |
| 93 | if (found) |
| 94 | return false; |
| 95 | found = true; |
| 96 | } |
| 97 | return true; |
| 98 | } |
| 99 | |
| 100 | static std::string stringifyReassocIndices(ReassociationIndicesRef ri) { |
| 101 | return llvm::interleaved(R: ri, Separator: ", " , /*Prefix=*/"|" , /*Suffix=*/"" ); |
| 102 | } |
| 103 | #endif // NDEBUG |
| 104 | |
| 105 | /// Return the index of the first result of `map` that is a function of |
| 106 | /// AffineDimExpr(dim), std::nullopt otherwise. |
| 107 | static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map, |
| 108 | int64_t dim) { |
| 109 | for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
| 110 | AffineExpr expr = map.getResult(idx: i); |
| 111 | if (!expr.isFunctionOfDim(position: dim)) |
| 112 | continue; |
| 113 | return i; |
| 114 | } |
| 115 | return std::nullopt; |
| 116 | } |
| 117 | |
| 118 | /// Perform one step of packing of a LinalgOp's metadata along `dim` into the |
| 119 | /// `newDim` at `iteratorTypes.size()` by: |
| 120 | /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. |
| 121 | /// 2. Appending a `newDim` to the domain of every indexing map. |
| 122 | /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing |
| 123 | /// by potentially adding a `newDim` result to `map`. |
| 124 | /// The preserved invariant is that `iteratorTypes.size()` is always equal to |
| 125 | /// `map.getNumDims()` for every map in `indexingMaps`. |
| 126 | /// |
| 127 | /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. |
| 128 | /// Return a vector that records the optional packing for each operand. |
| 129 | /// Return failure if the packed indexing cannot be represented with a LinalgOp. |
| 130 | /// |
| 131 | /// Further details: |
| 132 | /// ================ |
| 133 | /// The current implementation of packing (i.e. data tiling) consists of |
| 134 | /// rewriting a linearized strip-mined form into a higher-dimensional access. |
| 135 | /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite |
| 136 | /// `I` into `4 * i + ii`, where `0 <= ii < 4`. |
| 137 | /// The access is further rewritten as `A[i][f(j, k, l)][ii]`. |
| 138 | /// |
| 139 | /// This rewrite into higher dimensional access is not possible for general |
| 140 | /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: |
| 141 | /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we |
| 142 | /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. |
| 143 | /// The rewrite of the access would be a form not representable in Linalg: |
| 144 | /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. |
| 145 | /// Note however that as `J` and `ii` iterate, the accesses do not have a |
| 146 | /// particular alignment, so packing does not achieve alignment in this case |
| 147 | /// |
| 148 | /// In the future, we may want to consider a mixed-form that allows some |
| 149 | /// alignment in the presence of multiple accesses: |
| 150 | /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` |
| 151 | /// And would rewrite accesses as: |
| 152 | /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` |
| 153 | static FailureOr<SmallVector<std::optional<int64_t>>> |
| 154 | packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps, |
| 155 | SmallVectorImpl<utils::IteratorType> &iteratorTypes, |
| 156 | int64_t dim) { |
| 157 | int64_t newDim = iteratorTypes.size(); |
| 158 | iteratorTypes.push_back(iteratorTypes[dim]); |
| 159 | |
| 160 | SmallVector<std::optional<int64_t>> packedDimPerIndexingMap( |
| 161 | indexingMaps.size(), std::nullopt); |
| 162 | SmallVector<AffineMap> newMaps; |
| 163 | for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; |
| 164 | ++operandIdx) { |
| 165 | AffineMap map = indexingMaps[operandIdx]; |
| 166 | |
| 167 | // Add the `newDim` to map whatever the case. |
| 168 | assert(map.getNumDims() == newDim && "num dims invariant violation" ); |
| 169 | map = map.shiftDims(shift: 1, offset: newDim); |
| 170 | |
| 171 | // Get the at-most-1 index of the result that is a function of `dim`. |
| 172 | // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which |
| 173 | // logically chunks dimension `dim` into `K * dim + newDim`, where the |
| 174 | // packing factor `K` is specified separately. |
| 175 | assert(hasAtMostOneResultFunctionOfDim(map, dim) && |
| 176 | "num results invariant violation" ); |
| 177 | auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); |
| 178 | if (!maybeOperandDimensionToPack.has_value()) { |
| 179 | newMaps.push_back(Elt: map); |
| 180 | continue; |
| 181 | } |
| 182 | |
| 183 | // We can only pack AffineDimExpr atm. |
| 184 | if (!isa<AffineDimExpr>(Val: map.getResult(idx: maybeOperandDimensionToPack.value()))) |
| 185 | return failure(); |
| 186 | |
| 187 | // Add `newDim` to the results of the map. |
| 188 | map = map.insertResult(expr: Builder(map.getContext()).getAffineDimExpr(position: newDim), |
| 189 | pos: map.getNumResults()); |
| 190 | newMaps.push_back(Elt: map); |
| 191 | |
| 192 | // Record the that `operandIdx` is packed. |
| 193 | packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; |
| 194 | } |
| 195 | indexingMaps = newMaps; |
| 196 | |
| 197 | return packedDimPerIndexingMap; |
| 198 | } |
| 199 | |
| 200 | namespace { |
| 201 | |
| 202 | /// Helper struct to encode packing along one dimension of a LinalgOp. |
| 203 | struct PackedOperandsDim { |
| 204 | OpFoldResult packedSize; |
| 205 | SmallVector<std::optional<int64_t>> packedDimForEachOperand; |
| 206 | }; |
| 207 | |
| 208 | /// Helper struct to encode packing along all dimensions of a LinalgOp. |
| 209 | struct PackedOperandsDimList { |
| 210 | void pushBack(PackedOperandsDim &&packedOperandsDims) { |
| 211 | spec.emplace_back(Args&: packedOperandsDims); |
| 212 | } |
| 213 | /// Return all the dims that have been packed for operand @ `operandPos`. |
| 214 | SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos); |
| 215 | /// Return all the pack sizes by which an operand @ `operandPos` is packed. |
| 216 | SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos); |
| 217 | |
| 218 | private: |
| 219 | SmallVector<PackedOperandsDim> spec; |
| 220 | }; |
| 221 | |
| 222 | } // namespace |
| 223 | |
| 224 | FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, |
| 225 | linalg::PackOp packOp, |
| 226 | bool lowerPadLikeWithInsertSlice) { |
| 227 | // 1. Filter out NYI cases. |
| 228 | auto packedTensorType = |
| 229 | cast<RankedTensorType>(packOp->getResultTypes().front()); |
| 230 | if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) { |
| 231 | return rewriter.notifyMatchFailure( |
| 232 | packOp, |
| 233 | "non-static shape NYI, needs a more powerful tensor.expand_shape op" ); |
| 234 | } |
| 235 | |
| 236 | Location loc = packOp->getLoc(); |
| 237 | OpBuilder::InsertionGuard g(rewriter); |
| 238 | rewriter.setInsertionPoint(packOp); |
| 239 | |
| 240 | // 2. Compute the permutation vector to shuffle packed shape into the shape |
| 241 | // before any outer or inner permutations have been applied. |
| 242 | PackingMetadata packingMetadata = computePackingMetadata( |
| 243 | packedTensorType.getRank(), packOp.getInnerDimsPos()); |
| 244 | SmallVector<int64_t> packedToStripMinedShapePerm = |
| 245 | getPackInverseDestPerm(packOp); |
| 246 | |
| 247 | // 3. Compute the stripMinedShape: this is the packed shape before any outer |
| 248 | // or inner permutations have been applied. |
| 249 | SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); |
| 250 | applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm); |
| 251 | |
| 252 | // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. |
| 253 | SmallVector<OpFoldResult> lows(packOp.getSourceRank(), |
| 254 | rewriter.getIndexAttr(0)); |
| 255 | SmallVector<OpFoldResult> highs(packOp.getSourceRank(), |
| 256 | rewriter.getIndexAttr(0)); |
| 257 | for (auto [pos, innerSize] : |
| 258 | llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { |
| 259 | int outerPos = |
| 260 | packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; |
| 261 | OpFoldResult origSize = |
| 262 | tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos); |
| 263 | OpFoldResult outerSize = |
| 264 | tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos); |
| 265 | AffineExpr s0, d0, d1; |
| 266 | bindDims(rewriter.getContext(), d0, d1); |
| 267 | bindSymbols(rewriter.getContext(), s0); |
| 268 | auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1); |
| 269 | highs[pos] = affine::makeComposedFoldedAffineApply( |
| 270 | rewriter, loc, map, {outerSize, origSize, innerSize}); |
| 271 | } |
| 272 | RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( |
| 273 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), |
| 274 | packingMetadata.reassociations); |
| 275 | Value paddingValue = packOp.getPaddingValue(); |
| 276 | if (!paddingValue) { |
| 277 | paddingValue = rewriter.create<arith::ConstantOp>( |
| 278 | loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); |
| 279 | } |
| 280 | auto padOp = |
| 281 | rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows, |
| 282 | highs, paddingValue, /*nofold=*/false); |
| 283 | |
| 284 | LLVM_DEBUG( |
| 285 | DBGSNL(); DBGSNL(); |
| 286 | DBGS() << "insertPositions: " |
| 287 | << llvm::interleaved(packingMetadata.insertPositions); |
| 288 | DBGSNL(); DBGS() << "outerPositions: " |
| 289 | << llvm::interleaved(packingMetadata.outerPositions); |
| 290 | DBGSNL(); DBGS() << "packedShape: " |
| 291 | << llvm::interleaved(packedTensorType.getShape()); |
| 292 | DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " |
| 293 | << llvm::interleaved(packedToStripMinedShapePerm); |
| 294 | DBGSNL(); |
| 295 | DBGS() << "reassociations: " |
| 296 | << llvm::interleaved(llvm::map_range( |
| 297 | packingMetadata.reassociations, stringifyReassocIndices)); |
| 298 | DBGSNL(); |
| 299 | DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); |
| 300 | DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); |
| 301 | |
| 302 | if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { |
| 303 | // Pack ops which operate as simple pads may not produce legal |
| 304 | // tensor.insert_slice operations when the packed type does not rank reduce |
| 305 | // to the padded type. |
| 306 | SliceVerificationResult rankReduces = |
| 307 | isRankReducedType(packedTensorType, padOp.getResultType()); |
| 308 | |
| 309 | if (rankReduces == SliceVerificationResult::Success) { |
| 310 | // This pack is just a plain pad. |
| 311 | // Just insert the pad in the higher ranked tensor. |
| 312 | // Offsets. |
| 313 | SmallVector<OpFoldResult> zeros(packOp.getDestRank(), |
| 314 | rewriter.getIndexAttr(0)); |
| 315 | // Strides. |
| 316 | SmallVector<OpFoldResult> ones(packOp.getDestRank(), |
| 317 | rewriter.getIndexAttr(1)); |
| 318 | SmallVector<OpFoldResult> sizes = |
| 319 | tensor::getMixedSizes(builder&: rewriter, loc, value: packOp.getDest()); |
| 320 | |
| 321 | auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>( |
| 322 | loc, /*source=*/padOp, /*dest=*/packOp.getDest(), |
| 323 | /*offsets=*/zeros, sizes, /*strides=*/ones); |
| 324 | |
| 325 | LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); |
| 326 | |
| 327 | rewriter.replaceOp(packOp, insertSliceOp->getResults()); |
| 328 | |
| 329 | return LowerPackResult{padOp, /*reshapeOp=*/nullptr, |
| 330 | /*transposeOp=*/nullptr}; |
| 331 | } |
| 332 | } |
| 333 | |
| 334 | // 5. Expand from the padded result to the stripMinedShape. |
| 335 | auto expandShapeResultType = |
| 336 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); |
| 337 | auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>( |
| 338 | loc, expandShapeResultType, padOp.getResult(), |
| 339 | packingMetadata.reassociations); |
| 340 | |
| 341 | // 6. Transpose stripMinedShape to packedShape. |
| 342 | SmallVector<int64_t> transpPerm = |
| 343 | invertPermutationVector(permutation: packedToStripMinedShapePerm); |
| 344 | auto transposeOp = rewriter.create<linalg::TransposeOp>( |
| 345 | loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); |
| 346 | |
| 347 | LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); |
| 348 | DBGS() << "reshape op: " << reshapeOp; DBGSNL(); |
| 349 | DBGS() << "transpPerm: " << llvm::interleaved(transpPerm); |
| 350 | DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); |
| 351 | |
| 352 | // 7. Replace packOp by transposeOp. |
| 353 | rewriter.replaceOp(packOp, transposeOp->getResults()); |
| 354 | |
| 355 | return LowerPackResult{padOp, reshapeOp, transposeOp}; |
| 356 | } |
| 357 | |
| 358 | FailureOr<LowerUnPackOpResult> |
| 359 | linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, |
| 360 | bool ) { |
| 361 | Location loc = unPackOp->getLoc(); |
| 362 | OpBuilder::InsertionGuard g(rewriter); |
| 363 | rewriter.setInsertionPoint(unPackOp); |
| 364 | |
| 365 | RankedTensorType packedTensorType = unPackOp.getSourceType(); |
| 366 | int64_t packedRank = packedTensorType.getRank(); |
| 367 | |
| 368 | OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); |
| 369 | auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType()); |
| 370 | if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) { |
| 371 | // This unpack is just a plain unpad. |
| 372 | // Just extract the slice from the higher ranked tensor. |
| 373 | ArrayRef<int64_t> destShape = destTensorType.getShape(); |
| 374 | // The inner dimensions stay the same as the destination tensor, but the |
| 375 | // outer ones are additional 1s. |
| 376 | SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one); |
| 377 | sizes.append(tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest())); |
| 378 | |
| 379 | auto = rewriter.create<tensor::ExtractSliceOp>( |
| 380 | loc, destTensorType, unPackOp.getSource(), |
| 381 | SmallVector<OpFoldResult>(packedRank, zero), sizes, |
| 382 | SmallVector<OpFoldResult>(packedRank, one)); |
| 383 | |
| 384 | rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); |
| 385 | |
| 386 | return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, |
| 387 | /*reshapeOp=*/nullptr, extractSliceOp}; |
| 388 | } |
| 389 | |
| 390 | // 1. Compute the permutation vector to shuffle packed shape into the shape |
| 391 | // before any outer or inner permutations have been applied. |
| 392 | PackingMetadata packingMetadata; |
| 393 | SmallVector<int64_t> packedToStripMinedShapePerm = |
| 394 | getUnPackInverseSrcPerm(unPackOp, packingMetadata); |
| 395 | |
| 396 | // 2. Compute the stripMinedShape: this is the packed shape without outer and |
| 397 | // inner permutations. |
| 398 | SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); |
| 399 | applyPermutationToVector(inVec&: stripMinedShape, permutation: packedToStripMinedShapePerm); |
| 400 | |
| 401 | // 3. Transpose packedShape to stripMinedShape. |
| 402 | RankedTensorType stripMinedTensorType = |
| 403 | RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); |
| 404 | RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( |
| 405 | stripMinedTensorType, packingMetadata.reassociations); |
| 406 | |
| 407 | // Get dynamic dims from input tensor based on packedToStripMinedShapePerm |
| 408 | // permutation. |
| 409 | SmallVector<OpFoldResult, 4> dims = |
| 410 | tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getSource()); |
| 411 | applyPermutationToVector(inVec&: dims, permutation: packedToStripMinedShapePerm); |
| 412 | auto emptyOp = rewriter.create<tensor::EmptyOp>( |
| 413 | loc, dims, stripMinedTensorType.getElementType()); |
| 414 | auto transposeOp = rewriter.create<linalg::TransposeOp>( |
| 415 | loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); |
| 416 | |
| 417 | LLVM_DEBUG( |
| 418 | DBGSNL(); DBGSNL(); |
| 419 | DBGS() << "insertPositions: " |
| 420 | << llvm::interleaved(packingMetadata.insertPositions); |
| 421 | DBGSNL(); DBGS() << "packedShape: " |
| 422 | << llvm::interleaved(packedTensorType.getShape()); |
| 423 | DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " |
| 424 | << llvm::interleaved(packedToStripMinedShapePerm); |
| 425 | DBGSNL(); |
| 426 | DBGS() << "reassociations: " |
| 427 | << llvm::interleaved(llvm::map_range( |
| 428 | packingMetadata.reassociations, stringifyReassocIndices)); |
| 429 | DBGSNL(); |
| 430 | DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); |
| 431 | DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); |
| 432 | |
| 433 | // 4. Collapse from the stripMinedShape to the padded result. |
| 434 | auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>( |
| 435 | loc, collapsedType, transposeOp->getResult(0), |
| 436 | packingMetadata.reassociations); |
| 437 | |
| 438 | // 5. ExtractSlice. |
| 439 | int64_t destRank = destTensorType.getRank(); |
| 440 | auto = rewriter.create<tensor::ExtractSliceOp>( |
| 441 | loc, destTensorType, reshapeOp->getResult(0), |
| 442 | SmallVector<OpFoldResult>(destRank, zero), |
| 443 | tensor::getMixedSizes(builder&: rewriter, loc, value: unPackOp.getDest()), |
| 444 | SmallVector<OpFoldResult>(destRank, one)); |
| 445 | |
| 446 | // 6. Inject a copy to preserve DPS. |
| 447 | auto copyOp = rewriter.create<linalg::CopyOp>( |
| 448 | loc, extractSliceOp->getResult(0), unPackOp.getDest()); |
| 449 | |
| 450 | // 7. Replace unPackOp by copyOp. |
| 451 | rewriter.replaceOp(unPackOp, copyOp->getResults()); |
| 452 | |
| 453 | return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; |
| 454 | } |
| 455 | |
| 456 | SmallVector<int64_t> |
| 457 | PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { |
| 458 | SmallVector<int64_t> res; |
| 459 | for (auto &i : spec) { |
| 460 | if (!i.packedDimForEachOperand[operandPos].has_value()) |
| 461 | continue; |
| 462 | res.push_back(Elt: i.packedDimForEachOperand[operandPos].value()); |
| 463 | } |
| 464 | return res; |
| 465 | } |
| 466 | |
| 467 | SmallVector<OpFoldResult> |
| 468 | PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { |
| 469 | SmallVector<OpFoldResult> res; |
| 470 | for (auto &i : spec) { |
| 471 | if (!i.packedDimForEachOperand[operandPos].has_value()) |
| 472 | continue; |
| 473 | res.push_back(Elt: i.packedSize); |
| 474 | } |
| 475 | return res; |
| 476 | } |
| 477 | |
| 478 | /// Implement packing of a single LinalgOp by performing packing by |
| 479 | /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. |
| 480 | /// Return the packed Linalg op on success, failure otherwise. |
| 481 | FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, |
| 482 | linalg::LinalgOp linalgOp, |
| 483 | ArrayRef<OpFoldResult> packedSizes) { |
| 484 | if (packedSizes.size() != linalgOp.getNumLoops()) { |
| 485 | return rewriter.notifyMatchFailure(linalgOp, |
| 486 | "incorrect number of pack sizes" ); |
| 487 | } |
| 488 | |
| 489 | Location loc = linalgOp->getLoc(); |
| 490 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
| 491 | SmallVector<utils::IteratorType> iteratorTypes = |
| 492 | linalgOp.getIteratorTypesArray(); |
| 493 | LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n" |
| 494 | << "maps: " << llvm::interleaved(indexingMaps) << "\n" |
| 495 | << "iterators: " << llvm::interleaved(iteratorTypes) |
| 496 | << "\n" ); |
| 497 | |
| 498 | SmallVector<linalg::PackOp> packOps; |
| 499 | SmallVector<linalg::UnPackOp> unPackOps; |
| 500 | // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. |
| 501 | PackedOperandsDimList listOfPackedOperandsDim; |
| 502 | for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { |
| 503 | std::optional<int64_t> maybeConstant = getConstantIntValue(ofr: packedSizes[i]); |
| 504 | // Skip tile sizes explicitly set to 0. |
| 505 | if (maybeConstant.has_value() && maybeConstant.value() == 0) |
| 506 | continue; |
| 507 | |
| 508 | PackedOperandsDim packedOperandsDims; |
| 509 | packedOperandsDims.packedSize = packedSizes[i]; |
| 510 | FailureOr<SmallVector<std::optional<int64_t>>> |
| 511 | maybePackedDimForEachOperand = |
| 512 | packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); |
| 513 | if (failed(Result: maybePackedDimForEachOperand)) |
| 514 | return failure(); |
| 515 | packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; |
| 516 | listOfPackedOperandsDim.pushBack(packedOperandsDims: std::move(packedOperandsDims)); |
| 517 | |
| 518 | LLVM_DEBUG( |
| 519 | DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] |
| 520 | << "\n" |
| 521 | << "maps: " << llvm::interleaved(indexingMaps) << "\n" |
| 522 | << "iterators: " << llvm::interleaved(iteratorTypes) << "\n" |
| 523 | << "packedDimForEachOperand: " |
| 524 | << llvm::interleaved(packedOperandsDims.packedDimForEachOperand) |
| 525 | << "\n" ); |
| 526 | } |
| 527 | |
| 528 | // Step 2. Propagate packing to all LinalgOp operands. |
| 529 | SmallVector<Value> inputsAndInits, results; |
| 530 | SmallVector<OpOperand *> initOperands = |
| 531 | llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable())); |
| 532 | SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands(); |
| 533 | for (const auto &operandsList : {inputOperands, initOperands}) { |
| 534 | for (OpOperand *opOperand : operandsList) { |
| 535 | int64_t pos = opOperand->getOperandNumber(); |
| 536 | Value operand = opOperand->get(); |
| 537 | SmallVector<int64_t> innerPos = |
| 538 | listOfPackedOperandsDim.extractPackedDimsForOperand(pos); |
| 539 | SmallVector<OpFoldResult> innerPackSizes = |
| 540 | listOfPackedOperandsDim.extractPackSizesForOperand(pos); |
| 541 | LLVM_DEBUG(DBGS() << "operand: " << operand << "\n" |
| 542 | << "innerPos: " << llvm::interleaved(innerPos) << "\n" |
| 543 | << "innerPackSizes: " |
| 544 | << llvm::interleaved(innerPackSizes) << "\n" ); |
| 545 | if (innerPackSizes.empty()) { |
| 546 | inputsAndInits.push_back(operand); |
| 547 | continue; |
| 548 | } |
| 549 | Value dest = linalg::PackOp::createDestinationTensor( |
| 550 | rewriter, loc, operand, innerPackSizes, innerPos, |
| 551 | /*outerDimsPerm=*/{}); |
| 552 | ShapedType operandType = cast<ShapedType>(operand.getType()); |
| 553 | bool areConstantTiles = |
| 554 | llvm::all_of(innerPackSizes, [](OpFoldResult tile) { |
| 555 | return getConstantIntValue(tile).has_value(); |
| 556 | }); |
| 557 | if (areConstantTiles && operandType.hasStaticShape() && |
| 558 | !linalg::PackOp::requirePaddingValue( |
| 559 | operandType.getShape(), innerPos, |
| 560 | cast<ShapedType>(dest.getType()).getShape(), {}, |
| 561 | innerPackSizes)) { |
| 562 | packOps.push_back(rewriter.create<linalg::PackOp>( |
| 563 | loc, operand, dest, innerPos, innerPackSizes)); |
| 564 | } else { |
| 565 | // TODO: value of the padding attribute should be determined by |
| 566 | // consumers. |
| 567 | auto zeroAttr = |
| 568 | rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); |
| 569 | Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); |
| 570 | packOps.push_back(rewriter.create<linalg::PackOp>( |
| 571 | loc, operand, dest, innerPos, innerPackSizes, zero)); |
| 572 | } |
| 573 | inputsAndInits.push_back(packOps.back()); |
| 574 | } |
| 575 | } |
| 576 | |
| 577 | // Step 3. Build the packed op, use the type of `inits` as result types. |
| 578 | ValueRange inputs = |
| 579 | ValueRange{inputsAndInits}.take_front(n: linalgOp.getNumDpsInputs()); |
| 580 | ValueRange inits = |
| 581 | ValueRange{inputsAndInits}.take_back(n: linalgOp.getNumDpsInits()); |
| 582 | auto packedLinalgOp = rewriter.create<linalg::GenericOp>( |
| 583 | linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, |
| 584 | iteratorTypes); |
| 585 | packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); |
| 586 | |
| 587 | // Step 4. Propagate packing to all the op results. |
| 588 | for (OpResult result : packedLinalgOp->getResults()) { |
| 589 | int64_t resultNum = result.getResultNumber(); |
| 590 | linalg::PackOp maybePackedInit = |
| 591 | inits[resultNum].getDefiningOp<linalg::PackOp>(); |
| 592 | if (!maybePackedInit) { |
| 593 | results.push_back(result); |
| 594 | continue; |
| 595 | } |
| 596 | // Build the symmetrical UnPackOp to the existing PackOp. |
| 597 | unPackOps.push_back(rewriter.create<linalg::UnPackOp>( |
| 598 | packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), |
| 599 | maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); |
| 600 | results.push_back(unPackOps.back()); |
| 601 | } |
| 602 | |
| 603 | // Step 5. Replace `linalgOp`. |
| 604 | rewriter.replaceOp(linalgOp, results); |
| 605 | |
| 606 | // Return packedLinalgOp. |
| 607 | return PackResult{packOps, |
| 608 | cast<linalg::LinalgOp>(packedLinalgOp.getOperation()), |
| 609 | unPackOps}; |
| 610 | } |
| 611 | |
| 612 | //===----------------------------------------------------------------------===// |
| 613 | // packTranspose transformation. |
| 614 | //===----------------------------------------------------------------------===// |
| 615 | |
| 616 | /// Return a copy of `tensorType` after permutation by `permutationVector`. |
| 617 | // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder |
| 618 | // but this would introduce a dependence on Dialect in IR. |
| 619 | // TODO: Restructure. |
| 620 | static RankedTensorType permuteShape(RankedTensorType tensorType, |
| 621 | ArrayRef<int64_t> permutationVector) { |
| 622 | SmallVector<int64_t> shape(tensorType.getShape()); |
| 623 | applyPermutationToVector(inVec&: shape, permutation: permutationVector); |
| 624 | return RankedTensorType::Builder(tensorType).setShape(shape); |
| 625 | } |
| 626 | |
| 627 | /// Return a new GenericOp obtained by transposing opOperand by the permutation |
| 628 | /// vector: |
| 629 | /// - the corresponding indexing map is transposed by `permutation` |
| 630 | /// - the corresponding operand value is replaced by `transposedValue` |
| 631 | /// `linalgOp` is replaced by the return op in the process. |
| 632 | /// Asserts that `transposedValue` is of the proper transposed ShapedType. |
| 633 | static LinalgOp transposeOneLinalgOperandAndReplace( |
| 634 | RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, |
| 635 | ArrayRef<int64_t> permutation, Value transposedValue) { |
| 636 | // Sanity check the operand. |
| 637 | assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand" ); |
| 638 | |
| 639 | // Sanity check of the expected transposed tensor type. |
| 640 | auto tensorType = permuteShape( |
| 641 | cast<RankedTensorType>(opOperand.get().getType()), permutation); |
| 642 | (void)tensorType; |
| 643 | assert(tensorType == transposedValue.getType() && |
| 644 | "expected tensor type mismatch" ); |
| 645 | |
| 646 | // Compute the transposed indexing map. |
| 647 | // Sigh unsigned pollution. |
| 648 | SmallVector<unsigned> tmpTransposition = llvm::to_vector( |
| 649 | Range: llvm::map_range(C&: permutation, F: [](int64_t i) -> unsigned { return i; })); |
| 650 | AffineMap permutationMap = |
| 651 | AffineMap::getPermutationMap(permutation: tmpTransposition, context: rewriter.getContext()); |
| 652 | AffineMap transposedMap = |
| 653 | permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); |
| 654 | |
| 655 | // Set the transposed indexing map in the proper position. |
| 656 | SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); |
| 657 | indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; |
| 658 | // Set the transposedValue in the proper operand position. |
| 659 | SmallVector<Value> operands = linalgOp->getOperands(); |
| 660 | operands[opOperand.getOperandNumber()] = transposedValue; |
| 661 | |
| 662 | ValueRange operandsRef(operands); |
| 663 | auto transposedGenericOp = rewriter.create<linalg::GenericOp>( |
| 664 | /*location=*/linalgOp->getLoc(), |
| 665 | /*resultTensorTypes=*/ |
| 666 | operandsRef.drop_front(n: linalgOp.getNumDpsInputs()).getTypes(), |
| 667 | /*inputs=*/operandsRef.take_front(n: linalgOp.getNumDpsInputs()), |
| 668 | /*outputs=*/operandsRef.drop_front(n: linalgOp.getNumDpsInputs()), |
| 669 | /*indexingMaps=*/indexingMaps, |
| 670 | /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); |
| 671 | transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); |
| 672 | rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); |
| 673 | |
| 674 | return cast<linalg::LinalgOp>(transposedGenericOp.getOperation()); |
| 675 | } |
| 676 | |
| 677 | FailureOr<PackTransposeResult> |
| 678 | linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, |
| 679 | linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, |
| 680 | ArrayRef<int64_t> outerPerm, |
| 681 | ArrayRef<int64_t> innerPerm) { |
| 682 | Location loc = linalgOp.getLoc(); |
| 683 | |
| 684 | // Step 1. Transpose packOp. |
| 685 | rewriter.setInsertionPoint(packOp); |
| 686 | linalg::PackOp transposedPackOp = |
| 687 | packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); |
| 688 | |
| 689 | if (!packOp.getResult().hasOneUse()) |
| 690 | return rewriter.notifyMatchFailure(linalgOp, "expect single pack use" ); |
| 691 | |
| 692 | OpOperand &packUse = *packOp->getUses().begin(); |
| 693 | if (packUse.getOwner() != linalgOp) { |
| 694 | return rewriter.notifyMatchFailure( |
| 695 | linalgOp, "not a single use by the LinalgOp target" ); |
| 696 | } |
| 697 | if (maybeUnPackOp && |
| 698 | (!linalgOp.isDpsInit(&packUse) || |
| 699 | maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { |
| 700 | return rewriter.notifyMatchFailure(linalgOp, |
| 701 | "not produced by the LinalgOp target" ); |
| 702 | } |
| 703 | |
| 704 | // Step 2. Transpose linalgOp. |
| 705 | // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the |
| 706 | // identity. Don't rely on it. |
| 707 | int64_t numLeadingDims = packOp.getSourceRank(); |
| 708 | int64_t numTrailingDims = packOp.getInnerDimsPos().size(); |
| 709 | // Step 2.a. Compute the permutation on the whole operand. |
| 710 | // Leading part just reuse the outerPerm. |
| 711 | SmallVector<int64_t> permutation(outerPerm); |
| 712 | if (permutation.empty()) |
| 713 | llvm::append_range(C&: permutation, R: llvm::seq<int64_t>(Begin: 0, End: numLeadingDims)); |
| 714 | // Trailing part needs to reindex positions by `numLeadingDims`. |
| 715 | if (innerPerm.empty()) { |
| 716 | llvm::append_range( |
| 717 | C&: permutation, |
| 718 | R: llvm::seq<int64_t>(Begin: numLeadingDims, End: numLeadingDims + numTrailingDims)); |
| 719 | } else { |
| 720 | llvm::append_range(permutation, |
| 721 | llvm::map_range(innerPerm, [&](int64_t pos) { |
| 722 | return numLeadingDims + pos; |
| 723 | })); |
| 724 | } |
| 725 | if (!isPermutationVector(interchange: permutation)) |
| 726 | return rewriter.notifyMatchFailure(linalgOp, "invalid permutation" ); |
| 727 | |
| 728 | // Step 2.b. Save the transposedPackUse operand number in case we need to |
| 729 | // get the tied OpResult after `linalgOp` has been replaced. |
| 730 | int64_t packUseOperandNumber = packUse.getOperandNumber(); |
| 731 | // Step 2.c. Actually perform the transposition. |
| 732 | rewriter.setInsertionPoint(linalgOp); |
| 733 | linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( |
| 734 | rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); |
| 735 | |
| 736 | // Step 3. Maybe transpose unPackOp. |
| 737 | linalg::UnPackOp transposedUnPackOp; |
| 738 | if (maybeUnPackOp) { |
| 739 | OpOperand &opOperand = |
| 740 | transposedLinalgOp->getOpOperand(packUseOperandNumber); |
| 741 | OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); |
| 742 | rewriter.setInsertionPoint(maybeUnPackOp); |
| 743 | transposedUnPackOp = maybeUnPackOp.createTransposedClone( |
| 744 | rewriter, loc, transposedResult, innerPerm, outerPerm); |
| 745 | |
| 746 | rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); |
| 747 | } |
| 748 | |
| 749 | // Step 4. Finally, replace packOp now that we don't need it anymore. |
| 750 | rewriter.replaceOp(packOp, transposedPackOp->getResults()); |
| 751 | |
| 752 | return PackTransposeResult{transposedPackOp, transposedLinalgOp, |
| 753 | transposedUnPackOp}; |
| 754 | } |
| 755 | |
| 756 | //===----------------------------------------------------------------------===// |
| 757 | // packMatmulGreedily transformation. |
| 758 | //===----------------------------------------------------------------------===// |
| 759 | |
| 760 | /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m |
| 761 | /// and n are proper parallel dimensions and k is a proper reduction |
| 762 | /// dimension. Packing occurs by rewriting the op as a linalg.generic and |
| 763 | /// calling linalg::pack by `mnkPackedSizes`. The order of the packed |
| 764 | /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} |
| 765 | /// to reorder {m, n, k} into one of the 8 possible forms. The outer |
| 766 | /// dimensions of the operands are not permuted at this time, this is left for |
| 767 | /// future work. |
| 768 | FailureOr<PackResult> |
| 769 | linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, |
| 770 | ArrayRef<OpFoldResult> mnkPackedSizes, |
| 771 | ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf, |
| 772 | ArrayRef<int64_t> mnkOrder) { |
| 773 | assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes" ); |
| 774 | assert((mnkPaddedSizesNextMultipleOf.empty() || |
| 775 | mnkPaddedSizesNextMultipleOf.size() == 3) && |
| 776 | "num of packing sizes next multiple should be empty or of size 3" ); |
| 777 | assert(mnkOrder.size() == 3 && "unexpected mnkOrder size" ); |
| 778 | assert(isPermutationVector(mnkOrder) && "expected a permutation" ); |
| 779 | |
| 780 | int64_t numLoops = linalgOp.getNumLoops(); |
| 781 | if (numLoops <= 2) { |
| 782 | LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " |
| 783 | << numLoops << "\nin: " << linalgOp << "\n" ); |
| 784 | return rewriter.notifyMatchFailure( |
| 785 | linalgOp, "need 3+ loops to find a matmul to pack" ); |
| 786 | } |
| 787 | |
| 788 | // Locally adjust the desired iterator position of mnk and packing sizes. |
| 789 | int64_t numPackedDims = mnkPackedSizes.size(); |
| 790 | SmallVector<int64_t> mmnnkkPos(numPackedDims); |
| 791 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) |
| 792 | mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; |
| 793 | SmallVector<OpFoldResult> packedSizes(numPackedDims); |
| 794 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) |
| 795 | packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; |
| 796 | SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims); |
| 797 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) { |
| 798 | paddedSizesNextMultipleOf[mnkOrder[i]] = |
| 799 | mnkPaddedSizesNextMultipleOf.empty() ? 0 |
| 800 | : mnkPaddedSizesNextMultipleOf[i]; |
| 801 | } |
| 802 | |
| 803 | // 1. Infer dims that are important for matmul. |
| 804 | FailureOr<ContractionDimensions> maybeDimensions = |
| 805 | inferContractionDims(linalgOp); |
| 806 | if (failed(Result: maybeDimensions)) { |
| 807 | LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp |
| 808 | << "\n" ); |
| 809 | return rewriter.notifyMatchFailure(linalgOp, |
| 810 | "couldn't infer matmul iterators" ); |
| 811 | } |
| 812 | |
| 813 | // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most |
| 814 | // minor iterators. In cases with multiple options for m, n, k bias towards |
| 815 | // the most minor embedding. |
| 816 | // If we wanted a different normalization order, this is where it would have |
| 817 | // to plug a heuristic. |
| 818 | int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), |
| 819 | kPos = maybeDimensions->k.back(); |
| 820 | LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); |
| 821 | DBGS() << "Start packing generic op greedily with (m@" << mPos |
| 822 | << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp |
| 823 | << "\n" ;); |
| 824 | |
| 825 | // 2.a. Rewrite as a generic. |
| 826 | auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation()); |
| 827 | if (!genericOp) { |
| 828 | FailureOr<GenericOp> generalizeResult = |
| 829 | generalizeNamedOp(rewriter, linalgOp); |
| 830 | assert(succeeded(generalizeResult) && "unexpected failure generalizing op" ); |
| 831 | genericOp = *generalizeResult; |
| 832 | } |
| 833 | |
| 834 | // 2.b. Interchange to move the dimensions (k, m, n) as most-minor |
| 835 | // iterators. Note that this only normalized the iteration order and does |
| 836 | // not change the indexings of any operand. |
| 837 | SmallVector<int64_t> permutation = |
| 838 | computePermutationVector(permSize: numLoops, positions: {mPos, nPos, kPos}, desiredPositions: mmnnkkPos); |
| 839 | LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n" ); |
| 840 | // Sign .. unsigned pollution. |
| 841 | SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end()); |
| 842 | FailureOr<GenericOp> interchangeResult = |
| 843 | interchangeGenericOp(rewriter, genericOp, unsignedPerm); |
| 844 | assert(succeeded(interchangeResult) && "unexpected failure interchanging op" ); |
| 845 | genericOp = *interchangeResult; |
| 846 | LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n" ;); |
| 847 | |
| 848 | // At this point, the op iterators are normalized to {leading, k, m, n}. |
| 849 | // The layouts induced by packing will always be: |
| 850 | // - LHS{leading_lhs, kk, mm} |
| 851 | // - RHS{leading_rhs, kk, nn} |
| 852 | // - RES{leading_res, mm, nn} |
| 853 | // If we wanted to change the packed order, we would reorder (k, m, n) to |
| 854 | // something else above. |
| 855 | // |
| 856 | // Additional permutations of the outer dims of the operands (i.e. |
| 857 | // leading_lhs, leading_rhs and leading_res) could follow by computing the |
| 858 | // desired outerPerm for each operand. |
| 859 | // This is left for future work. |
| 860 | |
| 861 | // TODO: this creates too much IR, go use reifyResultShapes. |
| 862 | SmallVector<Range, 4> loopRanges = |
| 863 | cast<LinalgOp>(genericOp.getOperation()) |
| 864 | .createLoopRanges(rewriter, genericOp.getLoc()); |
| 865 | |
| 866 | // Add leading zeros to match numLoops, we only pack the last 3 dimensions |
| 867 | // post interchange. |
| 868 | LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: " |
| 869 | << llvm::interleaved(paddedSizesNextMultipleOf) << "\n" |
| 870 | << "loopRanges: " |
| 871 | << llvm::interleaved(llvm::map_range( |
| 872 | loopRanges, [](Range r) { return r.size; })) |
| 873 | << "\n" ); |
| 874 | SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(), |
| 875 | rewriter.getIndexAttr(0)); |
| 876 | for (int64_t i = 0, e = numPackedDims; i < e; ++i) { |
| 877 | if (paddedSizesNextMultipleOf[i] == 0) { |
| 878 | adjustedPackedSizes.push_back(Elt: packedSizes[i]); |
| 879 | continue; |
| 880 | } |
| 881 | AffineExpr d0, s0; |
| 882 | bindDims(ctx: rewriter.getContext(), exprs&: d0); |
| 883 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0); |
| 884 | adjustedPackedSizes.push_back(Elt: affine::makeComposedFoldedAffineApply( |
| 885 | rewriter, genericOp->getLoc(), d0.ceilDiv(other: s0) * s0, |
| 886 | {loopRanges[adjustedPackedSizes.size()].size, |
| 887 | rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); |
| 888 | } |
| 889 | LLVM_DEBUG(DBGS() << "adjustedPackedSizes: " |
| 890 | << llvm::interleaved(adjustedPackedSizes) << "\n" ); |
| 891 | |
| 892 | // TODO: If we wanted to give the genericOp a name after packing, after |
| 893 | // calling `pack` would be a good time. One would still need to check that |
| 894 | // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we |
| 895 | // also allow degenerate matmul cases (i.e. matvec, dot). |
| 896 | return pack(rewriter, genericOp, adjustedPackedSizes); |
| 897 | } |
| 898 | |
| 899 | //===----------------------------------------------------------------------===// |
| 900 | // Transformations exposed as rewrite patterns. |
| 901 | //===----------------------------------------------------------------------===// |
| 902 | |
| 903 | LinalgTilingOptions & |
| 904 | mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { |
| 905 | assert(!tileSizeComputationFunction && "tile sizes already set" ); |
| 906 | SmallVector<int64_t, 4> tileSizes(ts); |
| 907 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { |
| 908 | OpBuilder::InsertionGuard guard(b); |
| 909 | b.setInsertionPointToStart( |
| 910 | &op->getParentOfType<func::FuncOp>().getBody().front()); |
| 911 | return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { |
| 912 | Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); |
| 913 | return v; |
| 914 | })); |
| 915 | }; |
| 916 | return *this; |
| 917 | } |
| 918 | |
| 919 | LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( |
| 920 | memref::CopyOp copyOp, PatternRewriter &rewriter) const { |
| 921 | return vectorizeCopy(rewriter, copyOp); |
| 922 | } |
| 923 | |
| 924 | /// Filling `dest` using FillOp constant padding value if possible. |
| 925 | /// Otherwise, generate a tensor::GenerateOp. |
| 926 | Value DecomposePadOpPattern::createFillOrGenerateOp( |
| 927 | RewriterBase &rewriter, tensor::PadOp padOp, Value dest, |
| 928 | const SmallVector<Value> &dynSizes) const { |
| 929 | auto padValue = padOp.getConstantPaddingValue(); |
| 930 | if (padValue) { |
| 931 | // Move the padding value defined inside the PadOp block to outside. |
| 932 | if (padValue.getParentBlock() == &padOp.getRegion().front()) |
| 933 | rewriter.moveOpBefore(padValue.getDefiningOp(), padOp); |
| 934 | return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); |
| 935 | } |
| 936 | |
| 937 | // Fill could not be optimized: Lower to tensor::GenerateOp with region. |
| 938 | auto generateOp = rewriter.create<tensor::GenerateOp>( |
| 939 | padOp.getLoc(), padOp.getResultType(), dynSizes); |
| 940 | // Copy region to new op. |
| 941 | IRMapping bvm; |
| 942 | padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); |
| 943 | return generateOp; |
| 944 | } |
| 945 | |
| 946 | LogicalResult |
| 947 | DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, |
| 948 | PatternRewriter &rewriter) const { |
| 949 | // Given an OpFoldResult, return an index-typed value. |
| 950 | auto getIdxValue = [&](OpFoldResult ofr) { |
| 951 | if (auto val = llvm::dyn_cast_if_present<Value>(Val&: ofr)) |
| 952 | return val; |
| 953 | return rewriter |
| 954 | .create<arith::ConstantIndexOp>( |
| 955 | padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) |
| 956 | .getResult(); |
| 957 | }; |
| 958 | |
| 959 | auto resultType = padOp.getResultType(); |
| 960 | // Compute size of EmptyOp. Any combination of static/dynamic is supported. |
| 961 | SmallVector<Value> dynSizes; |
| 962 | SmallVector<int64_t> staticSizes; |
| 963 | for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { |
| 964 | if (resultType.isDynamicDim(dim)) { |
| 965 | auto srcSize = getIdxValue(tensor::getMixedSize(builder&: rewriter, loc: padOp.getLoc(), |
| 966 | value: padOp.getSource(), dim)); |
| 967 | // Add low and high padding value. |
| 968 | auto plusLow = rewriter.createOrFold<arith::AddIOp>( |
| 969 | padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); |
| 970 | auto plusHigh = rewriter.createOrFold<arith::AddIOp>( |
| 971 | padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); |
| 972 | dynSizes.push_back(Elt: plusHigh); |
| 973 | } |
| 974 | staticSizes.push_back(Elt: resultType.getDimSize(dim)); |
| 975 | } |
| 976 | |
| 977 | // Init tensor and fill it with padding. |
| 978 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
| 979 | padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); |
| 980 | Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); |
| 981 | |
| 982 | // Generate a InsertSliceOp for copying the PadOp source. |
| 983 | auto sourceType = padOp.getSourceType(); |
| 984 | // Compute size of source of tensor::PadOp. |
| 985 | SmallVector<OpFoldResult> srcSizes = |
| 986 | tensor::getMixedSizes(builder&: rewriter, loc: padOp.getLoc(), value: padOp.getSource()); |
| 987 | // Strides of InsertSliceOp are all 1. |
| 988 | SmallVector<OpFoldResult> strides(sourceType.getRank(), |
| 989 | rewriter.getIndexAttr(1)); |
| 990 | rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
| 991 | padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, |
| 992 | strides); |
| 993 | |
| 994 | return success(); |
| 995 | } |
| 996 | |
| 997 | LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( |
| 998 | tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { |
| 999 | if (!sliceOp.hasUnitStride()) |
| 1000 | return failure(); |
| 1001 | |
| 1002 | auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>(); |
| 1003 | if (!padOp) |
| 1004 | return failure(); |
| 1005 | |
| 1006 | bool zeroSliceGuard = true; |
| 1007 | if (controlFn) { |
| 1008 | if (std::optional<bool> control = controlFn(sliceOp)) |
| 1009 | zeroSliceGuard = *control; |
| 1010 | else |
| 1011 | return failure(); |
| 1012 | } |
| 1013 | |
| 1014 | FailureOr<TilingResult> tilingResult = |
| 1015 | tensor::bubbleUpPadSlice(b&: rewriter, padOp: padOp, offsets: sliceOp.getMixedOffsets(), |
| 1016 | sizes: sliceOp.getMixedSizes(), generateZeroSliceGuard: zeroSliceGuard); |
| 1017 | if (failed(Result: tilingResult)) |
| 1018 | return failure(); |
| 1019 | |
| 1020 | RankedTensorType sourceType = sliceOp.getSourceType(); |
| 1021 | RankedTensorType resultType = sliceOp.getResultType(); |
| 1022 | |
| 1023 | // If the extract_slice is not rank-reduced, all shapes are static and the |
| 1024 | // data source is actually used. Rewrite into pad(extract_slice(x)). |
| 1025 | if (sourceType.getRank() == resultType.getRank()) { |
| 1026 | rewriter.replaceOp(sliceOp, tilingResult->tiledValues); |
| 1027 | return success(); |
| 1028 | } |
| 1029 | |
| 1030 | // Handle rank-reduced slice by creating another extract_slice op. |
| 1031 | Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1032 | b&: rewriter, loc: sliceOp.getLoc(), tensor: tilingResult->tiledValues[0], targetType: resultType); |
| 1033 | |
| 1034 | rewriter.replaceOp(sliceOp, rankReduced); |
| 1035 | return success(); |
| 1036 | } |
| 1037 | |
| 1038 | /// If padding value is set, returns a tensor.pad Op for the source tensor, |
| 1039 | /// with the output shape matching the output of `packOp`. Otherwise, returns |
| 1040 | /// the source directly. |
| 1041 | /// |
| 1042 | /// This method assumes that all outer dims for this pack Op are 1. |
| 1043 | static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, |
| 1044 | linalg::PackOp packOp) { |
| 1045 | Value input = packOp.getSource(); |
| 1046 | if (!packOp.getPaddingValue()) { |
| 1047 | return input; |
| 1048 | } |
| 1049 | |
| 1050 | assert(llvm::all_of(packOp.getAllOuterDims(), |
| 1051 | [](int64_t val) { return val == 1; }) && |
| 1052 | "some outer dims are != 1" ); |
| 1053 | |
| 1054 | Location loc = packOp.getLoc(); |
| 1055 | ShapedType inputType = packOp.getSourceType(); |
| 1056 | int64_t inputRank = inputType.getRank(); |
| 1057 | |
| 1058 | DenseMap<int64_t, OpFoldResult> tileAndPosMapping = |
| 1059 | packOp.getDimAndTileMapping(); |
| 1060 | |
| 1061 | // The sizes of dynamic tiles |
| 1062 | SmallVector<Value> dynamicTileSizes; |
| 1063 | |
| 1064 | // Collect dims for the padded shape. |
| 1065 | SmallVector<int64_t> paddedShape; |
| 1066 | for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) { |
| 1067 | // 1. Non-tiled outer dims. |
| 1068 | // These dims should be 1 and we simply preserve them. |
| 1069 | if (!tileAndPosMapping.count(Val: dimIdx)) { |
| 1070 | int64_t inputDimSize = inputType.getDimSize(dimIdx); |
| 1071 | assert(inputDimSize == 1 && |
| 1072 | "with all outer dims == 1, this non-tiled input dim should be 1!" ); |
| 1073 | paddedShape.push_back(Elt: inputDimSize); |
| 1074 | continue; |
| 1075 | } |
| 1076 | |
| 1077 | // 2. Tiled outer dims |
| 1078 | // As all outer dims == 1, it is safe to use the tile size for the padded |
| 1079 | // shape. |
| 1080 | OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(Val: dimIdx); |
| 1081 | |
| 1082 | // 2.1 Static tile sizes |
| 1083 | std::optional<int64_t> cstTileSize = getConstantIntValue(ofr: tileSizeForDim); |
| 1084 | if (cstTileSize.has_value()) { |
| 1085 | paddedShape.push_back(Elt: cstTileSize.value()); |
| 1086 | continue; |
| 1087 | } |
| 1088 | |
| 1089 | // 2.2 Dynamic tile sizes |
| 1090 | paddedShape.push_back(ShapedType::kDynamic); |
| 1091 | |
| 1092 | // Get the value that holds the dynamic size. |
| 1093 | dynamicTileSizes.push_back(Elt: llvm::dyn_cast<Value>(Val&: tileSizeForDim)); |
| 1094 | } |
| 1095 | auto resultType = |
| 1096 | RankedTensorType::get(paddedShape, inputType.getElementType()); |
| 1097 | return tensor::createPadHighOp(resType: resultType, source: input, pad: packOp.getPaddingValue(), |
| 1098 | /*nofold=*/false, loc, builder, |
| 1099 | dynOutDims: dynamicTileSizes); |
| 1100 | } |
| 1101 | |
| 1102 | // Normalizes a permutation on a higher rank space to its actual size, e.g. |
| 1103 | // perm = [1, 4, 2] |
| 1104 | // becomes |
| 1105 | // norm = [0, 2, 1] |
| 1106 | static SmallVector<int64_t> |
| 1107 | getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) { |
| 1108 | constexpr int64_t kNonTiledMarker = -1; |
| 1109 | SmallVector<int64_t> vec(rank, kNonTiledMarker); |
| 1110 | for (auto [index, value] : llvm::enumerate(First&: perm)) |
| 1111 | vec[value] = index; |
| 1112 | SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector( |
| 1113 | C&: vec, Pred: [&](int64_t v) { return v != kNonTiledMarker; }); |
| 1114 | // This inverts the permutation in addition to normalizing so invert back. |
| 1115 | return invertPermutationVector(permutation: normalizedPerm); |
| 1116 | } |
| 1117 | |
| 1118 | // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm |
| 1119 | // assuming rank reduction of unit outer dims. |
| 1120 | static SmallVector<int64_t> |
| 1121 | getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, |
| 1122 | ArrayRef<int64_t> innerDimsPos, |
| 1123 | ArrayRef<int64_t> outerDimsPerm) { |
| 1124 | SmallVector<int64_t> rankReducedOuterDimsPerm; |
| 1125 | SmallVector<int64_t> outerDims; |
| 1126 | SmallVector<int64_t> innerDims; |
| 1127 | int64_t dim = 0; |
| 1128 | int64_t unpackedRank = shape.size(); |
| 1129 | for (auto i : llvm::seq<unsigned>(Begin: 0, End: unpackedRank)) { |
| 1130 | if (llvm::is_contained(Range&: innerDimsPos, Element: i)) { |
| 1131 | innerDims.push_back(Elt: dim++); |
| 1132 | continue; |
| 1133 | } |
| 1134 | if (shape[i] == 1) |
| 1135 | continue; |
| 1136 | outerDims.push_back(Elt: dim++); |
| 1137 | if (!outerDimsPerm.empty()) |
| 1138 | rankReducedOuterDimsPerm.push_back(Elt: outerDimsPerm[i]); |
| 1139 | } |
| 1140 | |
| 1141 | // Get the position of the inner dims after permutation. |
| 1142 | SmallVector<int64_t> innerPerm = |
| 1143 | getPackUnpackNormalizedPerm(rank: unpackedRank, perm: innerDimsPos); |
| 1144 | applyPermutationToVector<int64_t>(inVec&: innerDims, permutation: innerPerm); |
| 1145 | |
| 1146 | // Ditto for the outer dims. |
| 1147 | SmallVector<int64_t> perm = outerDims; |
| 1148 | |
| 1149 | rankReducedOuterDimsPerm = |
| 1150 | getPackUnpackNormalizedPerm(rank: unpackedRank, perm: rankReducedOuterDimsPerm); |
| 1151 | if (!rankReducedOuterDimsPerm.empty()) |
| 1152 | applyPermutationToVector<int64_t>(inVec&: perm, permutation: rankReducedOuterDimsPerm); |
| 1153 | |
| 1154 | // The tile always ends up as the inner most dims after packing. |
| 1155 | perm.append(RHS: innerDims); |
| 1156 | |
| 1157 | return perm; |
| 1158 | } |
| 1159 | |
| 1160 | LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( |
| 1161 | linalg::PackOp packOp, PatternRewriter &rewriter) const { |
| 1162 | // TODO: support the case that outer dimensions are not all 1s. A |
| 1163 | // tensor.expand_shape will be generated in this case. |
| 1164 | if (llvm::any_of(packOp.getAllOuterDims(), |
| 1165 | [](int64_t dim) { return dim != 1; })) { |
| 1166 | return rewriter.notifyMatchFailure( |
| 1167 | packOp, "not all outer dimensions of the result are 1s" ); |
| 1168 | } |
| 1169 | |
| 1170 | Attribute zeroIdxAttr = rewriter.getIndexAttr(0); |
| 1171 | Attribute oneIdxAttr = rewriter.getIndexAttr(1); |
| 1172 | Location loc = packOp.getLoc(); |
| 1173 | |
| 1174 | Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); |
| 1175 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 1176 | packOp.getDimAndTileMapping(); |
| 1177 | int64_t srcRank = packOp.getSourceRank(); |
| 1178 | int64_t destRank = packOp.getDestRank(); |
| 1179 | int64_t numTiles = destRank - srcRank; |
| 1180 | |
| 1181 | if (!llvm::all_of(packOp.getInnerDimsPos(), |
| 1182 | [&srcRank, &numTiles](int64_t dimPos) { |
| 1183 | return dimPos >= (srcRank - numTiles - 1); |
| 1184 | })) |
| 1185 | return rewriter.notifyMatchFailure( |
| 1186 | packOp, "Attempting to tile non-trailing source dims!" ); |
| 1187 | |
| 1188 | // 1. Extract the inner tile sizes. |
| 1189 | // Where possible, values are replaced with constant attributes (to match the |
| 1190 | // behaviour of `getPackOpSourceOrPaddedSource`). |
| 1191 | SmallVector<OpFoldResult> tileSizes; |
| 1192 | for (auto i : llvm::seq<unsigned>(0, srcRank)) { |
| 1193 | if (dimAndTileMapping.count(i)) { |
| 1194 | // Rather than taking the tile size as is, extact the actual constant |
| 1195 | // value Attribute where possible, e.g.: |
| 1196 | // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] |
| 1197 | auto [_, tileSize] = |
| 1198 | getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); |
| 1199 | tileSizes.push_back(tileSize); |
| 1200 | } |
| 1201 | } |
| 1202 | |
| 1203 | // 2. Transpose the input to match the inner tile order: |
| 1204 | // %init = tensor.empty() |
| 1205 | // %transposed_tile = linalg.transpose ins(%source_or_padded_source), |
| 1206 | // outs(%init) |
| 1207 | // Two assumptions are made: |
| 1208 | // 1. All outer dims are 1 - the corresponding transposition doesn't matter. |
| 1209 | // 2. Inner dims position correspond to the trailing `numTiles` dims. |
| 1210 | SmallVector<int64_t> tilesPermNormalized = |
| 1211 | getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); |
| 1212 | SmallVector<int64_t> srcPermForTranspose; |
| 1213 | for (int64_t i = 0; i < (srcRank - numTiles); i++) |
| 1214 | srcPermForTranspose.push_back(Elt: i); |
| 1215 | |
| 1216 | srcPermForTranspose.append(RHS: SmallVector<int64_t>(packOp.getInnerDimsPos())); |
| 1217 | |
| 1218 | LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n" |
| 1219 | << "perm: " << llvm::interleaved(srcPermForTranspose) |
| 1220 | << "\n" ); |
| 1221 | |
| 1222 | // 2.1 Create tensor.empty (init value for TransposeOp) |
| 1223 | SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, |
| 1224 | oneIdxAttr); |
| 1225 | transShapeForEmptyOp.append(RHS: tileSizes); |
| 1226 | |
| 1227 | applyPermutationToVector<OpFoldResult>(inVec&: transShapeForEmptyOp, |
| 1228 | permutation: srcPermForTranspose); |
| 1229 | Value empty = rewriter.create<tensor::EmptyOp>( |
| 1230 | loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); |
| 1231 | |
| 1232 | // 2.2 Create linalg.transpose |
| 1233 | auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty, |
| 1234 | srcPermForTranspose); |
| 1235 | |
| 1236 | // 3. Insert the inner tile to the destination: |
| 1237 | // %inserted_tile = tensor.insert_slice(%transposed_tile) |
| 1238 | SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); |
| 1239 | SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); |
| 1240 | // Outer dims are all 1s! |
| 1241 | SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(), |
| 1242 | oneIdxAttr); |
| 1243 | SmallVector<int64_t> writeShape; |
| 1244 | |
| 1245 | for (auto tileSize : packOp.getMixedTiles()) { |
| 1246 | auto [tileSizeStatic, tileSizeOfr] = |
| 1247 | getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); |
| 1248 | writeSizes.push_back(tileSizeOfr); |
| 1249 | writeShape.push_back(tileSizeStatic); |
| 1250 | } |
| 1251 | |
| 1252 | // 4. Replace tensor.packOp with tensor.insert_slice created above |
| 1253 | auto insert = rewriter.create<tensor::InsertSliceOp>( |
| 1254 | loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, |
| 1255 | writeSizes, writeStrides); |
| 1256 | rewriter.replaceOp(packOp, insert.getResult()); |
| 1257 | |
| 1258 | return success(); |
| 1259 | } |
| 1260 | |
| 1261 | LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( |
| 1262 | linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { |
| 1263 | int64_t srcRank = unpackOp.getSourceRank(); |
| 1264 | int64_t destRank = unpackOp.getDestRank(); |
| 1265 | ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); |
| 1266 | ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); |
| 1267 | if (llvm::any_of(unpackOp.getTiledOuterDims(), |
| 1268 | [](int64_t dim) { return dim != 1; })) { |
| 1269 | return rewriter.notifyMatchFailure( |
| 1270 | unpackOp, |
| 1271 | "require the tiled outer dimensions of the result are all 1s" ); |
| 1272 | } |
| 1273 | |
| 1274 | // 1. Use rank-reduced tensor.extract_slice op to extract the tile: |
| 1275 | // %extracted_tile = tensor.extract_slice(%unpack_op_input) |
| 1276 | Location loc = unpackOp.getLoc(); |
| 1277 | Value source = unpackOp.getSource(); |
| 1278 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 1279 | unpackOp.getDimAndTileMapping(); |
| 1280 | Attribute zeroIdxAttr = rewriter.getIndexAttr(0); |
| 1281 | Attribute oneIdxAttr = rewriter.getIndexAttr(1); |
| 1282 | |
| 1283 | // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of |
| 1284 | // dims: |
| 1285 | // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ] |
| 1286 | SmallVector<int64_t> ; |
| 1287 | // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and |
| 1288 | // outer-tiled-dims being all 1), this will be |
| 1289 | // [ outer-untiled-dims, tile-sizes ] |
| 1290 | SmallVector<OpFoldResult> ; |
| 1291 | // The offset and strides attributes for ExtractSliceOp. |
| 1292 | SmallVector<OpFoldResult> (srcRank, zeroIdxAttr); |
| 1293 | SmallVector<OpFoldResult> (srcRank, oneIdxAttr); |
| 1294 | |
| 1295 | // Shape for EmptyOp that's used as the init value for TransposeOp below. |
| 1296 | // This should be: |
| 1297 | // [ outer-untiled-dims, tile-sizes ] |
| 1298 | // However, skip unit dims - TransposeOp (below) applies rank-reduced |
| 1299 | // permutation. |
| 1300 | SmallVector<OpFoldResult> shapeForEmptyOp; |
| 1301 | |
| 1302 | for (auto i : llvm::seq<unsigned>(0, destRank)) { |
| 1303 | // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims. |
| 1304 | // |
| 1305 | // As all outer tiled dims are 1, so the corresponding |
| 1306 | // slice size to read will also 1. As this will be rank-reducing "extract |
| 1307 | // slice" (i.e. the unit dims will be "collapsed"), there's no need to |
| 1308 | // update: |
| 1309 | // * the output shape for ExtractSliceOp, nor |
| 1310 | // * the shape for EmptyOp. |
| 1311 | if (dimAndTileMapping.count(i)) { |
| 1312 | extractSliceSizes.push_back(oneIdxAttr); |
| 1313 | continue; |
| 1314 | } |
| 1315 | |
| 1316 | // Compute sizes attribute for ExtractSliceOp + EmptyOp - |
| 1317 | // outer-untiled-dims |
| 1318 | if (ShapedType::isDynamic(srcShape[i])) { |
| 1319 | OpFoldResult dynamicDim = |
| 1320 | rewriter.create<tensor::DimOp>(loc, source, i).getResult(); |
| 1321 | extractSliceSizes.push_back(dynamicDim); |
| 1322 | shapeForEmptyOp.push_back(dynamicDim); |
| 1323 | } else { |
| 1324 | extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i])); |
| 1325 | if (srcShape[i] != 1) |
| 1326 | shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i])); |
| 1327 | } |
| 1328 | // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take |
| 1329 | // into account rank-reducing) |
| 1330 | if (srcShape[i] != 1) { |
| 1331 | readShapeForExtractSlice.push_back(srcShape[i]); |
| 1332 | } |
| 1333 | } |
| 1334 | // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the |
| 1335 | // shape for EmptyOp. |
| 1336 | auto mixedTiles = unpackOp.getMixedTiles(); |
| 1337 | extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end()); |
| 1338 | shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end()); |
| 1339 | |
| 1340 | // Explicitly create the type for extract_slice op because the inner tile |
| 1341 | // size could be 1. We want to represent the whole inner tile in this case. |
| 1342 | auto tileShape = srcShape.drop_front(N: destRank); |
| 1343 | // Append the inner tile shape to the permuted and rank-reduced outer shape. |
| 1344 | readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); |
| 1345 | Type elemType = unpackOp.getSourceType().getElementType(); |
| 1346 | auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); |
| 1347 | Value innerTile = rewriter.create<tensor::ExtractSliceOp>( |
| 1348 | loc, readType, unpackOp.getSource(), extractSliceOffsets, |
| 1349 | extractSliceSizes, extractSliceStrides); |
| 1350 | |
| 1351 | // 2. Transpose the tile to match the outer corresponding tile order. |
| 1352 | SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( |
| 1353 | srcShape.take_front(N: destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); |
| 1354 | // Unpack is a transition out of packed space so we invert the permutation. |
| 1355 | perm = invertPermutationVector(permutation: perm); |
| 1356 | applyPermutationToVector<OpFoldResult>(inVec&: shapeForEmptyOp, permutation: perm); |
| 1357 | |
| 1358 | Value empty = |
| 1359 | rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType); |
| 1360 | auto transposedOp = |
| 1361 | rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm); |
| 1362 | |
| 1363 | // 3. Handle in-complete tiles if needed. It truncates trailing data from the |
| 1364 | // transposed tile. |
| 1365 | int numLoops = shapeForEmptyOp.size(); |
| 1366 | SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); |
| 1367 | SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); |
| 1368 | SmallVector<OpFoldResult> tileSizes; |
| 1369 | ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); |
| 1370 | for (auto i : llvm::seq<unsigned>(0, destRank)) { |
| 1371 | if (dimAndTileMapping.count(i) || destShape[i] != 1) |
| 1372 | tileSizes.push_back( |
| 1373 | tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); |
| 1374 | } |
| 1375 | |
| 1376 | auto partialTile = rewriter.create<tensor::ExtractSliceOp>( |
| 1377 | loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); |
| 1378 | |
| 1379 | // 4. Insert the result to the destination tensor. |
| 1380 | SmallVector<OpFoldResult> writeSizes; |
| 1381 | SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); |
| 1382 | SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); |
| 1383 | for (int i = 0, idx = 0; i < destRank; ++i) { |
| 1384 | if (dimAndTileMapping.count(Val: i) || destShape[i] != 1) |
| 1385 | writeSizes.push_back(Elt: tileSizes[idx++]); |
| 1386 | else |
| 1387 | writeSizes.push_back(Elt: oneIdxAttr); |
| 1388 | } |
| 1389 | auto insert = rewriter.create<tensor::InsertSliceOp>( |
| 1390 | loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, |
| 1391 | writeStrides); |
| 1392 | rewriter.replaceOp(unpackOp, insert.getResult()); |
| 1393 | |
| 1394 | return success(); |
| 1395 | } |
| 1396 | |
| 1397 | // The following are patterns for downscaling convolution ops with size-1 |
| 1398 | // window dimensions. |
| 1399 | // |
| 1400 | // Note that we'd eventually want to write such transformations in a generic |
| 1401 | // way, e.g., converting to linalg.generic, removing the size-1 dimensions, |
| 1402 | // and then turning back to named ops. But for now it's fine to have a few |
| 1403 | // patterns matching special ops to get started. |
| 1404 | |
| 1405 | template <typename Conv2DOp, typename Conv1DOp> |
| 1406 | FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: |
| 1407 | returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { |
| 1408 | if (convOp.hasPureBufferSemantics()) |
| 1409 | return failure(); // To be implemented. |
| 1410 | |
| 1411 | Value input = convOp.getInputs().front(); |
| 1412 | Value kernel = convOp.getInputs().back(); |
| 1413 | Value output = convOp.getOutputs().front(); |
| 1414 | |
| 1415 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
| 1416 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
| 1417 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
| 1418 | |
| 1419 | auto kernelShape = kernelType.getShape(); |
| 1420 | auto outputShape = outputType.getShape(); |
| 1421 | |
| 1422 | // Get domain indices based on conv2D layout. |
| 1423 | auto [khIndex, kwIndex, ohIndex, owIndex] = |
| 1424 | TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>( |
| 1425 | convOp) |
| 1426 | .Case([&](linalg::Conv2DNhwcHwcfOp op) { |
| 1427 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
| 1428 | }) |
| 1429 | .Case([&](linalg::Conv2DNchwFchwOp op) { |
| 1430 | return std::make_tuple(args: 2, args: 3, args: 2, args: 3); |
| 1431 | }) |
| 1432 | .Case([&](linalg::PoolingNhwcSumOp op) { |
| 1433 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
| 1434 | }) |
| 1435 | .Case([&](linalg::PoolingNchwSumOp op) { |
| 1436 | return std::make_tuple(args: 0, args: 1, args: 2, args: 3); |
| 1437 | }) |
| 1438 | .Case([&](linalg::PoolingNhwcMaxOp op) { |
| 1439 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
| 1440 | }) |
| 1441 | .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { |
| 1442 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
| 1443 | }) |
| 1444 | .Case([&](linalg::PoolingNhwcMinOp op) { |
| 1445 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
| 1446 | }) |
| 1447 | .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { |
| 1448 | return std::make_tuple(args: 0, args: 1, args: 1, args: 2); |
| 1449 | }) |
| 1450 | .Case([&](linalg::PoolingNchwMaxOp op) { |
| 1451 | return std::make_tuple(args: 0, args: 1, args: 2, args: 3); |
| 1452 | }) |
| 1453 | .Default([&](Operation *op) { |
| 1454 | llvm_unreachable("unexpected conv2d/pool2d operation." ); |
| 1455 | return std::make_tuple(args: 0, args: 0, args: 0, args: 0); |
| 1456 | }); |
| 1457 | |
| 1458 | // Only handle the case where at least one of the window dimensions is |
| 1459 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
| 1460 | int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; |
| 1461 | int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; |
| 1462 | bool removeH = (khSize == 1 && ohSize == 1); |
| 1463 | bool removeW = (kwSize == 1 && owSize == 1); |
| 1464 | if (!removeH && !removeW) |
| 1465 | return failure(); |
| 1466 | |
| 1467 | // Get new shapes and types for all operands by removing the size-1 |
| 1468 | // dimension. |
| 1469 | using RTTBuilder = RankedTensorType::Builder; |
| 1470 | RankedTensorType newInputType = |
| 1471 | RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); |
| 1472 | RankedTensorType newKernelType = |
| 1473 | RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); |
| 1474 | RankedTensorType newOutputType = |
| 1475 | RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); |
| 1476 | |
| 1477 | // Rank-reduce operands. |
| 1478 | Location loc = convOp.getLoc(); |
| 1479 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1480 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
| 1481 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1482 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
| 1483 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1484 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
| 1485 | |
| 1486 | // Rank-reduce strides and dilations too. |
| 1487 | // TODO: dropDim 1-liner helper. |
| 1488 | auto strides = |
| 1489 | llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>()); |
| 1490 | strides.erase(strides.begin() + (removeH ? 0 : 1)); |
| 1491 | auto stridesAttr = rewriter.getI64VectorAttr(values: strides); |
| 1492 | |
| 1493 | auto dilations = |
| 1494 | llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>()); |
| 1495 | dilations.erase(dilations.begin() + (removeH ? 0 : 1)); |
| 1496 | auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations); |
| 1497 | |
| 1498 | auto conv1DOp = rewriter.create<Conv1DOp>( |
| 1499 | loc, newOutputType, ValueRange{newInput, newKernel}, |
| 1500 | ValueRange{newOutput}, stridesAttr, dilationsAttr); |
| 1501 | |
| 1502 | // Insert back. |
| 1503 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
| 1504 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
| 1505 | rewriter.replaceOp(convOp, inserted); |
| 1506 | |
| 1507 | return conv1DOp; |
| 1508 | } |
| 1509 | |
| 1510 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp, |
| 1511 | Conv1DNwcWcfOp>; |
| 1512 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp, |
| 1513 | Conv1DNcwFcwOp>; |
| 1514 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, |
| 1515 | PoolingNwcSumOp>; |
| 1516 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, |
| 1517 | PoolingNcwSumOp>; |
| 1518 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, |
| 1519 | PoolingNwcMaxOp>; |
| 1520 | template struct linalg::DownscaleSizeOneWindowed2DConvolution< |
| 1521 | PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; |
| 1522 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, |
| 1523 | PoolingNwcMinOp>; |
| 1524 | template struct linalg::DownscaleSizeOneWindowed2DConvolution< |
| 1525 | PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; |
| 1526 | template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, |
| 1527 | PoolingNcwMaxOp>; |
| 1528 | |
| 1529 | FailureOr<DepthwiseConv1DNwcWcOp> |
| 1530 | DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( |
| 1531 | DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { |
| 1532 | if (convOp.hasPureBufferSemantics()) |
| 1533 | return failure(); // To be implemented. |
| 1534 | |
| 1535 | Value input = convOp.getInputs().front(); |
| 1536 | Value kernel = convOp.getInputs().back(); |
| 1537 | Value output = convOp.getOutputs().front(); |
| 1538 | |
| 1539 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
| 1540 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
| 1541 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
| 1542 | |
| 1543 | auto kernelShape = kernelType.getShape(); |
| 1544 | auto outputShape = outputType.getShape(); |
| 1545 | |
| 1546 | // Only handle the case where at least one of the window dimensions is |
| 1547 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
| 1548 | int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; |
| 1549 | int64_t ohSize = outputShape[1], owSize = outputShape[2]; |
| 1550 | bool removeH = (khSize == 1 && ohSize == 1); |
| 1551 | bool removeW = (kwSize == 1 && owSize == 1); |
| 1552 | if (!removeH && !removeW) |
| 1553 | return failure(); |
| 1554 | |
| 1555 | // Get new shapes and types for all operands by removing the size-1 |
| 1556 | // dimension. |
| 1557 | using RTTBuilder = RankedTensorType::Builder; |
| 1558 | RankedTensorType newInputType = |
| 1559 | RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); |
| 1560 | RankedTensorType newKernelType = |
| 1561 | RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); |
| 1562 | RankedTensorType newOutputType = |
| 1563 | RTTBuilder(outputType).dropDim(removeH ? 1 : 2); |
| 1564 | |
| 1565 | // Rank-reduce operands. |
| 1566 | Location loc = convOp.getLoc(); |
| 1567 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1568 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
| 1569 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1570 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
| 1571 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1572 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
| 1573 | |
| 1574 | // Rank-reduce strides and dilations too. |
| 1575 | // TODO: dropDim 1-liner helper. |
| 1576 | auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>()); |
| 1577 | strides.erase(strides.begin() + (removeH ? 0 : 1)); |
| 1578 | auto stridesAttr = rewriter.getI64VectorAttr(values: strides); |
| 1579 | |
| 1580 | auto dilations = |
| 1581 | llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>()); |
| 1582 | dilations.erase(dilations.begin() + (removeH ? 0 : 1)); |
| 1583 | auto dilationsAttr = rewriter.getI64VectorAttr(values: dilations); |
| 1584 | |
| 1585 | auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( |
| 1586 | loc, newOutputType, ValueRange{newInput, newKernel}, |
| 1587 | ValueRange{newOutput}, stridesAttr, dilationsAttr); |
| 1588 | |
| 1589 | // Insert back. |
| 1590 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
| 1591 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
| 1592 | rewriter.replaceOp(convOp, inserted); |
| 1593 | |
| 1594 | return conv1DOp; |
| 1595 | } |
| 1596 | |
| 1597 | FailureOr<Conv1DOp> |
| 1598 | DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, |
| 1599 | PatternRewriter &rewriter) const { |
| 1600 | if (convOp.hasPureBufferSemantics()) |
| 1601 | return failure(); // To be implemented. |
| 1602 | |
| 1603 | Value input = convOp.getInputs().front(); |
| 1604 | Value kernel = convOp.getInputs().back(); |
| 1605 | Value output = convOp.getOutputs().front(); |
| 1606 | |
| 1607 | auto inputType = dyn_cast<RankedTensorType>(input.getType()); |
| 1608 | auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); |
| 1609 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
| 1610 | |
| 1611 | auto kernelShape = kernelType.getShape(); |
| 1612 | auto outputShape = outputType.getShape(); |
| 1613 | |
| 1614 | // Only handle the case where at least one of the window dimensions is |
| 1615 | // of size 1. Other cases can rely on tiling to reduce to such cases. |
| 1616 | int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; |
| 1617 | int64_t ohSize = outputShape[0], owSize = outputShape[1]; |
| 1618 | bool removeH = (khSize == 1 && ohSize == 1); |
| 1619 | bool removeW = (kwSize == 1 && owSize == 1); |
| 1620 | if (!removeH && !removeW) |
| 1621 | return failure(); |
| 1622 | |
| 1623 | // Get new shapes and types for all operands by removing the size-1 |
| 1624 | // dimension. |
| 1625 | using RTTBuilder = RankedTensorType::Builder; |
| 1626 | RankedTensorType newInputType = |
| 1627 | RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); |
| 1628 | RankedTensorType newKernelType = |
| 1629 | RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); |
| 1630 | RankedTensorType newOutputType = |
| 1631 | RTTBuilder(outputType).dropDim(removeH ? 0 : 1); |
| 1632 | |
| 1633 | // Rank-reduce operands. |
| 1634 | Location loc = convOp.getLoc(); |
| 1635 | Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1636 | b&: rewriter, loc, tensor: input, targetType: newInputType); |
| 1637 | Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1638 | b&: rewriter, loc, tensor: kernel, targetType: newKernelType); |
| 1639 | Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( |
| 1640 | b&: rewriter, loc, tensor: output, targetType: newOutputType); |
| 1641 | |
| 1642 | auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType, |
| 1643 | ValueRange{newInput, newKernel}, |
| 1644 | ValueRange{newOutput}); |
| 1645 | |
| 1646 | // Insert back. |
| 1647 | Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( |
| 1648 | b&: rewriter, loc, tensor: conv1DOp.getResult(0), dest: output); |
| 1649 | rewriter.replaceOp(convOp, inserted); |
| 1650 | |
| 1651 | return conv1DOp; |
| 1652 | } |
| 1653 | |
| 1654 | void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, |
| 1655 | PatternBenefit benefit) { |
| 1656 | patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, |
| 1657 | Conv1DNwcWcfOp>, |
| 1658 | DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, |
| 1659 | Conv1DNcwFcwOp>, |
| 1660 | DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( |
| 1661 | patterns.getContext(), benefit); |
| 1662 | patterns.add< |
| 1663 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>, |
| 1664 | DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>, |
| 1665 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>, |
| 1666 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp, |
| 1667 | PoolingNwcMaxUnsignedOp>, |
| 1668 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>, |
| 1669 | DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp, |
| 1670 | PoolingNwcMinUnsignedOp>, |
| 1671 | DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>( |
| 1672 | patterns.getContext(), benefit); |
| 1673 | } |
| 1674 | |
| 1675 | void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { |
| 1676 | patterns.add<DecomposeOuterUnitDimsPackOpPattern>(arg: patterns.getContext()); |
| 1677 | patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(arg: patterns.getContext()); |
| 1678 | } |
| 1679 | |
| 1680 | void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { |
| 1681 | patterns.add<DecomposePadOpPattern>(arg: patterns.getContext()); |
| 1682 | } |
| 1683 | |