| 1 | //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" |
| 10 | |
| 11 | #include "mlir/Analysis/SliceAnalysis.h" |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | #include "mlir/Dialect/Affine/Utils.h" |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 17 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 22 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 23 | #include "mlir/Interfaces/TilingInterface.h" |
| 24 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 25 | #include "llvm/Support/Debug.h" |
| 26 | #include <optional> |
| 27 | |
| 28 | #define DEBUG_TYPE "linalg-tiling-interface-impl" |
| 29 | |
| 30 | using namespace mlir; |
| 31 | using namespace mlir::linalg; |
| 32 | |
| 33 | //===----------------------------------------------------------------------===// |
| 34 | // Utility methods for implementation of Tiling Interface for Linalg ops |
| 35 | //===----------------------------------------------------------------------===// |
| 36 | |
| 37 | /// Return the SSA values that represent the data point accessed using a given |
| 38 | /// `indexingMap` for a given point in the iteration space represented by `ivs`. |
| 39 | static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc, |
| 40 | AffineMap indexingMap, |
| 41 | ValueRange ivs) { |
| 42 | SmallVector<Value> indices; |
| 43 | indices.reserve(N: indexingMap.getNumResults()); |
| 44 | for (auto result : indexingMap.getResults()) { |
| 45 | AffineMap m = AffineMap::get(dimCount: indexingMap.getNumDims(), |
| 46 | symbolCount: indexingMap.getNumSymbols(), result); |
| 47 | Value v = b.create<affine::AffineApplyOp>(location: loc, args&: m, args&: ivs); |
| 48 | indices.push_back(Elt: v); |
| 49 | } |
| 50 | return indices; |
| 51 | } |
| 52 | |
| 53 | /// Method to inline the payload of a `linalgOp` given the iteration space |
| 54 | /// point and values for the arguments of the payload. |
| 55 | static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, |
| 56 | ValueRange ivs, ValueRange argValues) { |
| 57 | Block *body = linalgOp.getBlock(); |
| 58 | IRMapping map; |
| 59 | map.map(from: body->getArguments(), to&: argValues); |
| 60 | for (auto &op : body->without_terminator()) { |
| 61 | if (auto indexOp = dyn_cast<IndexOp>(Val: &op)) { |
| 62 | map.map(from: indexOp.getResult(), to: ivs[indexOp.getDim()]); |
| 63 | continue; |
| 64 | } |
| 65 | b.clone(op, mapper&: map); |
| 66 | } |
| 67 | |
| 68 | Operation *terminator = body->getTerminator(); |
| 69 | Location loc = terminator->getLoc(); |
| 70 | for (const auto &operand : llvm::enumerate(First: terminator->getOperands())) { |
| 71 | Value toStore = map.lookupOrDefault(from: operand.value()); |
| 72 | OpOperand *storeInto = linalgOp.getDpsInitOperand(i: operand.index()); |
| 73 | auto indices = getIndicesForAccess( |
| 74 | b, loc, indexingMap: linalgOp.getMatchingIndexingMap(opOperand: storeInto), ivs); |
| 75 | b.create<memref::StoreOp>( |
| 76 | location: loc, args&: toStore, args: linalgOp.getDpsInitOperand(i: operand.index())->get(), |
| 77 | args&: indices); |
| 78 | } |
| 79 | return success(); |
| 80 | } |
| 81 | |
| 82 | //===----------------------------------------------------------------------===// |
| 83 | // External Model for implementing `TilingInterface` for `LinalgOp`s. |
| 84 | //===----------------------------------------------------------------------===// |
| 85 | |
| 86 | namespace { |
| 87 | /// External model implementation of TilingInterface for LinalgOps. An external |
| 88 | /// model implementation is used for now till the use of `TilingInterface` is |
| 89 | /// on-par with the current Linalg tiling + fusion patterns. Once it is |
| 90 | /// maybe possible to move this into the op-definition (though there are |
| 91 | /// advantages to leaving it as an external model) |
| 92 | template <typename LinalgOpTy> |
| 93 | struct LinalgOpTilingInterface |
| 94 | : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, |
| 95 | LinalgOpTy> { |
| 96 | /// Return the loop iterator type. |
| 97 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 98 | LinalgOpTy concreteOp = cast<LinalgOpTy>(op); |
| 99 | return concreteOp.getIteratorTypesArray(); |
| 100 | } |
| 101 | |
| 102 | /// Return the iteration domain range. |
| 103 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 104 | OpBuilder::InsertionGuard g(b); |
| 105 | b.setInsertionPoint(op); |
| 106 | Location loc = op->getLoc(); |
| 107 | LinalgOp linalgOp = cast<LinalgOp>(Val: op); |
| 108 | SmallVector<OpFoldResult> allShapesSizes = |
| 109 | linalgOp.createFlatListOfOperandDims(b, loc); |
| 110 | AffineMap map = linalgOp.getShapesToLoopsMap(); |
| 111 | |
| 112 | return llvm::to_vector( |
| 113 | llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { |
| 114 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
| 115 | b, loc, expr: loopExpr, operands: allShapesSizes); |
| 116 | return Range{.offset: b.getIndexAttr(value: 0), .size: ofr, .stride: b.getIndexAttr(value: 1)}; |
| 117 | })); |
| 118 | } |
| 119 | |
| 120 | /// Instantiate the tiled implementation of the operation. |
| 121 | FailureOr<TilingResult> |
| 122 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 123 | ArrayRef<OpFoldResult> offsets, |
| 124 | ArrayRef<OpFoldResult> sizes) const { |
| 125 | // Leave the `sizeBounds` value empty. That is only needed when the `sizes` |
| 126 | // specified could lead to out of bounds accesses. |
| 127 | Location loc = op->getLoc(); |
| 128 | LinalgOp linalgOp = cast<LinalgOp>(Val: op); |
| 129 | SmallVector<Value> valuesToTile = linalgOp->getOperands(); |
| 130 | SmallVector<Value> tiledOperands = makeTiledShapes( |
| 131 | builder&: b, loc, linalgOp, valuesToTile, ivs: offsets, tileSizes: sizes, sizeBounds: {}, omitPartialTileCheck: true); |
| 132 | SmallVector<Operation *> generatedSlices = llvm::map_to_vector( |
| 133 | llvm::make_filter_range( |
| 134 | tiledOperands, |
| 135 | [](Value v) -> bool { |
| 136 | return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>( |
| 137 | Val: v.getDefiningOp()); |
| 138 | }), |
| 139 | [](Value v) -> Operation * { return v.getDefiningOp(); }); |
| 140 | |
| 141 | SmallVector<Type> resultTensorTypes = |
| 142 | getTensorOutputTypes(op: linalgOp, operands: tiledOperands); |
| 143 | |
| 144 | Operation *tiledOp = clone(b, op: linalgOp, newResultTypes: resultTensorTypes, newOperands: tiledOperands); |
| 145 | offsetIndices(b, linalgOp: cast<LinalgOp>(Val: tiledOp), offests: offsets); |
| 146 | |
| 147 | return TilingResult{ |
| 148 | .tiledOps: {tiledOp}, .tiledValues: SmallVector<Value>(tiledOp->getResults()), .generatedSlices: generatedSlices}; |
| 149 | } |
| 150 | |
| 151 | /// Utility to fetch the offsets and sizes when applied as per the indexing |
| 152 | /// map of the linalg op. This helps in fusing the linalg op as a consumer of |
| 153 | /// a given slice op. |
| 154 | static LogicalResult |
| 155 | getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, |
| 156 | ArrayRef<AffineMap> indexingMaps, |
| 157 | ArrayRef<SmallVector<OpFoldResult>> allOffsets, |
| 158 | ArrayRef<SmallVector<OpFoldResult>> allSizes, |
| 159 | SmallVectorImpl<OpFoldResult> &mappedOffsetsVec, |
| 160 | SmallVectorImpl<OpFoldResult> &mappedSizesVec) { |
| 161 | DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes; |
| 162 | |
| 163 | for (auto [indexingMap, offsets, sizes] : |
| 164 | llvm::zip_equal(t&: indexingMaps, u&: allOffsets, args&: allSizes)) { |
| 165 | for (auto [resultExpr, offset, size] : |
| 166 | llvm::zip_equal(t: indexingMap.getResults(), u: offsets, args: sizes)) { |
| 167 | auto dimExpr = dyn_cast<AffineDimExpr>(Val: resultExpr); |
| 168 | if (!dimExpr) |
| 169 | continue; |
| 170 | unsigned position = dimExpr.getPosition(); |
| 171 | auto it = mappedOffsets.find(Val: position); |
| 172 | if (it != mappedOffsets.end()) { |
| 173 | OpFoldResult seenOffset = it->second; |
| 174 | OpFoldResult seenSize = mappedSizes.lookup(Val: position); |
| 175 | if (seenOffset != offset || seenSize != size) { |
| 176 | LLVM_DEBUG({ |
| 177 | llvm::dbgs() << "inconsistent iteration space mapping from " |
| 178 | "offsets/sizes of operands/results" ; |
| 179 | }); |
| 180 | return failure(); |
| 181 | } |
| 182 | } else { |
| 183 | mappedOffsets[position] = offset; |
| 184 | mappedSizes[position] = size; |
| 185 | } |
| 186 | } |
| 187 | } |
| 188 | |
| 189 | // Aggregate from the given operand offsets and sizes, or default to |
| 190 | // iteration space values. |
| 191 | SmallVector<Range> iterationDomain = |
| 192 | cast<TilingInterface>(Val: linalgOp.getOperation()).getIterationDomain(b); |
| 193 | mappedOffsetsVec.resize(N: iterationDomain.size()); |
| 194 | mappedSizesVec.resize(N: iterationDomain.size()); |
| 195 | for (auto [index, domain] : llvm::enumerate(First&: iterationDomain)) { |
| 196 | auto it = mappedOffsets.find(Val: index); |
| 197 | if (it != mappedOffsets.end()) { |
| 198 | mappedOffsetsVec[index] = it->second; |
| 199 | mappedSizesVec[index] = mappedSizes.lookup(Val: index); |
| 200 | continue; |
| 201 | } |
| 202 | mappedOffsetsVec[index] = domain.offset; |
| 203 | mappedSizesVec[index] = domain.size; |
| 204 | } |
| 205 | return success(); |
| 206 | } |
| 207 | |
| 208 | /// Method to return the position of the result tile computed by the tiled |
| 209 | /// operation. |
| 210 | LogicalResult getIterationDomainTileFromOperandTiles( |
| 211 | Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers, |
| 212 | ArrayRef<SmallVector<OpFoldResult>> allOffsets, |
| 213 | ArrayRef<SmallVector<OpFoldResult>> allSizes, |
| 214 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
| 215 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
| 216 | auto linalgOp = cast<LinalgOp>(Val: op); |
| 217 | |
| 218 | std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets, |
| 219 | iterationSpaceSizes; |
| 220 | SmallVector<AffineMap> indexingMaps = |
| 221 | llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { |
| 222 | OpOperand &opOperand = linalgOp->getOpOperand(idx: operandNumber); |
| 223 | return linalgOp.getMatchingIndexingMap(opOperand: &opOperand); |
| 224 | }); |
| 225 | if (failed(Result: getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets, |
| 226 | allSizes, mappedOffsetsVec&: iterDomainOffsets, |
| 227 | mappedSizesVec&: iterDomainSizes))) { |
| 228 | return failure(); |
| 229 | } |
| 230 | return success(); |
| 231 | } |
| 232 | |
| 233 | /// Return the details of the output tile generated by the tiled |
| 234 | /// implementation. |
| 235 | LogicalResult |
| 236 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 237 | ArrayRef<OpFoldResult> offsets, |
| 238 | ArrayRef<OpFoldResult> sizes, |
| 239 | SmallVector<OpFoldResult> &resultOffsets, |
| 240 | SmallVector<OpFoldResult> &resultSizes) const { |
| 241 | Location loc = op->getLoc(); |
| 242 | LinalgOp linalgOp = cast<LinalgOp>(Val: op); |
| 243 | |
| 244 | AffineExpr d0; |
| 245 | bindDims(ctx: b.getContext(), exprs&: d0); |
| 246 | SmallVector<OpFoldResult> subShapeSizes = |
| 247 | llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { |
| 248 | return affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - 1, operands: ofr); |
| 249 | })); |
| 250 | |
| 251 | OpOperand *outOperand = linalgOp.getDpsInitOperand(i: resultNumber); |
| 252 | SliceParameters sliceParams = computeSliceParameters( |
| 253 | builder&: b, loc, valueToTile: outOperand->get(), tileSizes: sizes, |
| 254 | map: linalgOp.getMatchingIndexingMap(opOperand: outOperand), lbs: offsets, |
| 255 | /*ubs*/ {}, subShapeSizes, omitPartialTileCheck: true); |
| 256 | resultOffsets = sliceParams.offsets; |
| 257 | resultSizes = sliceParams.sizes; |
| 258 | return success(); |
| 259 | } |
| 260 | |
| 261 | LogicalResult getIterationDomainTileFromResultTile( |
| 262 | Operation *op, OpBuilder &b, unsigned resultNumber, |
| 263 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 264 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
| 265 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
| 266 | auto linalgOp = cast<LinalgOp>(Val: op); |
| 267 | |
| 268 | // Check that the indexing map used for the output is a projected |
| 269 | // permutation. This could be relaxed with a more general approach that can |
| 270 | // map the offsets and sizes from the result to iteration space tiles |
| 271 | // (filling in full extent for dimensions not used to access the result). |
| 272 | AffineMap indexingMap = |
| 273 | linalgOp.getIndexingMapMatchingResult(result: op->getResult(idx: resultNumber)); |
| 274 | if (!indexingMap.isProjectedPermutation()) { |
| 275 | return op->emitOpError( |
| 276 | message: "unhandled tiled implementation generation when result is not " |
| 277 | "accessed using a permuted projection" ); |
| 278 | } |
| 279 | |
| 280 | SmallVector<OpFoldResult> allOffsets = llvm::to_vector(Range&: offsets); |
| 281 | SmallVector<OpFoldResult> allSizes = llvm::to_vector(Range&: sizes); |
| 282 | auto status = |
| 283 | getMappedOffsetAndSize(linalgOp, b, indexingMaps: indexingMap, allOffsets: {allOffsets}, |
| 284 | allSizes: {allSizes}, mappedOffsetsVec&: iterDomainOffsets, mappedSizesVec&: iterDomainSizes); |
| 285 | (void)status; |
| 286 | assert(succeeded(status) && "unexpected error in offset calculation" ); |
| 287 | return success(); |
| 288 | } |
| 289 | |
| 290 | FailureOr<TilingResult> |
| 291 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 292 | ArrayRef<OpFoldResult> offsets, |
| 293 | ArrayRef<OpFoldResult> sizes) const { |
| 294 | SmallVector<OpFoldResult> mappedOffsets, mappedSizes; |
| 295 | if (failed(getIterationDomainTileFromResultTile( |
| 296 | op, b, resultNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets, iterDomainSizes&: mappedSizes))) { |
| 297 | return failure(); |
| 298 | } |
| 299 | auto tilingInterfaceOp = cast<TilingInterface>(Val: op); |
| 300 | FailureOr<TilingResult> tilingResult = |
| 301 | tilingInterfaceOp.getTiledImplementation(b, offsets: mappedOffsets, sizes: mappedSizes); |
| 302 | |
| 303 | if (failed(Result: tilingResult)) |
| 304 | return failure(); |
| 305 | |
| 306 | if (tilingResult->tiledOps.size() != 1) |
| 307 | return op->emitOpError(message: "failed to generate tiled implementation" ); |
| 308 | |
| 309 | return TilingResult{ |
| 310 | .tiledOps: tilingResult->tiledOps, |
| 311 | .tiledValues: SmallVector<Value>{tilingResult->tiledValues[resultNumber]}, |
| 312 | .generatedSlices: tilingResult->generatedSlices}; |
| 313 | } |
| 314 | |
| 315 | /// Method to generate the tiled implementation of an operation from the tile |
| 316 | /// of the operand. |
| 317 | FailureOr<TilingResult> getTiledImplementationFromOperandTiles( |
| 318 | Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers, |
| 319 | ArrayRef<SmallVector<OpFoldResult>> allOffsets, |
| 320 | ArrayRef<SmallVector<OpFoldResult>> allSizes) const { |
| 321 | SmallVector<OpFoldResult> mappedOffsets, mappedSizes; |
| 322 | if (failed(getIterationDomainTileFromOperandTiles( |
| 323 | op, b, operandNumbers, allOffsets, allSizes, iterDomainOffsets&: mappedOffsets, |
| 324 | iterDomainSizes&: mappedSizes))) { |
| 325 | return failure(); |
| 326 | } |
| 327 | return getTiledImplementation(op, b, offsets: mappedOffsets, sizes: mappedSizes); |
| 328 | } |
| 329 | |
| 330 | LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, |
| 331 | Location loc, |
| 332 | ValueRange ivs) const { |
| 333 | auto linalgOp = cast<LinalgOp>(Val: op); |
| 334 | if (!linalgOp.hasPureBufferSemantics()) |
| 335 | return op->emitOpError(message: "expected operation to have buffer semantics" ); |
| 336 | |
| 337 | SmallVector<Value> indexedValues; |
| 338 | indexedValues.reserve(N: linalgOp->getNumOperands()); |
| 339 | Location linalgOpLoc = op->getLoc(); |
| 340 | /// Load the data corresponding to the block arguments that |
| 341 | /// represent input operands. |
| 342 | for (OpOperand &operand : linalgOp->getOpOperands()) { |
| 343 | if (!linalgOp.payloadUsesValueFromOperand(opOperand: &operand)) { |
| 344 | indexedValues.push_back(Elt: nullptr); |
| 345 | continue; |
| 346 | } |
| 347 | if (linalgOp.isScalar(opOperand: &operand)) { |
| 348 | indexedValues.push_back(Elt: operand.get()); |
| 349 | continue; |
| 350 | } |
| 351 | SmallVector<Value> indices = getIndicesForAccess( |
| 352 | b&: builder, loc: linalgOpLoc, indexingMap: linalgOp.getMatchingIndexingMap(opOperand: &operand), ivs); |
| 353 | Value load = |
| 354 | builder.create<memref::LoadOp>(location: linalgOpLoc, args: operand.get(), args&: indices); |
| 355 | indexedValues.push_back(Elt: load); |
| 356 | } |
| 357 | |
| 358 | /// Inline the op payload and store the result. |
| 359 | return inlinePayload(b&: builder, linalgOp, ivs, argValues: indexedValues); |
| 360 | } |
| 361 | }; |
| 362 | |
| 363 | //===----------------------------------------------------------------------===// |
| 364 | // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. |
| 365 | //===----------------------------------------------------------------------===// |
| 366 | |
| 367 | /// In a given set vector, get the position of a particular element. |
| 368 | std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims, |
| 369 | unsigned value) { |
| 370 | for (auto [index, reductionDim] : llvm::enumerate(First: reductionDims)) { |
| 371 | if (reductionDim == value) { |
| 372 | return index; |
| 373 | } |
| 374 | } |
| 375 | return std::nullopt; |
| 376 | } |
| 377 | |
| 378 | /// Return an AffineMaps to use for the `outs` operands of the linalg op |
| 379 | /// generated for partial results. The new AffineMap is the AffineMap of the |
| 380 | /// untiled op with reduction dimensions appended at end in order in which they |
| 381 | /// were specified during tiling. |
| 382 | static SmallVector<AffineMap> |
| 383 | getPartialResultAffineMaps(LinalgOp linalgOp, |
| 384 | const SetVector<unsigned> &reductionDims) { |
| 385 | auto partialReductionMaps = llvm::map_to_vector( |
| 386 | C: linalgOp.getDpsInitsMutable(), F: [&](OpOperand &opOperand) { |
| 387 | AffineMap map = linalgOp.getMatchingIndexingMap(opOperand: &opOperand); |
| 388 | for (auto redPos : reductionDims) { |
| 389 | map = |
| 390 | map.insertResult(expr: getAffineDimExpr(position: redPos, context: linalgOp.getContext()), |
| 391 | pos: map.getNumResults()); |
| 392 | } |
| 393 | return map; |
| 394 | }); |
| 395 | return partialReductionMaps; |
| 396 | } |
| 397 | |
| 398 | struct InitSliceInfo { |
| 399 | SmallVector<int64_t> resultShape; |
| 400 | SmallVector<OpFoldResult> offsets; |
| 401 | SmallVector<OpFoldResult> sizes; |
| 402 | SmallVector<OpFoldResult> strides; |
| 403 | }; |
| 404 | |
| 405 | /// Return the result shape, offsets, sizes and strides of the slice of the |
| 406 | /// `initValue` to use as the destination of the partial reduction op generated |
| 407 | /// with outer reduction strategy. |
| 408 | static InitSliceInfo getInitSliceInfoForOuterReduction( |
| 409 | MLIRContext *context, ArrayRef<OpFoldResult> offsets, |
| 410 | ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims, |
| 411 | ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) { |
| 412 | int64_t initRank = partialReductionMap.getNumResults(); |
| 413 | SmallVector<OpFoldResult> initOffsets, initSizes; |
| 414 | Attribute zero = IntegerAttr::get(type: IndexType::get(context), value: 0); |
| 415 | Attribute one = IntegerAttr::get(type: IndexType::get(context), value: 1); |
| 416 | SmallVector<OpFoldResult> initStrides(initRank, one); |
| 417 | for (AffineExpr dimExpr : partialReductionMap.getResults()) { |
| 418 | unsigned dim = cast<AffineDimExpr>(Val&: dimExpr).getPosition(); |
| 419 | if (reductionDims.contains(key: dim)) { |
| 420 | initOffsets.push_back(Elt: zero); |
| 421 | } else { |
| 422 | initOffsets.push_back(Elt: offsets[dim]); |
| 423 | } |
| 424 | initSizes.push_back(Elt: sizes[dim]); |
| 425 | } |
| 426 | SmallVector<int64_t> resultShape; |
| 427 | std::tie(args&: resultShape, args: std::ignore) = decomposeMixedValues(mixedValues: initSizes); |
| 428 | return {.resultShape: resultShape, .offsets: initOffsets, .sizes: initSizes, .strides: initStrides}; |
| 429 | } |
| 430 | |
| 431 | /// Return the result shape, offsets, sizes and strides of the slice of the |
| 432 | /// `initValue` to use as destination of the partial reduction op generated with |
| 433 | /// outer parallel strategy. |
| 434 | static InitSliceInfo getInitSliceInfoForOuterParallel( |
| 435 | MLIRContext *context, ArrayRef<OpFoldResult> offsets, |
| 436 | ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims, |
| 437 | ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) { |
| 438 | int64_t initRank = partialReductionMap.getNumResults(); |
| 439 | SmallVector<OpFoldResult> initOffsets, initSizes; |
| 440 | Attribute one = IntegerAttr::get(type: IndexType::get(context), value: 1); |
| 441 | SmallVector<OpFoldResult> initStrides(initRank, one); |
| 442 | SmallVector<OpFoldResult> resultShape; |
| 443 | for (AffineExpr dimExpr : partialReductionMap.getResults()) { |
| 444 | unsigned dim = cast<AffineDimExpr>(Val&: dimExpr).getPosition(); |
| 445 | if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, value: dim)) { |
| 446 | initOffsets.push_back(Elt: splitReductionIvs[dimPos.value()]); |
| 447 | initSizes.push_back(Elt: one); |
| 448 | } else { |
| 449 | initOffsets.push_back(Elt: offsets[dim]); |
| 450 | initSizes.push_back(Elt: sizes[dim]); |
| 451 | resultShape.push_back(Elt: sizes[dim]); |
| 452 | } |
| 453 | } |
| 454 | SmallVector<int64_t> staticShapes; |
| 455 | std::tie(args&: staticShapes, args: std::ignore) = decomposeMixedValues(mixedValues: resultShape); |
| 456 | return {.resultShape: staticShapes, .offsets: initOffsets, .sizes: initSizes, .strides: initStrides}; |
| 457 | } |
| 458 | |
| 459 | /// Return the result shape, offsets, sizes and strides of the slice of the |
| 460 | /// `initValue` to use as destination of the partial reduction op. |
| 461 | static InitSliceInfo getInitSliceInfo(MLIRContext *context, |
| 462 | ReductionTilingStrategy strategy, |
| 463 | ArrayRef<OpFoldResult> offsets, |
| 464 | ArrayRef<OpFoldResult> sizes, |
| 465 | const SetVector<unsigned> &reductionDims, |
| 466 | ArrayRef<OpFoldResult> splitReductionIvs, |
| 467 | AffineMap partialReductionMap) { |
| 468 | if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) { |
| 469 | return getInitSliceInfoForOuterReduction(context, offsets, sizes, |
| 470 | reductionDims, splitReductionIvs, |
| 471 | partialReductionMap); |
| 472 | } |
| 473 | assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel && |
| 474 | "unexpected ReductionTilingStrategy" ); |
| 475 | return getInitSliceInfoForOuterParallel(context, offsets, sizes, |
| 476 | reductionDims, splitReductionIvs, |
| 477 | partialReductionMap); |
| 478 | } |
| 479 | |
| 480 | /// External model implementation of PartialReductionInterface for |
| 481 | /// LinalgOps. |
| 482 | template <typename LinalgOpTy> |
| 483 | struct LinalgOpPartialReductionInterface |
| 484 | : public PartialReductionOpInterface::ExternalModel< |
| 485 | LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> { |
| 486 | FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction( |
| 487 | Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes, |
| 488 | const SetVector<unsigned> &reductionDims) const { |
| 489 | auto linalgOp = cast<LinalgOp>(Val: op); |
| 490 | |
| 491 | OpBuilder::InsertionGuard guard(b); |
| 492 | if (linalgOp.hasPureBufferSemantics()) |
| 493 | return op->emitOpError(message: "expected operation to have tensor semantics" ); |
| 494 | |
| 495 | SmallVector<AffineMap> partialResultMaps = |
| 496 | getPartialResultAffineMaps(linalgOp, reductionDims); |
| 497 | |
| 498 | SmallVector<Value> inits; |
| 499 | for (auto [initIdx, result, partialMap] : |
| 500 | llvm::enumerate(First: linalgOp->getResults(), Rest&: partialResultMaps)) { |
| 501 | SmallVector<Operation *, 4> combinerOps; |
| 502 | if (!matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: initIdx, |
| 503 | combinerOps) || |
| 504 | combinerOps.size() != 1) |
| 505 | return op->emitOpError(message: "Failed to anaysis the reduction operation." ); |
| 506 | |
| 507 | Operation *reductionOp = combinerOps[0]; |
| 508 | std::optional<TypedAttr> identity = arith::getNeutralElement(op: reductionOp); |
| 509 | if (!identity.has_value()) |
| 510 | return op->emitOpError( |
| 511 | message: "Failed to get an identity value for the reduction operation." ); |
| 512 | |
| 513 | // Append the new partial result dimensions. |
| 514 | SmallVector<OpFoldResult> partialResultShape; |
| 515 | for (AffineExpr dimExpr : partialMap.getResults()) { |
| 516 | auto dim = cast<AffineDimExpr>(Val&: dimExpr); |
| 517 | partialResultShape.push_back(Elt: sizes[dim.getPosition()]); |
| 518 | } |
| 519 | |
| 520 | Type elType = getElementTypeOrSelf(type: result.getType()); |
| 521 | Value emptyTensor = |
| 522 | b.create<tensor::EmptyOp>(location: loc, args&: partialResultShape, args&: elType); |
| 523 | Value constantOp = b.create<arith::ConstantOp>(location: loc, args&: *identity); |
| 524 | auto identityTensor = |
| 525 | b.create<linalg::FillOp>(location: loc, args&: constantOp, args&: emptyTensor); |
| 526 | inits.push_back(Elt: identityTensor.getResult(i: 0)); |
| 527 | } |
| 528 | |
| 529 | return inits; |
| 530 | } |
| 531 | |
| 532 | FailureOr<TilingResult> |
| 533 | tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, |
| 534 | ReductionTilingStrategy tilingStrategy, |
| 535 | ValueRange init, ArrayRef<OpFoldResult> offsets, |
| 536 | ArrayRef<OpFoldResult> sizes, |
| 537 | const SetVector<unsigned> &reductionDims, |
| 538 | ArrayRef<OpFoldResult> splitReductionIvs) const { |
| 539 | OpBuilder::InsertionGuard guard(b); |
| 540 | auto linalgOp = cast<LinalgOp>(Val: op); |
| 541 | |
| 542 | SmallVector<AffineMap> partialReductionMaps = |
| 543 | getPartialResultAffineMaps(linalgOp, reductionDims); |
| 544 | |
| 545 | // Step 1. Extend init maps to have reduction dimension dims, since we |
| 546 | // are converting them to parallel dimensions. |
| 547 | SmallVector<AffineMap> newInitMaps; |
| 548 | if (tilingStrategy == |
| 549 | ReductionTilingStrategy::PartialReductionOuterReduction) { |
| 550 | newInitMaps = llvm::to_vector(Range&: partialReductionMaps); |
| 551 | } else { |
| 552 | newInitMaps = llvm::map_to_vector( |
| 553 | linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) { |
| 554 | return linalgOp.getMatchingIndexingMap(opOperand: &opOperand); |
| 555 | }); |
| 556 | } |
| 557 | |
| 558 | // Step 2a: Extract a slice of the input operands. |
| 559 | SmallVector<Value> tiledInputs = makeTiledShapes( |
| 560 | builder&: b, loc, linalgOp, valuesToTile: linalgOp.getDpsInputs(), ivs: offsets, tileSizes: sizes, sizeBounds: {}, omitPartialTileCheck: true); |
| 561 | SmallVector<Operation *> generatedSlices = llvm::map_to_vector( |
| 562 | llvm::make_filter_range( |
| 563 | tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }), |
| 564 | [](Value v) -> Operation * { return v.getDefiningOp(); }); |
| 565 | |
| 566 | // Step 2b: Extract a slice of the init operands. |
| 567 | SmallVector<Value, 1> tiledInits; |
| 568 | for (auto [partialReductionMap, valueToTile] : |
| 569 | llvm::zip_equal(t&: partialReductionMaps, u&: init)) { |
| 570 | InitSliceInfo sliceInfo = getInitSliceInfo( |
| 571 | context: b.getContext(), strategy: tilingStrategy, offsets, sizes, reductionDims, |
| 572 | splitReductionIvs, partialReductionMap); |
| 573 | auto valueToTileType = cast<RankedTensorType>(Val: valueToTile.getType()); |
| 574 | RankedTensorType sliceResultType = RankedTensorType::get( |
| 575 | shape: sliceInfo.resultShape, elementType: valueToTileType.getElementType(), |
| 576 | encoding: valueToTileType.getEncoding()); |
| 577 | auto sliceOp = b.create<tensor::ExtractSliceOp>( |
| 578 | location: loc, args&: sliceResultType, args&: valueToTile, args&: sliceInfo.offsets, args&: sliceInfo.sizes, |
| 579 | args&: sliceInfo.strides); |
| 580 | tiledInits.push_back(Elt: sliceOp.getResult()); |
| 581 | generatedSlices.push_back(Elt: sliceOp); |
| 582 | } |
| 583 | |
| 584 | // Update the indexing maps. |
| 585 | SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray(); |
| 586 | for (auto [initOperand, newInitMap] : |
| 587 | llvm::zip_equal(t: linalgOp.getDpsInitsMutable(), u&: newInitMaps)) { |
| 588 | int mapIdx = linalgOp.getIndexingMapIndex(opOperand: &initOperand); |
| 589 | newMaps[mapIdx] = newInitMap; |
| 590 | } |
| 591 | |
| 592 | // Step 3. Change the reduction dim iterator types. |
| 593 | SmallVector<utils::IteratorType> newIteratorTypes = |
| 594 | linalgOp.getIteratorTypesArray(); |
| 595 | if (tilingStrategy == |
| 596 | ReductionTilingStrategy::PartialReductionOuterReduction) { |
| 597 | for (int dim : reductionDims) |
| 598 | newIteratorTypes[dim] = utils::IteratorType::parallel; |
| 599 | } |
| 600 | |
| 601 | // Step 4. Create the new generic op. |
| 602 | Operation *partialReductionOp; |
| 603 | auto resultTypes = ValueRange(tiledInits).getTypes(); |
| 604 | if (tilingStrategy == |
| 605 | ReductionTilingStrategy::PartialReductionOuterReduction) { |
| 606 | auto genericOp = b.create<GenericOp>( |
| 607 | location: loc, args&: resultTypes, args&: tiledInputs, args&: tiledInits, args&: newMaps, args&: newIteratorTypes); |
| 608 | IRMapping mapping; |
| 609 | op->getRegion(index: 0).cloneInto(dest: &genericOp.getRegion(), |
| 610 | destPos: genericOp.getRegion().begin(), mapper&: mapping); |
| 611 | partialReductionOp = genericOp.getOperation(); |
| 612 | } else { |
| 613 | SmallVector<Value> operands = std::move(tiledInputs); |
| 614 | llvm::append_range(C&: operands, R&: tiledInits); |
| 615 | partialReductionOp = mlir::clone(b, op, newResultTypes: resultTypes, newOperands: operands); |
| 616 | } |
| 617 | return TilingResult{ |
| 618 | {partialReductionOp}, |
| 619 | llvm::map_to_vector(partialReductionOp->getResults(), |
| 620 | [](OpResult r) -> Value { return r; }), |
| 621 | generatedSlices}; |
| 622 | } |
| 623 | |
| 624 | FailureOr<MergeResult> |
| 625 | mergeReductions(Operation *op, OpBuilder &b, Location loc, |
| 626 | ValueRange partialReduce, |
| 627 | const SetVector<unsigned> &reductionDims) const { |
| 628 | auto linalgOp = cast<LinalgOp>(Val: op); |
| 629 | SmallVector<AffineMap> partialReductionMaps = |
| 630 | getPartialResultAffineMaps(linalgOp, reductionDims); |
| 631 | |
| 632 | // Permute the reduction dims as permuted by the partial result map. |
| 633 | SmallVector<Operation *> mergeOperations; |
| 634 | SmallVector<Value> replacements; |
| 635 | for (auto [idx, init, partialResult, partialMap] : llvm::enumerate( |
| 636 | First: linalgOp.getDpsInits(), Rest&: partialReduce, Rest&: partialReductionMaps)) { |
| 637 | unsigned initIdx = idx; |
| 638 | // linalg.reduce's iteration space is the tiled result's iteration space |
| 639 | // (and not the tiled operation's iteration space). To account for this, |
| 640 | // permute the reduction dimensions based on the partial result map of the |
| 641 | // tiled result. |
| 642 | SmallVector<int64_t> partialReductionDims; |
| 643 | for (auto [resultNum, dimExpr] : |
| 644 | llvm::enumerate(First: partialMap.getResults())) { |
| 645 | unsigned dim = cast<AffineDimExpr>(Val: dimExpr).getPosition(); |
| 646 | if (llvm::is_contained(Range: reductionDims, Element: dim)) { |
| 647 | partialReductionDims.push_back(Elt: resultNum); |
| 648 | } |
| 649 | } |
| 650 | |
| 651 | auto reduction = b.create<linalg::ReduceOp>( |
| 652 | loc, partialResult, init, partialReductionDims, |
| 653 | [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) { |
| 654 | // Get the combiner op. |
| 655 | SmallVector<Operation *, 4> combinerOps; |
| 656 | matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: initIdx, |
| 657 | combinerOps); |
| 658 | Operation *clonedReductionOp = b.clone(op&: *combinerOps[0]); |
| 659 | // Combine the input at idx and output at numInits + idx. |
| 660 | clonedReductionOp->setOperand(idx: 0, value: inputs[0]); |
| 661 | clonedReductionOp->setOperand(idx: 1, value: inputs[1]); |
| 662 | b.create<linalg::YieldOp>(location: loc, args: clonedReductionOp->getResult(idx: 0)); |
| 663 | }); |
| 664 | |
| 665 | mergeOperations.push_back(Elt: reduction); |
| 666 | replacements.push_back(Elt: reduction->getResult(0)); |
| 667 | } |
| 668 | |
| 669 | return MergeResult{.mergeOps: mergeOperations, .replacements: replacements}; |
| 670 | } |
| 671 | |
| 672 | LogicalResult getPartialResultTilePosition( |
| 673 | Operation *op, OpBuilder &b, unsigned resultNumber, |
| 674 | ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets, |
| 675 | ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims, |
| 676 | ArrayRef<OpFoldResult> splitReductionIvs, |
| 677 | SmallVector<OpFoldResult> &resultOffsets, |
| 678 | SmallVector<OpFoldResult> &resultSizes) const { |
| 679 | auto linalgOp = cast<LinalgOp>(Val: op); |
| 680 | SmallVector<AffineMap> partialReductionMaps = |
| 681 | getPartialResultAffineMaps(linalgOp, reductionDims); |
| 682 | InitSliceInfo sliceInfo = getInitSliceInfo( |
| 683 | context: b.getContext(), strategy: tilingStrategy, offsets, sizes, reductionDims, |
| 684 | splitReductionIvs, partialReductionMap: partialReductionMaps[resultNumber]); |
| 685 | std::swap(LHS&: resultOffsets, RHS&: sliceInfo.offsets); |
| 686 | std::swap(LHS&: resultSizes, RHS&: sliceInfo.sizes); |
| 687 | |
| 688 | return success(); |
| 689 | } |
| 690 | }; |
| 691 | |
| 692 | template <typename OpTy> |
| 693 | static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, |
| 694 | OpBuilder &builder) { |
| 695 | static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, |
| 696 | "applies to only pack or unpack operations" ); |
| 697 | OpBuilder::InsertionGuard g(builder); |
| 698 | int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() |
| 699 | : op.getDestRank(); |
| 700 | OpFoldResult zero = builder.getIndexAttr(value: 0); |
| 701 | OpFoldResult one = builder.getIndexAttr(value: 1); |
| 702 | ReifiedRankedShapedTypeDims resultShape; |
| 703 | (void)reifyResultShapes(builder, op, resultShape); |
| 704 | SmallVector<Range> loopBounds(rank); |
| 705 | for (auto dim : llvm::seq<int64_t>(Begin: 0, End: rank)) { |
| 706 | loopBounds[dim].offset = zero; |
| 707 | loopBounds[dim].stride = one; |
| 708 | loopBounds[dim].size = resultShape[0][dim]; |
| 709 | } |
| 710 | return loopBounds; |
| 711 | } |
| 712 | |
| 713 | static void applyPermToRange(SmallVector<OpFoldResult> &offsets, |
| 714 | SmallVector<OpFoldResult> &sizes, |
| 715 | ArrayRef<int64_t> permutation) { |
| 716 | if (permutation.empty()) |
| 717 | return; |
| 718 | applyPermutationToVector<OpFoldResult>(inVec&: offsets, permutation); |
| 719 | applyPermutationToVector<OpFoldResult>(inVec&: sizes, permutation); |
| 720 | } |
| 721 | |
| 722 | struct PackOpTiling |
| 723 | : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> { |
| 724 | |
| 725 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 726 | // Note that here we only consider untiled dimensions and outer tiled data |
| 727 | // dimensions, the inner tiled data dimensions are materialized when |
| 728 | // building the body of the operation. |
| 729 | auto packOp = cast<PackOp>(Val: op); |
| 730 | SmallVector<utils::IteratorType> iteratorTypes( |
| 731 | packOp.getSourceRank(), utils::IteratorType::parallel); |
| 732 | return iteratorTypes; |
| 733 | } |
| 734 | |
| 735 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 736 | return getPackUnPackIterationDomain<PackOp>(op: cast<PackOp>(Val: op), builder&: b); |
| 737 | } |
| 738 | |
| 739 | FailureOr<TilingResult> |
| 740 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 741 | ArrayRef<OpFoldResult> offsets, |
| 742 | ArrayRef<OpFoldResult> sizes) const { |
| 743 | auto packOp = cast<PackOp>(Val: op); |
| 744 | Location loc = packOp.getLoc(); |
| 745 | |
| 746 | // The tiling is applied on interchanged dimensions. We have to undo the |
| 747 | // interchange to map sizes and offsets to the original input. |
| 748 | int64_t inputRank = packOp.getSourceRank(); |
| 749 | SmallVector<OpFoldResult> origOffsets(offsets); |
| 750 | SmallVector<OpFoldResult> origSizes(sizes); |
| 751 | applyPermToRange(offsets&: origOffsets, sizes&: origSizes, |
| 752 | permutation: invertPermutationVector(permutation: packOp.getOuterDimsPerm())); |
| 753 | |
| 754 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 755 | packOp.getDimAndTileMapping(); |
| 756 | SmallVector<OpFoldResult> srcDimValues = |
| 757 | tensor::getMixedSizes(builder&: b, loc, value: packOp.getSource()); |
| 758 | SmallVector<OpFoldResult> inputIndices, inputSizes; |
| 759 | for (auto dim : llvm::seq<int64_t>(Begin: 0, End: inputRank)) { |
| 760 | using AV = affine::AffineValueExpr; |
| 761 | affine::AffineBuilder ab(b, loc); |
| 762 | AffineExpr dim0, dim1, sym; |
| 763 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
| 764 | bindSymbols(ctx: b.getContext(), exprs&: sym); |
| 765 | if (dimAndTileMapping.count(Val: dim)) { |
| 766 | // If the data dimension is tiled, the i-th index is the product of |
| 767 | // offset_i and tile_i, and the i-th size is the product of sizes_i and |
| 768 | // tile_i. |
| 769 | auto avOffset = AV(dim0).bind(v: origOffsets[dim]); |
| 770 | auto avSize = AV(dim0).bind(v: origSizes[dim]); |
| 771 | auto avTileSize = AV(sym).bind(v: dimAndTileMapping[dim]); |
| 772 | inputIndices.push_back(Elt: ab.mul(lhs: avOffset, rhs: avTileSize)); |
| 773 | inputSizes.push_back(Elt: ab.mul(lhs: avSize, rhs: avTileSize)); |
| 774 | } else { |
| 775 | inputIndices.push_back(Elt: origOffsets[dim]); |
| 776 | inputSizes.push_back(Elt: origSizes[dim]); |
| 777 | } |
| 778 | |
| 779 | // Limit the size of the input operand for incomplete tiles. |
| 780 | if (packOp.getPaddingValue()) { |
| 781 | OpFoldResult dimSize = srcDimValues[dim]; |
| 782 | auto avDimSize = AV(dim0).bind(v: dimSize); |
| 783 | auto avInputIdx = AV(dim1).bind(v: inputIndices.back()); |
| 784 | inputSizes.back() = |
| 785 | ab.min(vals: {inputSizes.back(), ab.sub(lhs: avDimSize, rhs: avInputIdx)}); |
| 786 | } |
| 787 | } |
| 788 | |
| 789 | auto oneAttr = b.getI64IntegerAttr(value: 1); |
| 790 | SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
| 791 | |
| 792 | SmallVector<Value> tiledOperands; |
| 793 | auto sourceSlice = b.create<tensor::ExtractSliceOp>( |
| 794 | location: loc, args: packOp.getSource(), args&: inputIndices, args&: inputSizes, args&: strides); |
| 795 | tiledOperands.push_back(Elt: sourceSlice); |
| 796 | |
| 797 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| 798 | if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets, sizes, resultOffsets&: outputOffsets, |
| 799 | resultSizes&: outputSizes))) |
| 800 | return {}; |
| 801 | |
| 802 | strides.append(NumInputs: packOp.getDestRank() - inputRank, Elt: oneAttr); |
| 803 | auto outSlice = b.create<tensor::ExtractSliceOp>( |
| 804 | location: loc, args: packOp.getDest(), args&: outputOffsets, args&: outputSizes, args&: strides); |
| 805 | tiledOperands.push_back(Elt: outSlice); |
| 806 | |
| 807 | if (auto val = packOp.getPaddingValue()) |
| 808 | tiledOperands.push_back(Elt: val); |
| 809 | for (auto tile : packOp.getInnerTiles()) |
| 810 | tiledOperands.push_back(Elt: tile); |
| 811 | |
| 812 | Operation *tiledPackOp = b.create<PackOp>( |
| 813 | location: loc, args: TypeRange{outSlice.getType()}, args&: tiledOperands, args: op->getAttrs()); |
| 814 | |
| 815 | return TilingResult{ |
| 816 | .tiledOps: {tiledPackOp}, |
| 817 | .tiledValues: SmallVector<Value>(tiledPackOp->getResults()), |
| 818 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})}; |
| 819 | } |
| 820 | |
| 821 | LogicalResult |
| 822 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 823 | ArrayRef<OpFoldResult> offsets, |
| 824 | ArrayRef<OpFoldResult> sizes, |
| 825 | SmallVector<OpFoldResult> &resultOffsets, |
| 826 | SmallVector<OpFoldResult> &resultSizes) const { |
| 827 | // The iteration domain is over outer dimensions of packed layout. In this |
| 828 | // context, the outer dimensions of `resultOffsets` are `offsets`. The |
| 829 | // inner dimensions of `resultOffsets` are zeros because tiling is not |
| 830 | // applied to them. |
| 831 | auto packOp = cast<PackOp>(Val: op); |
| 832 | int64_t inputRank = packOp.getSourceRank(); |
| 833 | int64_t outputRank = packOp.getDestRank(); |
| 834 | auto zeroAttr = b.getI64IntegerAttr(value: 0); |
| 835 | resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end()); |
| 836 | resultOffsets.append(NumInputs: outputRank - inputRank, Elt: zeroAttr); |
| 837 | |
| 838 | ReifiedRankedShapedTypeDims outputShape; |
| 839 | (void)reifyResultShapes(b, op: packOp, reifiedReturnShapes&: outputShape); |
| 840 | resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end()); |
| 841 | for (auto dataTileDim : llvm::seq<unsigned>(Begin: inputRank, End: outputRank)) |
| 842 | resultSizes.push_back(Elt: outputShape[0][dataTileDim]); |
| 843 | |
| 844 | return success(); |
| 845 | } |
| 846 | |
| 847 | FailureOr<TilingResult> |
| 848 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 849 | ArrayRef<OpFoldResult> offsets, |
| 850 | ArrayRef<OpFoldResult> sizes) const { |
| 851 | auto packOp = cast<PackOp>(Val: op); |
| 852 | int64_t numTiles = packOp.getInnerDimsPos().size(); |
| 853 | |
| 854 | // tensor.pack op is fusible (as a producer) only if full inner tiles are |
| 855 | // iterated or inner dims are not tiled. Otherwise, it will generate a |
| 856 | // sequence of non-trivial ops (for partial tiles). |
| 857 | for (auto offset : offsets.take_back(N: numTiles)) |
| 858 | if (!isZeroInteger(v: offset)) |
| 859 | return failure(); |
| 860 | |
| 861 | for (auto iter : |
| 862 | llvm::zip_equal(t: packOp.getMixedTiles(), u: sizes.take_back(N: numTiles))) |
| 863 | if (!isEqualConstantIntOrValue(ofr1: std::get<0>(t&: iter), ofr2: std::get<1>(t&: iter))) |
| 864 | return failure(); |
| 865 | |
| 866 | FailureOr<TilingResult> tilingResult = getTiledImplementation( |
| 867 | op, b, offsets: offsets.drop_back(N: numTiles), sizes: sizes.drop_back(N: numTiles)); |
| 868 | if (failed(Result: tilingResult)) |
| 869 | return failure(); |
| 870 | return tilingResult.value(); |
| 871 | } |
| 872 | |
| 873 | /// Method to return the position of iteration domain tile computed by the |
| 874 | /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and |
| 875 | /// `resultSizes` only cover outer dimensions. |
| 876 | LogicalResult getIterationDomainTileFromOperandTiles( |
| 877 | Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers, |
| 878 | ArrayRef<SmallVector<OpFoldResult>> allOffsets, |
| 879 | ArrayRef<SmallVector<OpFoldResult>> allSizes, |
| 880 | SmallVectorImpl<OpFoldResult> &resultOffsets, |
| 881 | SmallVectorImpl<OpFoldResult> &resultSizes) const { |
| 882 | if (operandNumbers.size() != 1 || operandNumbers[0] != 0) { |
| 883 | LLVM_DEBUG( |
| 884 | { llvm::dbgs() << "unsupported operands for consumer fusion" ; }); |
| 885 | return failure(); |
| 886 | } |
| 887 | |
| 888 | ArrayRef<OpFoldResult> offsets(allOffsets[0]); |
| 889 | ArrayRef<OpFoldResult> sizes(allSizes[0]); |
| 890 | |
| 891 | auto packOp = cast<PackOp>(Val: op); |
| 892 | // It is not trivial to infer dest tile from source tile if `packOp` has |
| 893 | // padding semantic. |
| 894 | if (packOp.getPaddingValue()) |
| 895 | return failure(); |
| 896 | |
| 897 | Location loc = packOp.getLoc(); |
| 898 | |
| 899 | SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; |
| 900 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 901 | packOp.getDimAndTileMapping(); |
| 902 | for (auto dim : llvm::seq<int64_t>(Size: packOp.getSourceRank())) { |
| 903 | if (dimAndTileMapping.count(Val: dim)) { |
| 904 | FailureOr<int64_t> cstSize = |
| 905 | ValueBoundsConstraintSet::computeConstantBound( |
| 906 | type: presburger::BoundType::UB, var: sizes[dim], |
| 907 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
| 908 | std::optional<int64_t> cstInnerSize = |
| 909 | getConstantIntValue(ofr: dimAndTileMapping[dim]); |
| 910 | // Currently fusing `packOp` as consumer only expects perfect tiling |
| 911 | // scenario because even if without padding semantic, the `packOp` may |
| 912 | // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, |
| 913 | // where the `tileSize` from operand of `packOp` is 5, which is not |
| 914 | // exactly divided by `innerTile`(=6) of `packOp`. As the result: |
| 915 | // 1. the first slice is extracted from (0) to (4) and inserted into |
| 916 | // (0,0)~(0,4) at first row. |
| 917 | // 2. the second slice is extracted from (5) to (9) and SHOULD BE |
| 918 | // respectively inserted into two rows with different length, including |
| 919 | // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate |
| 920 | // them, thus adding below constraint to bypass them temporarily. In |
| 921 | // another word, we can only support tiling with consumer if the tile |
| 922 | // size for the producer is a multiple of the inner tile size for the |
| 923 | // packed dimensions at this moment. |
| 924 | if (failed(Result: cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { |
| 925 | return failure(); |
| 926 | } |
| 927 | |
| 928 | using AV = affine::AffineValueExpr; |
| 929 | affine::AffineBuilder ab(b, loc); |
| 930 | AffineExpr dim0, sym; |
| 931 | bindDims(ctx: b.getContext(), exprs&: dim0); |
| 932 | bindSymbols(ctx: b.getContext(), exprs&: sym); |
| 933 | auto avOffset = AV(dim0).bind(v: offsets[dim]); |
| 934 | auto avSize = AV(dim0).bind(v: sizes[dim]); |
| 935 | auto avTileSize = AV(sym).bind(v: dimAndTileMapping[dim]); |
| 936 | outerDimOffsets.push_back(Elt: ab.floor(lhs: avOffset, rhs: avTileSize)); |
| 937 | outerDimSizes.push_back(Elt: ab.ceil(lhs: avSize, rhs: avTileSize)); |
| 938 | } else { |
| 939 | outerDimOffsets.push_back(Elt: offsets[dim]); |
| 940 | outerDimSizes.push_back(Elt: sizes[dim]); |
| 941 | } |
| 942 | } |
| 943 | applyPermToRange(offsets&: outerDimOffsets, sizes&: outerDimSizes, permutation: packOp.getOuterDimsPerm()); |
| 944 | resultOffsets = outerDimOffsets; |
| 945 | resultSizes = outerDimSizes; |
| 946 | return success(); |
| 947 | } |
| 948 | |
| 949 | /// Method to return the tiled implementation of tensor.pack as a consumer. |
| 950 | FailureOr<TilingResult> getTiledImplementationFromOperandTiles( |
| 951 | Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers, |
| 952 | ArrayRef<SmallVector<OpFoldResult>> allOffsets, |
| 953 | ArrayRef<SmallVector<OpFoldResult>> allSizes) const { |
| 954 | if (operandNumbers.size() != 1 || operandNumbers[0] != 0) { |
| 955 | LLVM_DEBUG( |
| 956 | { llvm ::dbgs() << "unhandled operands for consumer fusion" ; }); |
| 957 | return failure(); |
| 958 | } |
| 959 | |
| 960 | ArrayRef<OpFoldResult> offsets(allOffsets[0]); |
| 961 | ArrayRef<OpFoldResult> sizes(allSizes[0]); |
| 962 | |
| 963 | auto packOp = cast<PackOp>(Val: op); |
| 964 | Location loc = packOp.getLoc(); |
| 965 | |
| 966 | int64_t inputRank = packOp.getSourceRank(); |
| 967 | auto oneAttr = b.getI64IntegerAttr(value: 1); |
| 968 | SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
| 969 | |
| 970 | SmallVector<Value> tiledOperands; |
| 971 | auto sourceSlice = b.create<tensor::ExtractSliceOp>( |
| 972 | location: loc, args: packOp.getSource(), args&: offsets, args&: sizes, args&: strides); |
| 973 | tiledOperands.push_back(Elt: sourceSlice); |
| 974 | |
| 975 | SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; |
| 976 | if (failed(Result: getIterationDomainTileFromOperandTiles( |
| 977 | op, b, operandNumbers, allOffsets, allSizes, resultOffsets&: outerDimOffsets, |
| 978 | resultSizes&: outerDimSizes))) |
| 979 | return failure(); |
| 980 | |
| 981 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| 982 | if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets: outerDimOffsets, sizes: outerDimSizes, |
| 983 | resultOffsets&: outputOffsets, resultSizes&: outputSizes))) |
| 984 | return failure(); |
| 985 | |
| 986 | strides.append(NumInputs: packOp.getDestRank() - inputRank, Elt: oneAttr); |
| 987 | auto outSlice = b.create<tensor::ExtractSliceOp>( |
| 988 | location: loc, args: packOp.getDest(), args&: outputOffsets, args&: outputSizes, args&: strides); |
| 989 | tiledOperands.push_back(Elt: outSlice); |
| 990 | |
| 991 | assert(!packOp.getPaddingValue() && "Expect no padding semantic" ); |
| 992 | for (auto tile : packOp.getInnerTiles()) |
| 993 | tiledOperands.push_back(Elt: tile); |
| 994 | |
| 995 | Operation *tiledPackOp = b.create<PackOp>( |
| 996 | location: loc, args: TypeRange{outSlice.getType()}, args&: tiledOperands, args: op->getAttrs()); |
| 997 | |
| 998 | return TilingResult{ |
| 999 | .tiledOps: {tiledPackOp}, |
| 1000 | .tiledValues: SmallVector<Value>(tiledPackOp->getResults()), |
| 1001 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})}; |
| 1002 | } |
| 1003 | }; |
| 1004 | |
| 1005 | struct UnpackTileDimInfo { |
| 1006 | bool isAlignedToInnerTileSize; |
| 1007 | OpFoldResult sourceOffset; |
| 1008 | OpFoldResult sourceSize; |
| 1009 | OpFoldResult resultOffset; |
| 1010 | OpFoldResult destExpandedSize; |
| 1011 | }; |
| 1012 | |
| 1013 | /// Returns the needed information for tiling unpack op on `tileDim` with given |
| 1014 | /// `tileOffset` and `tileSize`. For more details, see the comment of the |
| 1015 | /// `getTiledImplementation`. |
| 1016 | static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, |
| 1017 | int64_t tileDim, |
| 1018 | OpFoldResult tileOffset, |
| 1019 | OpFoldResult tileSize) { |
| 1020 | UnpackTileDimInfo info; |
| 1021 | Attribute zeroAttr = b.getIndexAttr(value: 0); |
| 1022 | Attribute oneAttr = b.getIndexAttr(value: 1); |
| 1023 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 1024 | unpackOp.getDimAndTileMapping(); |
| 1025 | // The dimension is not one of packed data dimension. |
| 1026 | if (!dimAndTileMapping.count(Val: tileDim)) { |
| 1027 | info.isAlignedToInnerTileSize = true; |
| 1028 | info.sourceOffset = tileOffset; |
| 1029 | info.sourceSize = tileSize; |
| 1030 | info.resultOffset = zeroAttr; |
| 1031 | info.destExpandedSize = tileSize; |
| 1032 | return info; |
| 1033 | } |
| 1034 | |
| 1035 | Location loc = unpackOp.getLoc(); |
| 1036 | using AV = affine::AffineValueExpr; |
| 1037 | affine::AffineBuilder ab(b, loc); |
| 1038 | AffineExpr dim0, dim1, sym0; |
| 1039 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
| 1040 | bindSymbols(ctx: b.getContext(), exprs&: sym0); |
| 1041 | |
| 1042 | OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; |
| 1043 | |
| 1044 | info.isAlignedToInnerTileSize = false; |
| 1045 | FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( |
| 1046 | type: presburger::BoundType::UB, var: tileSize, |
| 1047 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
| 1048 | std::optional<int64_t> cstInnerSize = getConstantIntValue(ofr: innerTileSize); |
| 1049 | if (!failed(Result: cstSize) && cstInnerSize) { |
| 1050 | if (*cstSize % *cstInnerSize == 0) |
| 1051 | info.isAlignedToInnerTileSize = true; |
| 1052 | |
| 1053 | // If the tiling size equals to the inner tiling size, the outer dims are |
| 1054 | // always 1. |
| 1055 | if (*cstInnerSize == *cstSize) { |
| 1056 | auto lhs = AV(dim0).bind(v: tileOffset); |
| 1057 | auto rhs = AV(dim1).bind(v: innerTileSize); |
| 1058 | info.sourceOffset = ab.floor(lhs, rhs); |
| 1059 | info.sourceSize = oneAttr; |
| 1060 | info.resultOffset = zeroAttr; |
| 1061 | info.destExpandedSize = tileSize; |
| 1062 | return info; |
| 1063 | } |
| 1064 | } |
| 1065 | |
| 1066 | if (info.isAlignedToInnerTileSize) { |
| 1067 | info.sourceOffset = |
| 1068 | ab.floor(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: innerTileSize)); |
| 1069 | info.resultOffset = zeroAttr; |
| 1070 | info.destExpandedSize = tileSize; |
| 1071 | |
| 1072 | // The ceilDiv is needed here because there could be incomplete tile even |
| 1073 | // it is perfect tiling cases. E.g., |
| 1074 | // %0 = unpack tensor<33x2xf32> into tensor<64xf32> |
| 1075 | // If the tiling size is 32, there will be 3 tiles. Two of them have |
| 1076 | // size=32; one of them have size=2. The size is represented using |
| 1077 | // affine_min op; we need ceilDiv. |
| 1078 | info.sourceSize = |
| 1079 | ab.ceil(lhs: AV(dim0).bind(v: tileSize), rhs: AV(dim1).bind(v: innerTileSize)); |
| 1080 | return info; |
| 1081 | } |
| 1082 | |
| 1083 | affine::DivModValue firstCoord = affine::getDivMod( |
| 1084 | b, loc, lhs: getValueOrCreateConstantIndexOp(b, loc, ofr: tileOffset), |
| 1085 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
| 1086 | OpFoldResult tileExclusiveBound = |
| 1087 | ab.add(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: tileSize)); |
| 1088 | affine::DivModValue lastCoord = affine::getDivMod( |
| 1089 | b, loc, |
| 1090 | lhs: getValueOrCreateConstantIndexOp( |
| 1091 | b, loc, |
| 1092 | ofr: ab.sub(lhs: AV(dim0).bind(v: tileExclusiveBound), rhs: AV(dim1).bind(v: oneAttr))), |
| 1093 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
| 1094 | |
| 1095 | OpFoldResult lengthMinusOne = ab.sub(lhs: AV(dim0).bind(v: lastCoord.quotient), |
| 1096 | rhs: AV(dim1).bind(v: firstCoord.quotient)); |
| 1097 | info.sourceSize = |
| 1098 | ab.add(lhs: AV(dim0).bind(v: lengthMinusOne), rhs: AV(dim1).bind(v: oneAttr)); |
| 1099 | info.sourceOffset = firstCoord.quotient; |
| 1100 | info.resultOffset = firstCoord.remainder; |
| 1101 | // Do not create an Affine ops for expanded size because the affine op is too |
| 1102 | // complicated which would trigger an issue in affine ops simplification. |
| 1103 | info.destExpandedSize = b.createOrFold<arith::MulIOp>( |
| 1104 | location: loc, args: getValueOrCreateConstantIndexOp(b, loc, ofr: info.sourceSize), |
| 1105 | args: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
| 1106 | return info; |
| 1107 | } |
| 1108 | |
| 1109 | struct UnPackOpTiling |
| 1110 | : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> { |
| 1111 | |
| 1112 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 1113 | auto unpackOp = cast<UnPackOp>(Val: op); |
| 1114 | SmallVector<utils::IteratorType> iteratorTypes( |
| 1115 | unpackOp.getDestRank(), utils::IteratorType::parallel); |
| 1116 | return iteratorTypes; |
| 1117 | } |
| 1118 | |
| 1119 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 1120 | return getPackUnPackIterationDomain<UnPackOp>(op: cast<UnPackOp>(Val: op), builder&: b); |
| 1121 | } |
| 1122 | |
| 1123 | /// There are two cases in tiling unpack ops. If the tiling size is aligned to |
| 1124 | /// the inner tile size, the corresponding tiles of source are all complete. |
| 1125 | /// Otherwise, there are in-complete tiles. We will need to expand the slice |
| 1126 | /// of source for getting complete tiles. The tiled unpack op unpacks more |
| 1127 | /// data from source, so We'll need an extract_slice op to shift and truncate |
| 1128 | /// the output. |
| 1129 | /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The |
| 1130 | /// coordinates of second tile (i.e., result[15..31]) are |
| 1131 | /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last |
| 1132 | /// row are incomplete tiles. To represent the unpack op, we have to complete |
| 1133 | /// the rows. I.e., the input coordinates would start with (1, 0); end with |
| 1134 | /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements |
| 1135 | /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we |
| 1136 | /// can get the actual result. |
| 1137 | FailureOr<TilingResult> |
| 1138 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 1139 | ArrayRef<OpFoldResult> offsets, |
| 1140 | ArrayRef<OpFoldResult> sizes) const { |
| 1141 | auto unpackOp = cast<UnPackOp>(Val: op); |
| 1142 | int64_t srcRank = unpackOp.getSourceRank(); |
| 1143 | int64_t destRank = unpackOp.getDestRank(); |
| 1144 | int64_t numInnerTiles = srcRank - destRank; |
| 1145 | Location loc = unpackOp.getLoc(); |
| 1146 | |
| 1147 | // The perfect tiling case indicates that the tiling sizes are multiple of |
| 1148 | // inner_tile_size. In this context, no extra data is needed when |
| 1149 | // representing the tiled unpack op. |
| 1150 | bool isPerfectTilingCase = true; |
| 1151 | Attribute oneAttr = b.getIndexAttr(value: 1); |
| 1152 | SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); |
| 1153 | SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; |
| 1154 | SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; |
| 1155 | for (auto dim : llvm::seq<int64_t>(Begin: 0, End: destRank)) { |
| 1156 | UnpackTileDimInfo info = |
| 1157 | getUnpackTileDimInfo(b, unpackOp, tileDim: dim, tileOffset: offsets[dim], tileSize: sizes[dim]); |
| 1158 | if (!info.isAlignedToInnerTileSize) |
| 1159 | isPerfectTilingCase = false; |
| 1160 | sliceSrcIndices.push_back(Elt: info.sourceOffset); |
| 1161 | sliceSrcSizes.push_back(Elt: info.sourceSize); |
| 1162 | destExpandedSizes.push_back(Elt: info.destExpandedSize); |
| 1163 | resultOffsetsFromDest.push_back(Elt: info.resultOffset); |
| 1164 | } |
| 1165 | |
| 1166 | // The tiling is applied on destination dimensions. We have to apply the |
| 1167 | // interchange on source dimensions if outer_dims_perm is set. |
| 1168 | applyPermToRange(offsets&: sliceSrcIndices, sizes&: sliceSrcSizes, |
| 1169 | permutation: unpackOp.getOuterDimsPerm()); |
| 1170 | Attribute zeroAttr = b.getIndexAttr(value: 0); |
| 1171 | sliceSrcIndices.append(NumInputs: numInnerTiles, Elt: zeroAttr); |
| 1172 | sliceSrcSizes.append(RHS: unpackOp.getMixedTiles()); |
| 1173 | sliceSrcStrides.append(NumInputs: numInnerTiles, Elt: oneAttr); |
| 1174 | SmallVector<Operation *> generatedSlices; |
| 1175 | tensor::ExtractSliceOp sliceSource = b.create<tensor::ExtractSliceOp>( |
| 1176 | location: loc, args: unpackOp.getSource(), args&: sliceSrcIndices, args&: sliceSrcSizes, |
| 1177 | args&: sliceSrcStrides); |
| 1178 | generatedSlices.push_back(Elt: sliceSource); |
| 1179 | |
| 1180 | SmallVector<OpFoldResult> destStrides(destRank, oneAttr); |
| 1181 | Value sliceDest; |
| 1182 | if (isPerfectTilingCase) { |
| 1183 | auto destSliceOp = b.create<tensor::ExtractSliceOp>( |
| 1184 | location: loc, args: unpackOp.getDest(), args&: offsets, args&: sizes, args&: destStrides); |
| 1185 | sliceDest = destSliceOp; |
| 1186 | generatedSlices.push_back(Elt: destSliceOp); |
| 1187 | } else { |
| 1188 | sliceDest = b.create<tensor::EmptyOp>( |
| 1189 | location: loc, args&: destExpandedSizes, args: unpackOp.getDestType().getElementType()); |
| 1190 | } |
| 1191 | |
| 1192 | SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest}; |
| 1193 | for (auto tile : unpackOp.getInnerTiles()) |
| 1194 | tiledOperands.push_back(Elt: tile); |
| 1195 | |
| 1196 | Operation *tiledUnpackOp = b.create<UnPackOp>( |
| 1197 | location: loc, args: TypeRange{sliceDest.getType()}, args&: tiledOperands, args: op->getAttrs()); |
| 1198 | |
| 1199 | if (isPerfectTilingCase) |
| 1200 | return TilingResult{.tiledOps: {tiledUnpackOp}, |
| 1201 | .tiledValues: SmallVector<Value>(tiledUnpackOp->getResults()), |
| 1202 | .generatedSlices: generatedSlices}; |
| 1203 | |
| 1204 | auto = b.create<tensor::ExtractSliceOp>( |
| 1205 | location: loc, args: tiledUnpackOp->getResult(idx: 0), args&: resultOffsetsFromDest, args&: sizes, |
| 1206 | args&: destStrides); |
| 1207 | return TilingResult{ |
| 1208 | .tiledOps: {tiledUnpackOp}, .tiledValues: {extractSlice.getResult()}, .generatedSlices: generatedSlices}; |
| 1209 | } |
| 1210 | |
| 1211 | LogicalResult |
| 1212 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 1213 | ArrayRef<OpFoldResult> offsets, |
| 1214 | ArrayRef<OpFoldResult> sizes, |
| 1215 | SmallVector<OpFoldResult> &resultOffsets, |
| 1216 | SmallVector<OpFoldResult> &resultSizes) const { |
| 1217 | resultOffsets = llvm::to_vector(Range&: offsets); |
| 1218 | resultSizes = llvm::to_vector(Range&: sizes); |
| 1219 | return success(); |
| 1220 | } |
| 1221 | |
| 1222 | FailureOr<TilingResult> |
| 1223 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 1224 | ArrayRef<OpFoldResult> offsets, |
| 1225 | ArrayRef<OpFoldResult> sizes) const { |
| 1226 | FailureOr<TilingResult> tilingResult = |
| 1227 | getTiledImplementation(op, b, offsets, sizes); |
| 1228 | if (failed(Result: tilingResult)) |
| 1229 | return failure(); |
| 1230 | return tilingResult.value(); |
| 1231 | } |
| 1232 | |
| 1233 | /// Method to return the position of iteration domain tile computed by the |
| 1234 | /// tiled operation. |
| 1235 | LogicalResult getIterationDomainTileFromOperandTiles( |
| 1236 | Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers, |
| 1237 | ArrayRef<SmallVector<OpFoldResult>> allOffsets, |
| 1238 | ArrayRef<SmallVector<OpFoldResult>> allSizes, |
| 1239 | SmallVectorImpl<OpFoldResult> &resultOffsets, |
| 1240 | SmallVectorImpl<OpFoldResult> &resultSizes) const { |
| 1241 | if (operandNumbers.size() != 1) { |
| 1242 | LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands" ; }); |
| 1243 | return failure(); |
| 1244 | } |
| 1245 | auto unPackOp = cast<UnPackOp>(Val: op); |
| 1246 | unsigned operandNumber = operandNumbers[0]; |
| 1247 | ArrayRef<OpFoldResult> offsets(allOffsets[0]); |
| 1248 | ArrayRef<OpFoldResult> sizes(allSizes[0]); |
| 1249 | |
| 1250 | // If the operand tile is the dest, then no adjustment is needed. |
| 1251 | if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) { |
| 1252 | resultOffsets = llvm::to_vector(Range&: offsets); |
| 1253 | resultSizes = llvm::to_vector(Range&: sizes); |
| 1254 | return success(); |
| 1255 | } |
| 1256 | Location loc = unPackOp.getLoc(); |
| 1257 | |
| 1258 | int64_t numTiles = unPackOp.getInnerDimsPos().size(); |
| 1259 | auto destOffsets = offsets.drop_back(N: numTiles); |
| 1260 | auto destSizes = sizes.drop_back(N: numTiles); |
| 1261 | // The tiling is applied on interchanged dimensions. We have to undo the |
| 1262 | // interchange to map sizes and offsets to the original input. |
| 1263 | int64_t outputRank = unPackOp.getDestRank(); |
| 1264 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
| 1265 | if (failed(Result: reifyResultShapes(b, op: unPackOp, reifiedReturnShapes))) |
| 1266 | return failure(); |
| 1267 | SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front(); |
| 1268 | SmallVector<OpFoldResult> origOffsets(destOffsets); |
| 1269 | SmallVector<OpFoldResult> origSizes(destSizes); |
| 1270 | applyPermToRange(offsets&: origOffsets, sizes&: origSizes, |
| 1271 | permutation: invertPermutationVector(permutation: unPackOp.getOuterDimsPerm())); |
| 1272 | |
| 1273 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 1274 | unPackOp.getDimAndTileMapping(); |
| 1275 | |
| 1276 | for (auto dim : llvm::seq<int64_t>(Begin: 0, End: outputRank)) { |
| 1277 | using AV = affine::AffineValueExpr; |
| 1278 | affine::AffineBuilder ab(b, loc); |
| 1279 | AffineExpr dim0, dim1, sym0; |
| 1280 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
| 1281 | bindSymbols(ctx: b.getContext(), exprs&: sym0); |
| 1282 | if (dimAndTileMapping.count(Val: dim)) { |
| 1283 | // If the data dimension is tiled, the i-th index is the product of |
| 1284 | // offset_i and tile_i, and the i-th size is the product of sizes_i and |
| 1285 | // tile_i. The sizes must be clamped to the sizes of the unpack result. |
| 1286 | auto avOffset = AV(dim0).bind(v: origOffsets[dim]); |
| 1287 | auto avSize = AV(dim0).bind(v: origSizes[dim]); |
| 1288 | auto avTileSize = AV(sym0).bind(v: dimAndTileMapping[dim]); |
| 1289 | auto avResultSize = AV(dim0).bind(v: outputMixedSizes[dim]); |
| 1290 | resultOffsets.push_back(Elt: ab.mul(lhs: avOffset, rhs: avTileSize)); |
| 1291 | auto avResultOffset = AV(dim1).bind(v: resultOffsets.back()); |
| 1292 | resultSizes.push_back(Elt: ab.min(vals: {ab.mul(lhs: avSize, rhs: avTileSize), |
| 1293 | ab.sub(lhs: avResultSize, rhs: avResultOffset)})); |
| 1294 | } else { |
| 1295 | resultOffsets.push_back(Elt: origOffsets[dim]); |
| 1296 | resultSizes.push_back(Elt: origSizes[dim]); |
| 1297 | } |
| 1298 | } |
| 1299 | return success(); |
| 1300 | } |
| 1301 | |
| 1302 | /// Method to return the tiled implementation of tensor.unpack as a consumer. |
| 1303 | FailureOr<TilingResult> getTiledImplementationFromOperandTiles( |
| 1304 | Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers, |
| 1305 | ArrayRef<SmallVector<OpFoldResult>> allOffsets, |
| 1306 | ArrayRef<SmallVector<OpFoldResult>> allSizes) const { |
| 1307 | if (operandNumbers.size() != 1 || operandNumbers[0] != 0) { |
| 1308 | LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion" ; }); |
| 1309 | return failure(); |
| 1310 | } |
| 1311 | auto unPackOp = cast<UnPackOp>(Val: op); |
| 1312 | ArrayRef<OpFoldResult> offsets(allOffsets[0]); |
| 1313 | ArrayRef<OpFoldResult> sizes(allSizes[0]); |
| 1314 | |
| 1315 | // tensor.unpack op is fusible (as a consumer) only if inner dims are not |
| 1316 | // tiled. |
| 1317 | int64_t numTiles = unPackOp.getInnerDimsPos().size(); |
| 1318 | for (auto iter : |
| 1319 | llvm::zip_equal(t: unPackOp.getMixedTiles(), u: sizes.take_back(N: numTiles))) { |
| 1320 | if (!isEqualConstantIntOrValue(ofr1: std::get<0>(t&: iter), ofr2: std::get<1>(t&: iter))) |
| 1321 | return failure(); |
| 1322 | } |
| 1323 | |
| 1324 | Location loc = unPackOp.getLoc(); |
| 1325 | |
| 1326 | // Fetch offset/size for creating the slice of the dest operand of |
| 1327 | // unpack op. |
| 1328 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| 1329 | if (failed(Result: getIterationDomainTileFromOperandTiles( |
| 1330 | op, b, operandNumbers, allOffsets, allSizes, resultOffsets&: outputOffsets, |
| 1331 | resultSizes&: outputSizes))) |
| 1332 | return failure(); |
| 1333 | |
| 1334 | auto oneAttr = b.getI64IntegerAttr(value: 1); |
| 1335 | int64_t outputRank = unPackOp.getDestRank(); |
| 1336 | SmallVector<OpFoldResult> strides(outputRank, oneAttr); |
| 1337 | |
| 1338 | SmallVector<Value> tiledOperands; |
| 1339 | // Create slice of the dest operand. |
| 1340 | auto = b.create<tensor::ExtractSliceOp>( |
| 1341 | location: loc, args: unPackOp.getDest(), args&: outputOffsets, args&: outputSizes, args&: strides); |
| 1342 | tiledOperands.push_back(Elt: extractDestSlice); |
| 1343 | |
| 1344 | strides.append(NumInputs: unPackOp.getSourceRank() - outputRank, Elt: oneAttr); |
| 1345 | // Create slice of the source operand. |
| 1346 | auto = b.create<tensor::ExtractSliceOp>( |
| 1347 | location: loc, args: unPackOp.getSource(), args&: offsets, args&: sizes, args&: strides); |
| 1348 | tiledOperands.insert(I: tiledOperands.begin(), Elt: extractSourceSlice); |
| 1349 | for (auto tile : unPackOp.getInnerTiles()) |
| 1350 | tiledOperands.push_back(Elt: tile); |
| 1351 | |
| 1352 | // Create tiled unpack op. |
| 1353 | Operation *tiledUnPackOp = |
| 1354 | b.create<UnPackOp>(location: loc, args: TypeRange{extractDestSlice.getType()}, |
| 1355 | args&: tiledOperands, args: op->getAttrs()); |
| 1356 | |
| 1357 | return TilingResult{.tiledOps: {tiledUnPackOp}, |
| 1358 | .tiledValues: SmallVector<Value>(tiledUnPackOp->getResults()), |
| 1359 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{ |
| 1360 | extractSourceSlice, extractDestSlice})}; |
| 1361 | } |
| 1362 | }; |
| 1363 | |
| 1364 | } // namespace |
| 1365 | |
| 1366 | template <typename OpType> |
| 1367 | static void registerOne(MLIRContext *ctx) { |
| 1368 | OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); |
| 1369 | OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>( |
| 1370 | *ctx); |
| 1371 | } |
| 1372 | |
| 1373 | /// Variadic helper function. |
| 1374 | template <typename... OpTypes> |
| 1375 | static void registerAll(MLIRContext *ctx) { |
| 1376 | (registerOne<OpTypes>(ctx), ...); |
| 1377 | } |
| 1378 | |
| 1379 | #define GET_OP_LIST |
| 1380 | |
| 1381 | void mlir::linalg::registerTilingInterfaceExternalModels( |
| 1382 | DialectRegistry ®istry) { |
| 1383 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { |
| 1384 | registerOne<linalg::GenericOp>(ctx); |
| 1385 | linalg::PackOp::attachInterface<PackOpTiling>(context&: *ctx); |
| 1386 | linalg::UnPackOp::attachInterface<UnPackOpTiling>(context&: *ctx); |
| 1387 | registerAll< |
| 1388 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
| 1389 | >(ctx); |
| 1390 | }); |
| 1391 | } |
| 1392 | |
| 1393 | void mlir::linalg::registerTilingInterfaceExternalModelsForPackUnPackOps( |
| 1394 | DialectRegistry ®istry) { |
| 1395 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, LinalgDialect *dialect) { |
| 1396 | linalg::PackOp::attachInterface<PackOpTiling>(context&: *ctx); |
| 1397 | linalg::UnPackOp::attachInterface<UnPackOpTiling>(context&: *ctx); |
| 1398 | }); |
| 1399 | } |
| 1400 | |