| 1 | //===- DataLayoutPropagation.cpp -----------------------------------------===/// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/Passes.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 13 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 16 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 17 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 18 | #include "mlir/IR/Dominance.h" |
| 19 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 20 | #include "llvm/ADT/SetOperations.h" |
| 21 | #include "llvm/ADT/SetVector.h" |
| 22 | #include "llvm/ADT/TypeSwitch.h" |
| 23 | #include "llvm/Support/Debug.h" |
| 24 | #include <optional> |
| 25 | |
| 26 | namespace mlir { |
| 27 | #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION |
| 28 | #include "mlir/Dialect/Linalg/Passes.h.inc" |
| 29 | } // namespace mlir |
| 30 | |
| 31 | using namespace mlir; |
| 32 | using namespace mlir::linalg; |
| 33 | |
| 34 | #define DEBUG_TYPE "linalg-data-layout-propagation" |
| 35 | |
| 36 | namespace { |
| 37 | |
| 38 | static bool hasGatherSemantics(linalg::GenericOp genericOp) { |
| 39 | for (Operation &op : genericOp.getBody()->getOperations()) |
| 40 | if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) |
| 41 | return true; |
| 42 | return false; |
| 43 | } |
| 44 | |
| 45 | // The struct contains the infomation about mapping packing information to |
| 46 | // the iteration domain of Linalg ops. |
| 47 | struct PackInfo { |
| 48 | int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; |
| 49 | // InnerDimsPos on iteration domain, which follows the order in pack ops. |
| 50 | SmallVector<int64_t> tiledDimsPos; |
| 51 | // The sizes of tiling data dimensions on iteration domain. |
| 52 | llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; |
| 53 | // The mapping from a dimension of iteration domain to the corresponding inner |
| 54 | // tiling dimension on iteration domain. |
| 55 | llvm::DenseMap<int64_t, int64_t> tileToPointMapping; |
| 56 | // The permutation of outer dims (on domain). |
| 57 | SmallVector<int64_t> outerDimsOnDomainPerm; |
| 58 | }; |
| 59 | |
| 60 | template <typename OpTy> |
| 61 | static FailureOr<PackInfo> |
| 62 | getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, |
| 63 | OpTy packOrUnPackOp) { |
| 64 | static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value, |
| 65 | "applies to only pack or unpack operations" ); |
| 66 | LLVM_DEBUG( |
| 67 | { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n" ; }); |
| 68 | |
| 69 | AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); |
| 70 | SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); |
| 71 | SmallVector<utils::IteratorType> iterators = |
| 72 | genericOp.getIteratorTypesArray(); |
| 73 | |
| 74 | PackInfo packInfo; |
| 75 | int64_t origNumDims = indexingMap.getNumDims(); |
| 76 | SmallVector<AffineExpr> exprs(indexingMap.getResults()); |
| 77 | ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); |
| 78 | for (auto [index, innerDimPos, tileSize] : |
| 79 | llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), |
| 80 | innerDimsPos, packOrUnPackOp.getMixedTiles())) { |
| 81 | auto expr = exprs[innerDimPos]; |
| 82 | if (!isa<AffineDimExpr>(expr)) |
| 83 | return failure(); |
| 84 | int64_t domainDimPos = |
| 85 | cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); |
| 86 | if (!isParallelIterator(iterators[domainDimPos])) |
| 87 | return failure(); |
| 88 | packInfo.tiledDimsPos.push_back(domainDimPos); |
| 89 | packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; |
| 90 | packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; |
| 91 | LLVM_DEBUG({ |
| 92 | llvm::dbgs() << "map innerDimPos=" << innerDimPos |
| 93 | << " to iteration dimension (d" << domainDimPos << ", d" |
| 94 | << packInfo.tileToPointMapping[domainDimPos] |
| 95 | << "), which has size=(" |
| 96 | << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n" ; |
| 97 | }); |
| 98 | } |
| 99 | |
| 100 | // Bail out if a tiled dimension is present in a map but not as an affine dim |
| 101 | // expression. |
| 102 | auto areAllAffineDimExpr = [&](int dim) { |
| 103 | for (AffineMap map : indexingMaps) { |
| 104 | if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { |
| 105 | return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); |
| 106 | })) { |
| 107 | return false; |
| 108 | } |
| 109 | } |
| 110 | return true; |
| 111 | }; |
| 112 | for (int64_t i : packInfo.tiledDimsPos) |
| 113 | if (!areAllAffineDimExpr(i)) |
| 114 | return failure(); |
| 115 | |
| 116 | // Get the outer dims perm on the iteration domain. Start by identifying the |
| 117 | // set of domain dims affected by the outer permutation along with the |
| 118 | // permuted ordering for those dims. Then the full outer dims permutation can |
| 119 | // be constructed by replacing the affected dims with the permuted result in a |
| 120 | // numLoops-rank identity. e.g. |
| 121 | // outerDimsPerm = [1, 2, 0] |
| 122 | // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) |
| 123 | // |
| 124 | // permutedOuterDims = [4, 3, 1] |
| 125 | // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] |
| 126 | // |
| 127 | // Non-affine dim expressions must not be permuted by the outer dims |
| 128 | // permutation. |
| 129 | SmallVector<int64_t> permutedOuterDims; |
| 130 | for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { |
| 131 | auto permutedExpr = indexingMap.getResult(idx: dim); |
| 132 | if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { |
| 133 | permutedOuterDims.push_back(dimExpr.getPosition()); |
| 134 | continue; |
| 135 | } |
| 136 | |
| 137 | // TODO: Allow propagation with transposes on non affine dim expressions, |
| 138 | // e.g. d0 + d1 which implies transposing both dims simultaneously while |
| 139 | // maintaining the relative position between them. |
| 140 | if (static_cast<int64_t>(index) != dim) |
| 141 | return failure(); |
| 142 | } |
| 143 | if (!permutedOuterDims.empty()) { |
| 144 | int64_t outerDimIndex = 0; |
| 145 | llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), |
| 146 | permutedOuterDims.end()); |
| 147 | for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) |
| 148 | packInfo.outerDimsOnDomainPerm.push_back( |
| 149 | permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] |
| 150 | : i); |
| 151 | LLVM_DEBUG({ |
| 152 | llvm::dbgs() << "map outer dimsDimsPerm to " ; |
| 153 | for (auto dim : packInfo.outerDimsOnDomainPerm) |
| 154 | llvm::dbgs() << dim << " " ; |
| 155 | llvm::dbgs() << "\n" ; |
| 156 | }); |
| 157 | } |
| 158 | |
| 159 | return packInfo; |
| 160 | } |
| 161 | |
| 162 | static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, |
| 163 | ArrayRef<AffineExpr> exprs) { |
| 164 | // Compute `outer_dims_perm`. See example: |
| 165 | // current exprs : (d0, d1, d2, d3) -> (d2, d3) |
| 166 | // perm : [0, 3, 1, 2] |
| 167 | // First map d2, d3 with their position in the array as: |
| 168 | // currentPositionTileLoops: dim | pos |
| 169 | // d2 | 0 |
| 170 | // d3 | 1 |
| 171 | // then scan `perm` in order and get the `outer_dims_perm` |
| 172 | // to be used, here it would be [1, 0]. |
| 173 | assert(!perm.empty() && "expect perm not to be empty" ); |
| 174 | assert(!exprs.empty() && "expect exprs not to be empty" ); |
| 175 | if (exprs.size() == 1) |
| 176 | return {}; |
| 177 | SmallVector<int64_t> outerDimsPerm; |
| 178 | DenseMap<int64_t, int64_t> currentPositionTileLoops; |
| 179 | for (auto [pos, expr] : llvm::enumerate(exprs)) { |
| 180 | // Here we rely on the assumption that the outer dims permutation |
| 181 | // when propagating currently requires that non-affine dim expressions |
| 182 | // are not permuted, thus allowing the identity assignment below. |
| 183 | if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) |
| 184 | currentPositionTileLoops[dimExpr.getPosition()] = pos; |
| 185 | else |
| 186 | currentPositionTileLoops[pos] = pos; |
| 187 | } |
| 188 | for (int64_t loopIdx : perm) { |
| 189 | if (currentPositionTileLoops.count(loopIdx)) |
| 190 | outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); |
| 191 | } |
| 192 | return outerDimsPerm; |
| 193 | } |
| 194 | |
| 195 | /// Returns a tuple for packed operand and indexing_map with the assumptions: |
| 196 | /// 1) The generic op is the producer of the pack op. |
| 197 | /// 2) The generic op has only one result. |
| 198 | /// If the operand is a scalar or packing dimensions are all irrelevant to the |
| 199 | /// operand, the operand and the updated indexing map will be returned. |
| 200 | /// Otherwise, it returns the packed operand and the updated indexing map. E.g., |
| 201 | /// |
| 202 | /// #map0 = affine_map<(d0, d1) -> (d0, d1)> |
| 203 | /// #map1 = affine_map<(d0, d1) -> (d0)> |
| 204 | /// #map2 = affine_map<(d0, d1) -> (d1)> |
| 205 | /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], |
| 206 | /// iterator_types = ["parallel", "parallel"]} |
| 207 | /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) |
| 208 | /// outs(%init : tensor<?x?xf32>) { |
| 209 | /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): |
| 210 | /// %4 = arith.addf %arg3, %arg4 : f32 |
| 211 | /// linalg.yield %4 : f32 |
| 212 | /// } -> tensor<?x?xf32> |
| 213 | /// %1 = linalg.pack %0 |
| 214 | /// inner_dims_pos = [0, 1] |
| 215 | /// inner_tiles = [8, 2] |
| 216 | /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
| 217 | /// |
| 218 | /// Taking the first input operand as an example, the inner tile size of d1 is |
| 219 | /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> |
| 220 | /// affine_map<(d1, d3)>` will be returned. |
| 221 | /// |
| 222 | /// %pack = linalg.pack %arg0 |
| 223 | /// inner_dims_pos = [0] |
| 224 | /// inner_tiles = [8] |
| 225 | /// into %init : tensor<?xf32> -> tensor<?x8xf32> |
| 226 | static std::tuple<Value, AffineMap> |
| 227 | getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, |
| 228 | GenericOp genericOp, OpOperand *opOperand) { |
| 229 | int64_t numOrigLoops = genericOp.getNumLoops(); |
| 230 | int64_t numInnerLoops = packInfo.getNumTiledLoops(); |
| 231 | int64_t numLoops = numOrigLoops + numInnerLoops; |
| 232 | AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); |
| 233 | llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; |
| 234 | SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); |
| 235 | |
| 236 | // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. |
| 237 | if (genericOp.isScalar(opOperand) || exprs.empty()) |
| 238 | return std::make_tuple(opOperand->get(), |
| 239 | AffineMap::get(numLoops, 0, exprs, b.getContext())); |
| 240 | |
| 241 | // Step 1. Construct the information of packing data dimensions; append inner |
| 242 | // dimensions to the indexing maps for the operand. |
| 243 | for (auto [index, expr] : llvm::enumerate(exprs)) { |
| 244 | if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { |
| 245 | int64_t dimPos = dimExpr.getPosition(); |
| 246 | domainDimToOperandDim[dimPos] = index; |
| 247 | continue; |
| 248 | } |
| 249 | } |
| 250 | SmallVector<int64_t> innerDimsPos; |
| 251 | SmallVector<OpFoldResult> innerTileSizes; |
| 252 | for (auto dimPos : packInfo.tiledDimsPos) { |
| 253 | if (!domainDimToOperandDim.count(dimPos)) |
| 254 | continue; |
| 255 | int64_t index = domainDimToOperandDim[dimPos]; |
| 256 | innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); |
| 257 | innerDimsPos.push_back(index); |
| 258 | exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); |
| 259 | } |
| 260 | |
| 261 | // Step 2. Handle outer dim permutations. |
| 262 | SmallVector<int64_t> outerDimsPerm; |
| 263 | if (!packInfo.outerDimsOnDomainPerm.empty()) { |
| 264 | outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); |
| 265 | |
| 266 | // Step 2.1: Fold transpose into the linalg.generic. |
| 267 | SmallVector<int64_t> inversedOuterPerm = |
| 268 | invertPermutationVector(packInfo.outerDimsOnDomainPerm); |
| 269 | for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { |
| 270 | if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { |
| 271 | int64_t dimPos = dimExpr.getPosition(); |
| 272 | exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); |
| 273 | continue; |
| 274 | } |
| 275 | assert(isa<AffineConstantExpr>(exprs[i]) && |
| 276 | "Attempted to permute non-constant and non-affine dim expression" ); |
| 277 | } |
| 278 | // Step 2.2: Undo the transposition on `exprs` and propagate the |
| 279 | // transposition on the pack using outerDimsPerm. |
| 280 | if (!outerDimsPerm.empty()) { |
| 281 | SmallVector<AffineExpr> auxVec = exprs; |
| 282 | for (const auto &en : enumerate(outerDimsPerm)) |
| 283 | auxVec[en.index()] = exprs[en.value()]; |
| 284 | exprs = auxVec; |
| 285 | } |
| 286 | } |
| 287 | auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); |
| 288 | |
| 289 | // The operand does not have dimensions that relates to pack op. |
| 290 | if (innerDimsPos.empty() && outerDimsPerm.empty()) |
| 291 | return std::make_tuple(opOperand->get(), indexingMap); |
| 292 | |
| 293 | auto empty = linalg::PackOp::createDestinationTensor( |
| 294 | b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); |
| 295 | auto packedOperand = b.create<linalg::PackOp>( |
| 296 | loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, |
| 297 | /*padding=*/std::nullopt, outerDimsPerm); |
| 298 | return std::make_tuple(packedOperand, indexingMap); |
| 299 | } |
| 300 | |
| 301 | /// This function is a helper subroutine to pack a genericOp and return it. It |
| 302 | /// will create a new generic op with the packed operand and the packed output |
| 303 | /// according to packInfo when we attempt to push down unpack or bubble up pack |
| 304 | /// around it. Implicitly this will only work when a packInfo can be obtained. |
| 305 | /// This make sure that we are only using this function on parallel permuted |
| 306 | /// dimensions. |
| 307 | static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, |
| 308 | Value dest, AffineMap packedOutIndexingMap, |
| 309 | const PackInfo &packInfo, |
| 310 | bool isFoldableUnpackPack) { |
| 311 | Location loc = genericOp.getLoc(); |
| 312 | SmallVector<Value> inputOperands; |
| 313 | SmallVector<Value> inputOperandsFromUnpackedSource; |
| 314 | SmallVector<AffineMap> indexingMaps; |
| 315 | auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) { |
| 316 | return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() && |
| 317 | packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() && |
| 318 | llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles()); |
| 319 | }; |
| 320 | for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { |
| 321 | auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( |
| 322 | rewriter, loc, packInfo, genericOp, inputOperand); |
| 323 | auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>(); |
| 324 | auto packOp = packedOperand.getDefiningOp<linalg::PackOp>(); |
| 325 | if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) { |
| 326 | inputOperandsFromUnpackedSource.push_back(unpackOp.getSource()); |
| 327 | } else { |
| 328 | inputOperandsFromUnpackedSource.push_back(packedOperand); |
| 329 | } |
| 330 | inputOperands.push_back(packedOperand); |
| 331 | indexingMaps.push_back(packedIndexingMap); |
| 332 | } |
| 333 | |
| 334 | // If the unpack->pack sequences can be folded, replace use the sources of |
| 335 | // the unpack ops in any unpack->pack chains on the generic op operands. |
| 336 | if (isFoldableUnpackPack) { |
| 337 | inputOperands = inputOperandsFromUnpackedSource; |
| 338 | if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) { |
| 339 | auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>(); |
| 340 | if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) { |
| 341 | dest = destUnPack.getSource(); |
| 342 | } |
| 343 | } |
| 344 | } |
| 345 | |
| 346 | int64_t numInnerLoops = packInfo.getNumTiledLoops(); |
| 347 | SmallVector<utils::IteratorType> iterTypes = |
| 348 | genericOp.getIteratorTypesArray(); |
| 349 | iterTypes.append(numInnerLoops, utils::IteratorType::parallel); |
| 350 | |
| 351 | indexingMaps.push_back(packedOutIndexingMap); |
| 352 | |
| 353 | auto newGenericOp = rewriter.create<linalg::GenericOp>( |
| 354 | loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, |
| 355 | /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); |
| 356 | rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), |
| 357 | newGenericOp.getRegion().begin()); |
| 358 | return newGenericOp; |
| 359 | } |
| 360 | |
| 361 | /// Bubbles up linalg.pack op through a producer generic op. This |
| 362 | /// swap pack(generic) to generic(pack). The new generic op works on packed |
| 363 | /// domain; pack ops are created for input and output operands. E.g., |
| 364 | /// |
| 365 | /// #map0 = affine_map<(d0, d1) -> (d0, d1)> |
| 366 | /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> |
| 367 | /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> |
| 368 | /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> |
| 369 | /// %3 = linalg.generic {indexing_maps = [#map0, #map0], |
| 370 | /// iterator_types = ["parallel", "parallel"]} |
| 371 | /// ins(%arg0 : tensor<?x?xf32>) |
| 372 | /// outs(%2 : tensor<?x?xf32>) { |
| 373 | /// ^bb0(%arg3: f32, %arg4: f32): |
| 374 | /// %4 = arith.addf %arg3, %arg3 : f32 |
| 375 | /// linalg.yield %4 : f32 |
| 376 | /// } -> tensor<?x?xf32> |
| 377 | /// %4 = linalg.pack %3 |
| 378 | /// inner_dims_pos = [0, 1] |
| 379 | /// inner_tiles = [8, 2] |
| 380 | /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
| 381 | /// |
| 382 | /// will be converted to |
| 383 | /// |
| 384 | /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> |
| 385 | /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> |
| 386 | /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 387 | /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> |
| 388 | /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> |
| 389 | /// %0 = affine.apply #map()[%dim] |
| 390 | /// %1 = affine.apply #map1()[%dim_0] |
| 391 | /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> |
| 392 | /// %pack = linalg.pack %arg0 |
| 393 | /// inner_dims_pos = [0, 1] |
| 394 | /// inner_tiles = [8, 2] |
| 395 | /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> |
| 396 | /// %3 = linalg.generic {indexing_maps = [#map2, #map2], |
| 397 | /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 398 | /// ins(%pack : tensor<?x?x8x2xf32>) |
| 399 | /// outs(%arg1 : tensor<?x?x8x2xf32>) { |
| 400 | /// ^bb0(%in: f32, %out: f32): |
| 401 | /// %4 = arith.addf %in, %in : f32 |
| 402 | /// linalg.yield %4 : f32 |
| 403 | /// } -> tensor<?x?x8x2xf32> |
| 404 | static FailureOr<GenericOp> |
| 405 | bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, |
| 406 | const ControlPropagationFn &controlFn) { |
| 407 | auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); |
| 408 | if (!genericOp) |
| 409 | return failure(); |
| 410 | |
| 411 | // User controlled propagation function. |
| 412 | if (!controlFn(&packOp.getSourceMutable())) |
| 413 | return failure(); |
| 414 | |
| 415 | // TODO: Enable propagation in the presence of linalg.index and |
| 416 | // tensor.extract, likely as a separate pattern as the pack information and |
| 417 | // propagation decision needs to be inferred from the region of the generic. |
| 418 | if (hasGatherSemantics(genericOp)) |
| 419 | return failure(); |
| 420 | |
| 421 | // TODO: Relax the restriction. We are able to bubble up the pack op through |
| 422 | // multi-result generic op. It just needs more work. |
| 423 | if (genericOp.getNumResults() != 1) |
| 424 | return failure(); |
| 425 | |
| 426 | // Bail-out if the result of the generic has multiple uses, as bubbling up |
| 427 | // creates recomputation if the generic has multiple users. |
| 428 | // TODO: Enable the case where every use is an identical pack op as no |
| 429 | // recomputation is needed in that case. |
| 430 | if (!genericOp->getResult(0).hasOneUse()) |
| 431 | return failure(); |
| 432 | |
| 433 | // TODO: Add an option for allowing padding values. It could introduce |
| 434 | // undefined behavior if we unconditionally propagate pack op through all |
| 435 | // the ops. E.g., if the padding value is zero and there are division ops in |
| 436 | // a generic op. Some values of padding area could be NaN (0/0). |
| 437 | if (packOp.getPaddingValue()) |
| 438 | return failure(); |
| 439 | |
| 440 | OpOperand *opOperand = genericOp.getDpsInitOperand(0); |
| 441 | auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); |
| 442 | if (failed(packInfo)) |
| 443 | return failure(); |
| 444 | |
| 445 | // We want to move the pack not the generic. |
| 446 | OpBuilder::InsertionGuard guard(rewriter); |
| 447 | rewriter.setInsertionPoint(genericOp); |
| 448 | |
| 449 | // We need to handle two cases: |
| 450 | // 1) The linalg.pack destination is a tensor.empty. If this is the case, we |
| 451 | // create a new tensor.empty to avoid breaking dominance, as we are moving the |
| 452 | // linalg.pack above the linalg.generic. |
| 453 | // 2) The destination is not a tensor.empty. In this case we can replace only |
| 454 | // if the destination of the linalg.pack dominates the linalg.generic. |
| 455 | Value packOpDest = packOp.getDest(); |
| 456 | if (!packOpDest.hasOneUse()) |
| 457 | return failure(); |
| 458 | if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { |
| 459 | packOpDest = rewriter.create<tensor::EmptyOp>( |
| 460 | genericOp->getLoc(), emptyOp.getMixedSizes(), |
| 461 | emptyOp.getType().getElementType()); |
| 462 | } else { |
| 463 | DominanceInfo dom(genericOp); |
| 464 | if (!dom.properlyDominates(packOpDest, genericOp)) |
| 465 | return failure(); |
| 466 | } |
| 467 | |
| 468 | // Rebuild the indexing map for the corresponding init operand. |
| 469 | auto [packedOutOperand, packedOutIndexingMap] = |
| 470 | getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, |
| 471 | genericOp, opOperand); |
| 472 | |
| 473 | // If the dps init operand of the generic is a tensor.empty forward the pack |
| 474 | // op destination. |
| 475 | Value dest = packedOutOperand; |
| 476 | if (auto initTensor = genericOp.getDpsInitOperand(0) |
| 477 | ->get() |
| 478 | .getDefiningOp<tensor::EmptyOp>()) { |
| 479 | dest = packOpDest; |
| 480 | } |
| 481 | // pack(unpack) isn't naively foldable because the unpack op can be from |
| 482 | // an arbitrary domain so we need to keep both. |
| 483 | return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, |
| 484 | *packInfo, /*isFoldableUnpackPack=*/false); |
| 485 | } |
| 486 | |
| 487 | /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. |
| 488 | struct BubbleUpPackOpThroughGenericOpPattern |
| 489 | : public OpRewritePattern<linalg::PackOp> { |
| 490 | public: |
| 491 | BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, |
| 492 | ControlPropagationFn fun) |
| 493 | : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {} |
| 494 | |
| 495 | LogicalResult matchAndRewrite(linalg::PackOp packOp, |
| 496 | PatternRewriter &rewriter) const override { |
| 497 | auto genericOp = |
| 498 | bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); |
| 499 | if (failed(genericOp)) |
| 500 | return failure(); |
| 501 | rewriter.replaceOp(packOp, genericOp->getResults()); |
| 502 | return success(); |
| 503 | } |
| 504 | |
| 505 | private: |
| 506 | ControlPropagationFn controlFn; |
| 507 | }; |
| 508 | |
| 509 | /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to |
| 510 | /// add as many zero padding dimensions in `high` and `low` based on the number |
| 511 | /// of point loops. |
| 512 | class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> { |
| 513 | public: |
| 514 | BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) |
| 515 | : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {} |
| 516 | |
| 517 | LogicalResult matchAndRewrite(linalg::PackOp packOp, |
| 518 | PatternRewriter &rewriter) const override { |
| 519 | auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); |
| 520 | if (!padOp) |
| 521 | return failure(); |
| 522 | |
| 523 | // User controlled propagation function. |
| 524 | if (!controlFn(&packOp.getSourceMutable())) |
| 525 | return failure(); |
| 526 | |
| 527 | // TODO: Enable padding when the padding values are the same. |
| 528 | if (packOp.getPaddingValue()) |
| 529 | return failure(); |
| 530 | |
| 531 | // Fail for non-constant padding values. The body of the pad could |
| 532 | // depend on the padding indices and/or properties of the padded |
| 533 | // tensor so for now we fail. |
| 534 | // TODO: Support non-constant padding values. |
| 535 | Value paddingVal = padOp.getConstantPaddingValue(); |
| 536 | if (!paddingVal) |
| 537 | return failure(); |
| 538 | |
| 539 | if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) |
| 540 | return failure(); |
| 541 | |
| 542 | ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); |
| 543 | |
| 544 | // Bail out if one of the padded dimension is a tiled one. |
| 545 | llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); |
| 546 | llvm::SmallBitVector innerDims(paddedDims.size()); |
| 547 | for (int64_t dim : innerDimsPos) |
| 548 | innerDims.flip(dim); |
| 549 | if (paddedDims.anyCommon(RHS: innerDims)) |
| 550 | return failure(); |
| 551 | |
| 552 | Location loc = padOp->getLoc(); |
| 553 | OpBuilder::InsertionGuard guard(rewriter); |
| 554 | rewriter.setInsertionPoint(padOp); |
| 555 | |
| 556 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
| 557 | SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles(); |
| 558 | auto empty = linalg::PackOp::createDestinationTensor( |
| 559 | rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, |
| 560 | outerDimsPerm); |
| 561 | auto sourcePack = rewriter.create<linalg::PackOp>( |
| 562 | loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, |
| 563 | /*padding=*/std::nullopt, outerDimsPerm); |
| 564 | |
| 565 | // If we have `outer_dims_perms` we need to adjust the padded dimensions. |
| 566 | SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); |
| 567 | SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); |
| 568 | if (!outerDimsPerm.empty()) { |
| 569 | applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); |
| 570 | applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); |
| 571 | } |
| 572 | // The tiled dimensions were verified to be unpadded above, so here we |
| 573 | // just append 0 for the inner tile dimensions. |
| 574 | size_t pointLoopsSize = innerDimsPos.size(); |
| 575 | lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
| 576 | highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
| 577 | |
| 578 | auto newPadOp = rewriter.create<tensor::PadOp>( |
| 579 | loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, |
| 580 | padOp.getNofold()); |
| 581 | |
| 582 | // If the pad has more than one user, create an unpack on the new pad to |
| 583 | // replace the other uses. |
| 584 | if (!padOp->hasOneUse()) { |
| 585 | auto unpackEmpty = linalg::UnPackOp::createDestinationTensor( |
| 586 | rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); |
| 587 | Value unpackedPad = rewriter.create<linalg::UnPackOp>( |
| 588 | loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); |
| 589 | rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); |
| 590 | } |
| 591 | |
| 592 | // Replace the pack with the new pad. |
| 593 | rewriter.replaceOp(packOp, newPadOp.getResult()); |
| 594 | |
| 595 | return success(); |
| 596 | } |
| 597 | |
| 598 | private: |
| 599 | ControlPropagationFn controlFn; |
| 600 | }; |
| 601 | |
| 602 | /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. |
| 603 | /// |
| 604 | /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and |
| 605 | /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the |
| 606 | /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most |
| 607 | /// non-unit projected dims in pos [2, 3] is 2. |
| 608 | /// |
| 609 | /// If all candidates in a reassociation are unit dims, it chooses the |
| 610 | /// inner-most dim pos. |
| 611 | static SmallVector<int64_t> |
| 612 | projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, |
| 613 | ArrayRef<ReassociationIndices> reassocIndices, |
| 614 | ArrayRef<int64_t> targetShape) { |
| 615 | SmallVector<int64_t> projectedDimsPos; |
| 616 | for (auto pos : dimsPos) { |
| 617 | // In the case all dims are unit, this will return the inner-most one. |
| 618 | int64_t projectedPos = reassocIndices[pos].back(); |
| 619 | for (auto i : llvm::reverse(reassocIndices[pos])) { |
| 620 | int64_t dim = targetShape[i]; |
| 621 | if (dim > 1 || ShapedType::isDynamic(dim)) { |
| 622 | projectedPos = i; |
| 623 | break; |
| 624 | } |
| 625 | } |
| 626 | projectedDimsPos.push_back(projectedPos); |
| 627 | } |
| 628 | return projectedDimsPos; |
| 629 | } |
| 630 | |
| 631 | /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. |
| 632 | static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, |
| 633 | ArrayRef<int64_t> shape, |
| 634 | ArrayRef<int64_t> tileSizes) { |
| 635 | for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { |
| 636 | int64_t dim = shape[pos]; |
| 637 | if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) |
| 638 | return false; |
| 639 | } |
| 640 | return true; |
| 641 | } |
| 642 | |
| 643 | /// Permutate the reassociation indices and reindex them in the sequence order. |
| 644 | /// Returns the next dim pos in the sequence. |
| 645 | /// |
| 646 | /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it |
| 647 | /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into |
| 648 | /// [[0], [1, 2]]. |
| 649 | static int64_t applyPermutationAndReindexReassoc( |
| 650 | SmallVector<ReassociationIndices> &reassocIndices, |
| 651 | ArrayRef<int64_t> permutation) { |
| 652 | if (!permutation.empty()) |
| 653 | applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); |
| 654 | int64_t nextPos = 0; |
| 655 | for (ReassociationIndices &indices : reassocIndices) { |
| 656 | for (auto &index : indices) { |
| 657 | index = nextPos; |
| 658 | nextPos += 1; |
| 659 | } |
| 660 | } |
| 661 | return nextPos; |
| 662 | } |
| 663 | |
| 664 | /// Bubble up pack op through collapse shape op when the packed dims can be |
| 665 | /// projected to the dims before collapsing. This is possible when the inner |
| 666 | /// tile sizes can divide the projected dims. |
| 667 | /// |
| 668 | /// For example: |
| 669 | /// |
| 670 | /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] |
| 671 | /// : tensor<?x16x4xf32> into tensor<?x4xf32> |
| 672 | /// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1] |
| 673 | /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty |
| 674 | /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> |
| 675 | /// |
| 676 | /// can be transformed into: |
| 677 | /// |
| 678 | /// %pack = linalg.pack %in outer_dims_perm = [1, 2] |
| 679 | /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty |
| 680 | /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> |
| 681 | /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] |
| 682 | /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> |
| 683 | static LogicalResult |
| 684 | bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, |
| 685 | linalg::PackOp packOp, |
| 686 | PatternRewriter &rewriter) { |
| 687 | SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); |
| 688 | ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); |
| 689 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
| 690 | |
| 691 | ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); |
| 692 | SmallVector<ReassociationIndices> reassocIndices = |
| 693 | collapseOp.getReassociationIndices(); |
| 694 | // Project inner tile pos to the dim pos before collapsing. For example, if |
| 695 | // dims [x, y] is collapsed into [z], packing on dim z can be projected back |
| 696 | // to pack on dim y. |
| 697 | // |
| 698 | // Project to inner-most non-unit dims to increase the chance that they can be |
| 699 | // divided by the inner tile sizes. This is correct because for [..., x, 1], |
| 700 | // packing on dim 1 is equivalent to packing on dim x. |
| 701 | SmallVector<int64_t> projectedInnerDimsPos = |
| 702 | projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); |
| 703 | |
| 704 | if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, |
| 705 | innerTileSizes)) { |
| 706 | return failure(); |
| 707 | } |
| 708 | // Expand the outer dims permutation with the associated source dims for the |
| 709 | // new permutation after bubbling. This is because moving a collapsed dim is |
| 710 | // equivalent to moving the associated source dims together. |
| 711 | SmallVector<int64_t> newOuterDimsPerm; |
| 712 | for (auto outerPos : outerDimsPerm) |
| 713 | llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]); |
| 714 | |
| 715 | auto emptyOp = linalg::PackOp::createDestinationTensor( |
| 716 | rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), |
| 717 | projectedInnerDimsPos, newOuterDimsPerm); |
| 718 | auto newPackOp = rewriter.create<linalg::PackOp>( |
| 719 | packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, |
| 720 | packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); |
| 721 | |
| 722 | SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
| 723 | // First apply the permutation on the reassociations of the outer dims. |
| 724 | // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
| 725 | // -> [[0], [1, 2]] |
| 726 | int64_t nextPos = |
| 727 | applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
| 728 | // Then add direct mapping for the inner tile dims. |
| 729 | for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
| 730 | newReassocIndices.push_back({nextPos}); |
| 731 | nextPos += 1; |
| 732 | } |
| 733 | |
| 734 | auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( |
| 735 | collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); |
| 736 | rewriter.replaceOp(packOp, newCollapseOp); |
| 737 | |
| 738 | return success(); |
| 739 | } |
| 740 | |
| 741 | /// Project dimsPos to their collapsed positions in the reassocIndices. |
| 742 | /// |
| 743 | /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices |
| 744 | /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, |
| 745 | /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos |
| 746 | /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. |
| 747 | static SmallVector<int64_t> |
| 748 | projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, |
| 749 | ArrayRef<ReassociationIndices> reassocIndices) { |
| 750 | SmallVector<int64_t> projectedPos; |
| 751 | |
| 752 | // Map each dimension to the position of corresponding reassociation index. |
| 753 | for (auto pos : dimsPos) { |
| 754 | for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { |
| 755 | // If the dimension is present in the current indices group, the group |
| 756 | // position within the reassociation map is the desired projected |
| 757 | // dimension position. |
| 758 | if (llvm::is_contained(indices, pos)) { |
| 759 | projectedPos.push_back(idx); |
| 760 | break; |
| 761 | } |
| 762 | } |
| 763 | } |
| 764 | assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection" ); |
| 765 | |
| 766 | return projectedPos; |
| 767 | } |
| 768 | |
| 769 | /// Bubble up pack op through expand shape op. |
| 770 | /// |
| 771 | /// For example: |
| 772 | /// |
| 773 | /// %expand = tensor.expand_shape %in [[0], [1, 2]] |
| 774 | /// : tensor<?x64xf32> into tensor<?x4x16xf32> |
| 775 | /// %pack = linalg.pack %expand outer_dims_perm = [0, 1] |
| 776 | /// inner_dims_pos = [2] inner_tiles = [8] into %empty |
| 777 | /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> |
| 778 | /// |
| 779 | /// can be transformed into: |
| 780 | /// |
| 781 | /// %pack = linalg.pack %in outer_dims_perm = [1, 2] |
| 782 | /// inner_dims_pos = [1] inner_tiles = [8] into %empty |
| 783 | /// : tensor<?x64xf32> -> tensor<?x8x8xf32> |
| 784 | /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] |
| 785 | /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> |
| 786 | static LogicalResult |
| 787 | bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, |
| 788 | linalg::PackOp packOp, |
| 789 | PatternRewriter &rewriter) { |
| 790 | // Outer dimensions permutation is not supported currently. |
| 791 | // TODO: Handle outer_dims_perm variants. |
| 792 | ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); |
| 793 | if (!outerDimsPerm.empty() && !isIdentityPermutation(permutation: outerDimsPerm)) { |
| 794 | return rewriter.notifyMatchFailure(packOp, |
| 795 | "non-identity outer dims perm NYI" ); |
| 796 | } |
| 797 | |
| 798 | // Validate dimensions' relations between shape expansion and packing. |
| 799 | SmallVector<ReassociationIndices, 4> reassoc = |
| 800 | expandOp.getReassociationIndices(); |
| 801 | ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); |
| 802 | llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims); |
| 803 | |
| 804 | for (auto [idx, indices] : llvm::enumerate(reassoc)) { |
| 805 | // For each expand_shape reassociation, figure out which dimensions get |
| 806 | // packed if any. |
| 807 | llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices); |
| 808 | llvm::SetVector<int64_t> packedDims = |
| 809 | llvm::set_intersection(packDimsPos, expandDimPos); |
| 810 | |
| 811 | // The expanded dimension is not packed so, it does not affect moving pack |
| 812 | // before shape expansion - simply continue. |
| 813 | if (packedDims.empty()) |
| 814 | continue; |
| 815 | // Shape expansion cannot be propagated when multiple expanded dimension are |
| 816 | // packed - in this case operation reordering would affect final element |
| 817 | // positions and/or shapes can no longer be projected. |
| 818 | if (packedDims.size() != 1) |
| 819 | return rewriter.notifyMatchFailure( |
| 820 | packOp, "only one of the expanded dimensions can be packed" ); |
| 821 | // Only the inner-most expanded dimension should be packed. Otherwise, |
| 822 | // elements order will be affected after operation reordering. |
| 823 | if (packedDims.front() != indices.back()) |
| 824 | return rewriter.notifyMatchFailure( |
| 825 | packOp, "can only pack the inner-most expanded dimension" ); |
| 826 | } |
| 827 | |
| 828 | // Project pack.inner_dims_pos to positions before shape expansion. |
| 829 | SmallVector<int64_t> projectedInnerDimsPos = |
| 830 | projectDimsPosIntoReassocPos(packInnerDims, reassoc); |
| 831 | |
| 832 | // Project the shape expansion to new packed shape. |
| 833 | // The pack.outer_dims_perm is restricted to identity so, the permutation can |
| 834 | // be omitted for simplicity. |
| 835 | // TODO: Account for outer dimensions permutation. |
| 836 | // |
| 837 | // If reassociation is not possible, then reordering cannot happen. |
| 838 | // This can be caused by pack padding affecting previously expanded |
| 839 | // dimensions or packing extending dimensions. |
| 840 | RankedTensorType newPackType = linalg::PackOp::inferPackedType( |
| 841 | expandOp.getSrcType(), packOp.getStaticInnerTiles(), |
| 842 | projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); |
| 843 | auto reassocExpand = |
| 844 | getReassociationIndicesForReshape(newPackType, packOp.getDestType()); |
| 845 | if (!reassocExpand) |
| 846 | return rewriter.notifyMatchFailure( |
| 847 | packOp, "could not reassociate dims after bubbling up" ); |
| 848 | |
| 849 | Value destTensor = linalg::PackOp::createDestinationTensor( |
| 850 | rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), |
| 851 | projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); |
| 852 | Value packedVal = rewriter.create<linalg::PackOp>( |
| 853 | packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, |
| 854 | packOp.getMixedTiles(), packOp.getPaddingValue(), |
| 855 | /*outerDimsPerm=*/SmallVector<int64_t>{}); |
| 856 | |
| 857 | Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( |
| 858 | packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); |
| 859 | rewriter.replaceOp(packOp, newExpandOp); |
| 860 | |
| 861 | return success(); |
| 862 | } |
| 863 | |
| 864 | class BubbleUpPackOpThroughReshapeOp final |
| 865 | : public OpRewritePattern<linalg::PackOp> { |
| 866 | public: |
| 867 | BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) |
| 868 | : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {} |
| 869 | |
| 870 | LogicalResult matchAndRewrite(linalg::PackOp packOp, |
| 871 | PatternRewriter &rewriter) const override { |
| 872 | Operation *srcOp = packOp.getSource().getDefiningOp(); |
| 873 | // Currently only support when the pack op is the only user. |
| 874 | if (!srcOp || !(srcOp->getNumResults() == 1) || |
| 875 | !srcOp->getResult(idx: 0).hasOneUse()) { |
| 876 | return failure(); |
| 877 | } |
| 878 | // Currently only support static inner tile sizes. |
| 879 | if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic)) |
| 880 | return failure(); |
| 881 | |
| 882 | // User controlled propagation function. |
| 883 | if (!controlFn(&packOp.getSourceMutable())) |
| 884 | return failure(); |
| 885 | |
| 886 | return TypeSwitch<Operation *, LogicalResult>(srcOp) |
| 887 | .Case([&](tensor::CollapseShapeOp op) { |
| 888 | return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); |
| 889 | }) |
| 890 | .Case([&](tensor::ExpandShapeOp op) { |
| 891 | return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); |
| 892 | }) |
| 893 | .Default([](Operation *) { return failure(); }); |
| 894 | } |
| 895 | |
| 896 | private: |
| 897 | ControlPropagationFn controlFn; |
| 898 | }; |
| 899 | |
| 900 | /// Push down unpack op through expand shape op when the packed dims can be |
| 901 | /// projected to the dims after expanding. This is possible when the inner tile |
| 902 | /// sizes can divide the projected dims. |
| 903 | /// |
| 904 | /// For example: |
| 905 | /// |
| 906 | /// %unpack = linalg.unpack %in outer_dims_perm = [0, 1] |
| 907 | /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty |
| 908 | /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> |
| 909 | /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] |
| 910 | /// : tensor<?x256xf32> into tensor<?x256x256xf32> |
| 911 | /// |
| 912 | /// can be transformed into: |
| 913 | /// |
| 914 | /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] |
| 915 | /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> |
| 916 | /// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2] |
| 917 | /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty |
| 918 | /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> |
| 919 | static LogicalResult pushDownUnPackOpThroughExpandShape( |
| 920 | linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, |
| 921 | PatternRewriter &rewriter, ControlPropagationFn controlFn) { |
| 922 | // User controlled propagation function. |
| 923 | if (!controlFn(&expandOp.getSrcMutable())) |
| 924 | return failure(); |
| 925 | |
| 926 | SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); |
| 927 | ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); |
| 928 | ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); |
| 929 | |
| 930 | auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); |
| 931 | if (!expandTy) |
| 932 | return failure(); |
| 933 | ArrayRef<int64_t> dstShape = expandTy.getShape(); |
| 934 | SmallVector<ReassociationIndices> reassocIndices = |
| 935 | expandOp.getReassociationIndices(); |
| 936 | // Project inner tile pos to the dim pos after expanding. For example, if dims |
| 937 | // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack |
| 938 | // on dim y. |
| 939 | // |
| 940 | // Project to inner-most non-unit dims to increase the chance that they can be |
| 941 | // divided by the inner tile sizes. This is correct because for [..., x, 1], |
| 942 | // unpacking on dim 1 is equivalent to unpacking on dim x. |
| 943 | SmallVector<int64_t> projectedInnerDimsPos = |
| 944 | projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); |
| 945 | |
| 946 | if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, |
| 947 | innerTileSizes)) { |
| 948 | return failure(); |
| 949 | } |
| 950 | // Expand the outer dims permutation with the associated expanded dims for the |
| 951 | // new permutation after pushing. This is because moving a source dim is |
| 952 | // equivalent to moving the associated expanded dims together. |
| 953 | SmallVector<int64_t> newOuterDimsPerm; |
| 954 | for (auto outerPos : outerDimsPerm) |
| 955 | llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]); |
| 956 | |
| 957 | SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; |
| 958 | // First apply the permutation on the reassociations of the outer dims. |
| 959 | // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] |
| 960 | // -> [[0], [1, 2]] |
| 961 | int64_t nextPos = |
| 962 | applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); |
| 963 | // Then add direct mapping for the inner tile dims. |
| 964 | for (size_t i = 0; i < innerDimsPos.size(); ++i) { |
| 965 | newReassocIndices.push_back({nextPos}); |
| 966 | nextPos += 1; |
| 967 | } |
| 968 | |
| 969 | RankedTensorType newExpandType = linalg::PackOp::inferPackedType( |
| 970 | expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); |
| 971 | auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( |
| 972 | expandOp.getLoc(), newExpandType, unPackOp.getSource(), |
| 973 | newReassocIndices); |
| 974 | |
| 975 | auto emptyOp = linalg::UnPackOp::createDestinationTensor( |
| 976 | rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), |
| 977 | projectedInnerDimsPos, newOuterDimsPerm); |
| 978 | auto newUnPackOp = rewriter.create<linalg::UnPackOp>( |
| 979 | unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, |
| 980 | projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); |
| 981 | rewriter.replaceOp(expandOp, newUnPackOp); |
| 982 | |
| 983 | return success(); |
| 984 | } |
| 985 | |
| 986 | class PushDownUnPackOpThroughReshapeOp final |
| 987 | : public OpRewritePattern<linalg::UnPackOp> { |
| 988 | public: |
| 989 | PushDownUnPackOpThroughReshapeOp(MLIRContext *context, |
| 990 | ControlPropagationFn fun) |
| 991 | : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) { |
| 992 | } |
| 993 | |
| 994 | LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp, |
| 995 | PatternRewriter &rewriter) const override { |
| 996 | Value result = unPackOp.getResult(); |
| 997 | // Currently only support unpack op with the single user. |
| 998 | if (!result.hasOneUse()) { |
| 999 | return failure(); |
| 1000 | } |
| 1001 | // Currently only support static inner tile sizes. |
| 1002 | if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic)) |
| 1003 | return failure(); |
| 1004 | |
| 1005 | Operation *consumerOp = *result.user_begin(); |
| 1006 | return TypeSwitch<Operation *, LogicalResult>(consumerOp) |
| 1007 | .Case([&](tensor::ExpandShapeOp op) { |
| 1008 | return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter, |
| 1009 | controlFn); |
| 1010 | }) |
| 1011 | .Default([](Operation *) { return failure(); }); |
| 1012 | } |
| 1013 | |
| 1014 | private: |
| 1015 | ControlPropagationFn controlFn; |
| 1016 | }; |
| 1017 | |
| 1018 | // TODO: Relax this restriction. We should unpack a generic op also |
| 1019 | // in the presence of multiple unpack ops as producers. |
| 1020 | /// Return the unpacked operand, if present, for the current generic op. |
| 1021 | static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { |
| 1022 | OpOperand *unPackedOperand = nullptr; |
| 1023 | for (OpOperand &operand : genericOp->getOpOperands()) { |
| 1024 | auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>(); |
| 1025 | if (!unPackOp) |
| 1026 | continue; |
| 1027 | if (unPackedOperand) |
| 1028 | return failure(); |
| 1029 | unPackedOperand = &operand; |
| 1030 | } |
| 1031 | if (!unPackedOperand) |
| 1032 | return failure(); |
| 1033 | return unPackedOperand; |
| 1034 | } |
| 1035 | |
| 1036 | /// Push down a linalg.unpack op through a generic op. |
| 1037 | /// The new generic op works on packed domain; pack ops are created for input |
| 1038 | /// and output operands. A linalg.unpack op is inserted right after the packed |
| 1039 | /// generic. E.g. |
| 1040 | /// |
| 1041 | /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> |
| 1042 | /// |
| 1043 | /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. |
| 1044 | /// |
| 1045 | /// %0 = tensor.empty() : tensor<12x56x56x64xf32> |
| 1046 | /// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] |
| 1047 | /// inner_dims_pos = [3] inner_tiles = [32] into %0 |
| 1048 | /// %2 = linalg.generic {indexing_maps = [#map], |
| 1049 | /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} |
| 1050 | /// outs(%1 : tensor<12x56x56x64xf32>) { |
| 1051 | /// ^bb0(%out : f32): |
| 1052 | /// linalg.yield %out : f32 |
| 1053 | /// } -> tensor<12x56x56x64xf32> |
| 1054 | /// |
| 1055 | /// will be converted to |
| 1056 | /// |
| 1057 | /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> |
| 1058 | /// |
| 1059 | /// %0 = tensor.empty() : tensor<12x56x56x64xf32> |
| 1060 | /// %1 = linalg.generic {indexing_maps = [#map], |
| 1061 | /// iterator_types = ["parallel", "parallel", "parallel", |
| 1062 | /// "parallel", "parallel"]} |
| 1063 | /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { |
| 1064 | /// ^bb0(%out : f32): |
| 1065 | /// linalg.yield %out : f32 |
| 1066 | /// } -> tensor<12x2x56x56x32xf32> |
| 1067 | /// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2] |
| 1068 | /// inner_dims_pos = [3] inner_tiles = [32] into %0 |
| 1069 | /// |
| 1070 | static FailureOr<std::tuple<GenericOp, Value>> |
| 1071 | pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, |
| 1072 | ControlPropagationFn controlFn) { |
| 1073 | if (genericOp.getNumResults() != 1) |
| 1074 | return failure(); |
| 1075 | |
| 1076 | if (hasGatherSemantics(genericOp)) |
| 1077 | return failure(); |
| 1078 | |
| 1079 | // Collect the unPacked operand, if present. |
| 1080 | auto maybeUnPackedOperand = getUnPackedOperand(genericOp); |
| 1081 | if (failed(maybeUnPackedOperand)) |
| 1082 | return failure(); |
| 1083 | OpOperand *unPackedOperand = *(maybeUnPackedOperand); |
| 1084 | |
| 1085 | // Extract packing information. |
| 1086 | linalg::UnPackOp producerUnPackOp = |
| 1087 | unPackedOperand->get().getDefiningOp<linalg::UnPackOp>(); |
| 1088 | assert(producerUnPackOp && "expect a valid UnPackOp" ); |
| 1089 | |
| 1090 | if (!controlFn(unPackedOperand)) |
| 1091 | return failure(); |
| 1092 | |
| 1093 | auto packInfo = |
| 1094 | getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); |
| 1095 | if (failed(packInfo)) |
| 1096 | return failure(); |
| 1097 | |
| 1098 | // Rebuild the indexing map for the corresponding init operand. |
| 1099 | auto [packedOutOperand, packedOutIndexingMap] = |
| 1100 | getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, |
| 1101 | genericOp, genericOp.getDpsInitOperand(0)); |
| 1102 | auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>(); |
| 1103 | |
| 1104 | // If the dps init operand of the generic is a tensor.empty, do not pack it |
| 1105 | // and forward the new tensor.empty as a destination. |
| 1106 | Value dest = packedOutOperand; |
| 1107 | if (auto initTensor = genericOp.getDpsInitOperand(0) |
| 1108 | ->get() |
| 1109 | .getDefiningOp<tensor::EmptyOp>()) { |
| 1110 | if (destPack) |
| 1111 | dest = destPack.getDest(); |
| 1112 | } |
| 1113 | |
| 1114 | // Pack the genericOp. |
| 1115 | // pack(unpack) is foldable in this case. This is because in pushing down the |
| 1116 | // unpack, by default we will populate an additional pack op after the unpack. |
| 1117 | // This guarantees them to be foldable. |
| 1118 | GenericOp newGenericOp = |
| 1119 | packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, |
| 1120 | /*isFoldableUnpackPack=*/true); |
| 1121 | Value newResult = |
| 1122 | newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); |
| 1123 | |
| 1124 | // If the output is unaffected, no need to unpack. |
| 1125 | if (!destPack) |
| 1126 | return std::make_tuple(newGenericOp, newResult); |
| 1127 | |
| 1128 | auto mixedTiles = destPack.getMixedTiles(); |
| 1129 | auto innerDimsPos = destPack.getInnerDimsPos(); |
| 1130 | auto outerDimsPerm = destPack.getOuterDimsPerm(); |
| 1131 | |
| 1132 | // Insert an unPackOp right after the packed generic. |
| 1133 | Value unPackOpRes = |
| 1134 | rewriter |
| 1135 | .create<linalg::UnPackOp>(genericOp.getLoc(), newResult, |
| 1136 | destPack.getSource(), innerDimsPos, |
| 1137 | mixedTiles, outerDimsPerm) |
| 1138 | .getResult(); |
| 1139 | |
| 1140 | return std::make_tuple(newGenericOp, unPackOpRes); |
| 1141 | } |
| 1142 | |
| 1143 | // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. |
| 1144 | struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { |
| 1145 | public: |
| 1146 | PushDownUnPackOpThroughGenericOp(MLIRContext *context, |
| 1147 | ControlPropagationFn fun) |
| 1148 | : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} |
| 1149 | |
| 1150 | LogicalResult matchAndRewrite(GenericOp genericOp, |
| 1151 | PatternRewriter &rewriter) const override { |
| 1152 | auto genericAndRepl = |
| 1153 | pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); |
| 1154 | if (failed(genericAndRepl)) |
| 1155 | return failure(); |
| 1156 | rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); |
| 1157 | return success(); |
| 1158 | } |
| 1159 | |
| 1160 | private: |
| 1161 | ControlPropagationFn controlFn; |
| 1162 | }; |
| 1163 | |
| 1164 | /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to |
| 1165 | /// add as many zero padding dimensions in `high` and `low` based on the number |
| 1166 | /// of point loops. |
| 1167 | struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { |
| 1168 | PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) |
| 1169 | : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} |
| 1170 | |
| 1171 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
| 1172 | PatternRewriter &rewriter) const override { |
| 1173 | linalg::UnPackOp unpackOp = |
| 1174 | padOp.getSource().getDefiningOp<linalg::UnPackOp>(); |
| 1175 | if (!unpackOp) |
| 1176 | return failure(); |
| 1177 | |
| 1178 | if (!controlFn(&padOp.getSourceMutable())) |
| 1179 | return failure(); |
| 1180 | |
| 1181 | Location loc = padOp.getLoc(); |
| 1182 | // Bail out if one of the padded dimension is a tiled one. |
| 1183 | llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); |
| 1184 | ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); |
| 1185 | llvm::SmallBitVector innerDims(paddedDims.size()); |
| 1186 | for (int64_t dim : innerDimsPos) |
| 1187 | innerDims.flip(dim); |
| 1188 | if (paddedDims.anyCommon(RHS: innerDims)) |
| 1189 | return failure(); |
| 1190 | |
| 1191 | Value paddingVal = padOp.getConstantPaddingValue(); |
| 1192 | if (!paddingVal) |
| 1193 | return failure(); |
| 1194 | |
| 1195 | // If we have `outer_dims_perms` we need to adjust the padded dimensions. |
| 1196 | ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); |
| 1197 | SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); |
| 1198 | SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); |
| 1199 | if (!outerDimsPerm.empty()) { |
| 1200 | applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); |
| 1201 | applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); |
| 1202 | } |
| 1203 | // Add zero padding for the point loops. |
| 1204 | size_t pointLoopsSize = innerDimsPos.size(); |
| 1205 | lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
| 1206 | highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); |
| 1207 | |
| 1208 | auto newPadOp = rewriter.create<tensor::PadOp>( |
| 1209 | loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, |
| 1210 | paddingVal, padOp.getNofold()); |
| 1211 | |
| 1212 | // Inject the linalg.unpack right after the packed padOp. |
| 1213 | Value outputUnPack = rewriter.create<tensor::EmptyOp>( |
| 1214 | loc, padOp.getResultType().getShape(), |
| 1215 | padOp.getResultType().getElementType()); |
| 1216 | |
| 1217 | Value replacement = rewriter.create<linalg::UnPackOp>( |
| 1218 | loc, newPadOp.getResult(), outputUnPack, innerDimsPos, |
| 1219 | unpackOp.getMixedTiles(), outerDimsPerm); |
| 1220 | rewriter.replaceOp(padOp, replacement); |
| 1221 | return success(); |
| 1222 | } |
| 1223 | |
| 1224 | private: |
| 1225 | ControlPropagationFn controlFn; |
| 1226 | }; |
| 1227 | |
| 1228 | } // namespace |
| 1229 | |
| 1230 | void mlir::linalg::populateDataLayoutPropagationPatterns( |
| 1231 | RewritePatternSet &patterns, |
| 1232 | const ControlPropagationFn &controlPackUnPackPropagation) { |
| 1233 | patterns |
| 1234 | .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, |
| 1235 | BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, |
| 1236 | PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( |
| 1237 | arg: patterns.getContext(), args: controlPackUnPackPropagation); |
| 1238 | } |
| 1239 | |