| 1 | //===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" |
| 10 | |
| 11 | #include "mlir/Analysis/SliceAnalysis.h" |
| 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 13 | #include "mlir/Dialect/Affine/Utils.h" |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 17 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 22 | #include "mlir/Interfaces/TilingInterface.h" |
| 23 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 24 | #include <optional> |
| 25 | |
| 26 | using namespace mlir; |
| 27 | using namespace mlir::linalg; |
| 28 | |
| 29 | //===----------------------------------------------------------------------===// |
| 30 | // Utility methods for implementation of Tiling Interface for Linalg ops |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | |
| 33 | /// Return the SSA values that represent the data point accessed using a given |
| 34 | /// `indexingMap` for a given point in the iteration space represented by `ivs`. |
| 35 | static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc, |
| 36 | AffineMap indexingMap, |
| 37 | ValueRange ivs) { |
| 38 | SmallVector<Value> indices; |
| 39 | indices.reserve(N: indexingMap.getNumResults()); |
| 40 | for (auto result : indexingMap.getResults()) { |
| 41 | AffineMap m = AffineMap::get(dimCount: indexingMap.getNumDims(), |
| 42 | symbolCount: indexingMap.getNumSymbols(), result); |
| 43 | Value v = b.create<affine::AffineApplyOp>(loc, m, ivs); |
| 44 | indices.push_back(Elt: v); |
| 45 | } |
| 46 | return indices; |
| 47 | } |
| 48 | |
| 49 | /// Method to inline the payload of a `linalgOp` given the iteration space |
| 50 | /// point and values for the arguments of the payload. |
| 51 | static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, |
| 52 | ValueRange ivs, ValueRange argValues) { |
| 53 | Block *body = linalgOp.getBlock(); |
| 54 | IRMapping map; |
| 55 | map.map(from: body->getArguments(), to&: argValues); |
| 56 | for (auto &op : body->without_terminator()) { |
| 57 | if (auto indexOp = dyn_cast<IndexOp>(&op)) { |
| 58 | map.map(indexOp.getResult(), ivs[indexOp.getDim()]); |
| 59 | continue; |
| 60 | } |
| 61 | b.clone(op, map); |
| 62 | } |
| 63 | |
| 64 | Operation *terminator = body->getTerminator(); |
| 65 | Location loc = terminator->getLoc(); |
| 66 | for (const auto &operand : llvm::enumerate(terminator->getOperands())) { |
| 67 | Value toStore = map.lookupOrDefault(operand.value()); |
| 68 | OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); |
| 69 | auto indices = getIndicesForAccess( |
| 70 | b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); |
| 71 | b.create<memref::StoreOp>( |
| 72 | loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), |
| 73 | indices); |
| 74 | } |
| 75 | return success(); |
| 76 | } |
| 77 | |
| 78 | //===----------------------------------------------------------------------===// |
| 79 | // External Model for implementing `TilingInterface` for `LinalgOp`s. |
| 80 | //===----------------------------------------------------------------------===// |
| 81 | |
| 82 | namespace { |
| 83 | /// External model implementation of TilingInterface for LinalgOps. An external |
| 84 | /// model implementation is used for now till the use of `TilingInterface` is |
| 85 | /// on-par with the current Linalg tiling + fusion patterns. Once it is |
| 86 | /// maybe possible to move this into the op-definition (though there are |
| 87 | /// advantages to leaving it as an external model) |
| 88 | template <typename LinalgOpTy> |
| 89 | struct LinalgOpTilingInterface |
| 90 | : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>, |
| 91 | LinalgOpTy> { |
| 92 | /// Return the loop iterator type. |
| 93 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 94 | LinalgOpTy concreteOp = cast<LinalgOpTy>(op); |
| 95 | return concreteOp.getIteratorTypesArray(); |
| 96 | } |
| 97 | |
| 98 | /// Return the iteration domain range. |
| 99 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 100 | OpBuilder::InsertionGuard g(b); |
| 101 | b.setInsertionPoint(op); |
| 102 | Location loc = op->getLoc(); |
| 103 | LinalgOp linalgOp = cast<LinalgOp>(op); |
| 104 | SmallVector<OpFoldResult> allShapesSizes = |
| 105 | linalgOp.createFlatListOfOperandDims(b, loc); |
| 106 | AffineMap map = linalgOp.getShapesToLoopsMap(); |
| 107 | |
| 108 | return llvm::to_vector( |
| 109 | llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { |
| 110 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
| 111 | b, loc, expr: loopExpr, operands: allShapesSizes); |
| 112 | return Range{b.getIndexAttr(0), .size: ofr, b.getIndexAttr(1)}; |
| 113 | })); |
| 114 | } |
| 115 | |
| 116 | /// Instantiate the tiled implementation of the operation. |
| 117 | FailureOr<TilingResult> |
| 118 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 119 | ArrayRef<OpFoldResult> offsets, |
| 120 | ArrayRef<OpFoldResult> sizes) const { |
| 121 | // Leave the `sizeBounds` value empty. That is only needed when the `sizes` |
| 122 | // specified could lead to out of bounds accesses. |
| 123 | Location loc = op->getLoc(); |
| 124 | LinalgOp linalgOp = cast<LinalgOp>(op); |
| 125 | SmallVector<Value> valuesToTile = linalgOp->getOperands(); |
| 126 | SmallVector<Value> tiledOperands = makeTiledShapes( |
| 127 | b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); |
| 128 | SmallVector<Operation *> generatedSlices = llvm::map_to_vector( |
| 129 | llvm::make_filter_range( |
| 130 | tiledOperands, |
| 131 | [](Value v) -> bool { |
| 132 | return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>( |
| 133 | v.getDefiningOp()); |
| 134 | }), |
| 135 | [](Value v) -> Operation * { return v.getDefiningOp(); }); |
| 136 | |
| 137 | SmallVector<Type> resultTensorTypes = |
| 138 | getTensorOutputTypes(linalgOp, tiledOperands); |
| 139 | |
| 140 | Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); |
| 141 | offsetIndices(b, cast<LinalgOp>(tiledOp), offsets); |
| 142 | |
| 143 | return TilingResult{ |
| 144 | .tiledOps: {tiledOp}, .tiledValues: SmallVector<Value>(tiledOp->getResults()), .generatedSlices: generatedSlices}; |
| 145 | } |
| 146 | |
| 147 | /// Utility to fetch the offsets and sizes when applied as per the indexing |
| 148 | /// map of the linalg op. This helps in fusing the linalg op as a consumer of |
| 149 | /// a given slice op. |
| 150 | void |
| 151 | getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, |
| 152 | ArrayRef<OpFoldResult> offsets, |
| 153 | ArrayRef<OpFoldResult> sizes, |
| 154 | SmallVectorImpl<OpFoldResult> &mappedOffsets, |
| 155 | SmallVectorImpl<OpFoldResult> &mappedSizes) const { |
| 156 | unsigned numLoops = linalgOp.getNumLoops(); |
| 157 | auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); |
| 158 | mappedOffsets.resize(N: numLoops); |
| 159 | mappedSizes.resize(N: numLoops); |
| 160 | if (!indexingMap.isPermutation()) { |
| 161 | SmallVector<Range> iterationDomain = |
| 162 | tilingInterfaceOp.getIterationDomain(b); |
| 163 | for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { |
| 164 | mappedOffsets[index] = value.offset; |
| 165 | mappedSizes[index] = value.size; |
| 166 | } |
| 167 | } |
| 168 | for (const auto &&[index, value] : |
| 169 | llvm::enumerate(First: indexingMap.getResults())) { |
| 170 | unsigned dimPosition = cast<AffineDimExpr>(Val: value).getPosition(); |
| 171 | mappedOffsets[dimPosition] = offsets[index]; |
| 172 | mappedSizes[dimPosition] = sizes[index]; |
| 173 | } |
| 174 | } |
| 175 | |
| 176 | /// Method to return the position of the result tile computed by the tiled |
| 177 | /// operation. |
| 178 | LogicalResult getIterationDomainTileFromOperandTile( |
| 179 | Operation *op, OpBuilder &b, unsigned operandNumber, |
| 180 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 181 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
| 182 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
| 183 | auto linalgOp = cast<LinalgOp>(op); |
| 184 | |
| 185 | // Check that the indexing map used for the operand is a projected |
| 186 | // permutation. This could be relaxed with a more general approach that can |
| 187 | // map the offsets and sizes from the operand to iteration space tiles |
| 188 | // (filling in full extent for dimensions not used to access the result). |
| 189 | AffineMap indexingMap = |
| 190 | linalgOp.getMatchingIndexingMap(&op->getOpOperand(idx: operandNumber)); |
| 191 | if (!indexingMap.isProjectedPermutation()) { |
| 192 | return op->emitError() |
| 193 | << "unhandled get iter domain position when operand is not " |
| 194 | "accessed using a permuted projection" ; |
| 195 | } |
| 196 | |
| 197 | getMappedOffsetAndSize(linalgOp: linalgOp, b, indexingMap, offsets, sizes, |
| 198 | mappedOffsets&: iterDomainOffsets, mappedSizes&: iterDomainSizes); |
| 199 | return success(); |
| 200 | } |
| 201 | |
| 202 | /// Return the details of the output tile generated by the tiled |
| 203 | /// implementation. |
| 204 | LogicalResult |
| 205 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 206 | ArrayRef<OpFoldResult> offsets, |
| 207 | ArrayRef<OpFoldResult> sizes, |
| 208 | SmallVector<OpFoldResult> &resultOffsets, |
| 209 | SmallVector<OpFoldResult> &resultSizes) const { |
| 210 | Location loc = op->getLoc(); |
| 211 | LinalgOp linalgOp = cast<LinalgOp>(op); |
| 212 | |
| 213 | AffineExpr d0; |
| 214 | bindDims(ctx: b.getContext(), exprs&: d0); |
| 215 | SmallVector<OpFoldResult> subShapeSizes = |
| 216 | llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { |
| 217 | return affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - 1, operands: ofr); |
| 218 | })); |
| 219 | |
| 220 | OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); |
| 221 | SliceParameters sliceParams = computeSliceParameters( |
| 222 | b, loc, outOperand->get(), sizes, |
| 223 | linalgOp.getMatchingIndexingMap(outOperand), offsets, |
| 224 | /*ubs*/ {}, subShapeSizes, true); |
| 225 | resultOffsets = sliceParams.offsets; |
| 226 | resultSizes = sliceParams.sizes; |
| 227 | return success(); |
| 228 | } |
| 229 | |
| 230 | LogicalResult getIterationDomainTileFromResultTile( |
| 231 | Operation *op, OpBuilder &b, unsigned resultNumber, |
| 232 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 233 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
| 234 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
| 235 | auto linalgOp = cast<LinalgOp>(op); |
| 236 | |
| 237 | // Check that the indexing map used for the output is a projected |
| 238 | // permutation. This could be relaxed with a more general approach that can |
| 239 | // map the offsets and sizes from the result to iteration space tiles |
| 240 | // (filling in full extent for dimensions not used to access the result). |
| 241 | AffineMap indexingMap = |
| 242 | linalgOp.getIndexingMapMatchingResult(op->getResult(idx: resultNumber)); |
| 243 | if (!indexingMap.isProjectedPermutation()) { |
| 244 | return op->emitOpError( |
| 245 | message: "unhandled tiled implementation generation when result is not " |
| 246 | "accessed using a permuted projection" ); |
| 247 | } |
| 248 | |
| 249 | getMappedOffsetAndSize(linalgOp: linalgOp, b, indexingMap, offsets, sizes, |
| 250 | mappedOffsets&: iterDomainOffsets, mappedSizes&: iterDomainSizes); |
| 251 | return success(); |
| 252 | } |
| 253 | |
| 254 | FailureOr<TilingResult> |
| 255 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 256 | ArrayRef<OpFoldResult> offsets, |
| 257 | ArrayRef<OpFoldResult> sizes) const { |
| 258 | SmallVector<OpFoldResult> mappedOffsets, mappedSizes; |
| 259 | if (failed(getIterationDomainTileFromResultTile( |
| 260 | op, b, resultNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets, iterDomainSizes&: mappedSizes))) { |
| 261 | return failure(); |
| 262 | } |
| 263 | auto tilingInterfaceOp = cast<TilingInterface>(op); |
| 264 | FailureOr<TilingResult> tilingResult = |
| 265 | tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); |
| 266 | |
| 267 | if (failed(Result: tilingResult)) |
| 268 | return failure(); |
| 269 | |
| 270 | if (tilingResult->tiledOps.size() != 1) |
| 271 | return op->emitOpError(message: "failed to generate tiled implementation" ); |
| 272 | |
| 273 | return TilingResult{ |
| 274 | .tiledOps: tilingResult->tiledOps, |
| 275 | .tiledValues: SmallVector<Value>{tilingResult->tiledValues[resultNumber]}, |
| 276 | .generatedSlices: tilingResult->generatedSlices}; |
| 277 | } |
| 278 | |
| 279 | /// Method to generate the tiled implementation of an operation from the tile |
| 280 | /// of the operand. |
| 281 | FailureOr<TilingResult> getTiledImplementationFromOperandTile( |
| 282 | Operation *op, OpBuilder &b, unsigned operandNumber, |
| 283 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { |
| 284 | SmallVector<OpFoldResult> mappedOffsets, mappedSizes; |
| 285 | if (failed(getIterationDomainTileFromOperandTile( |
| 286 | op, b, operandNumber, offsets, sizes, iterDomainOffsets&: mappedOffsets, |
| 287 | iterDomainSizes&: mappedSizes))) { |
| 288 | return failure(); |
| 289 | } |
| 290 | return getTiledImplementation(op, b, offsets: mappedOffsets, sizes: mappedSizes); |
| 291 | } |
| 292 | |
| 293 | LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, |
| 294 | Location loc, |
| 295 | ValueRange ivs) const { |
| 296 | auto linalgOp = cast<LinalgOp>(op); |
| 297 | if (!linalgOp.hasPureBufferSemantics()) |
| 298 | return op->emitOpError(message: "expected operation to have buffer semantics" ); |
| 299 | |
| 300 | SmallVector<Value> indexedValues; |
| 301 | indexedValues.reserve(N: linalgOp->getNumOperands()); |
| 302 | Location linalgOpLoc = op->getLoc(); |
| 303 | /// Load the data corresponding to the block arguments that |
| 304 | /// represent input operands. |
| 305 | for (OpOperand &operand : linalgOp->getOpOperands()) { |
| 306 | if (!linalgOp.payloadUsesValueFromOperand(&operand)) { |
| 307 | indexedValues.push_back(nullptr); |
| 308 | continue; |
| 309 | } |
| 310 | if (linalgOp.isScalar(&operand)) { |
| 311 | indexedValues.push_back(operand.get()); |
| 312 | continue; |
| 313 | } |
| 314 | SmallVector<Value> indices = getIndicesForAccess( |
| 315 | builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); |
| 316 | Value load = |
| 317 | builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices); |
| 318 | indexedValues.push_back(load); |
| 319 | } |
| 320 | |
| 321 | /// Inline the op payload and store the result. |
| 322 | return inlinePayload(builder, linalgOp, ivs, indexedValues); |
| 323 | } |
| 324 | }; |
| 325 | |
| 326 | //===----------------------------------------------------------------------===// |
| 327 | // External Model for implementing `PartialReductionInterface` for `LinalgOp`s. |
| 328 | //===----------------------------------------------------------------------===// |
| 329 | |
| 330 | /// Return an AffineMap for a partial result for the given result number, |
| 331 | /// assuming the partial tiling strategy is outer-reduction loop + |
| 332 | /// inner-parallel tile. The returned AffineMap can be used as the replacement |
| 333 | /// AffineMap for the inner-parallel tile linalg op for the given result number. |
| 334 | /// |
| 335 | /// The new AffineMap is the old AffineMap with reduction dimensions appended |
| 336 | /// at end. |
| 337 | static AffineMap getPartialResultAffineMap(LinalgOp linalgOp, |
| 338 | ArrayRef<int> reductionDims, |
| 339 | unsigned resultNumber) { |
| 340 | AffineMap map = |
| 341 | linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber)); |
| 342 | for (int redPos : reductionDims) { |
| 343 | map = map.insertResult(expr: getAffineDimExpr(redPos, linalgOp.getContext()), |
| 344 | pos: map.getNumResults()); |
| 345 | } |
| 346 | return map; |
| 347 | } |
| 348 | |
| 349 | /// External model implementation of PartialReductionInterface for |
| 350 | /// LinalgOps. |
| 351 | template <typename LinalgOpTy> |
| 352 | struct LinalgOpPartialReductionInterface |
| 353 | : public PartialReductionOpInterface::ExternalModel< |
| 354 | LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> { |
| 355 | FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction( |
| 356 | Operation *op, OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes, |
| 357 | ArrayRef<int> reductionDims) const { |
| 358 | auto linalgOp = cast<LinalgOp>(op); |
| 359 | OpBuilder::InsertionGuard guard(b); |
| 360 | |
| 361 | if (linalgOp.hasPureBufferSemantics()) |
| 362 | return op->emitOpError(message: "expected operation to have tensor semantics" ); |
| 363 | |
| 364 | // LinalgOp implements TilingInterface. |
| 365 | auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation()); |
| 366 | SmallVector<OpFoldResult> shape = |
| 367 | llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b), |
| 368 | [](Range x) { return x.size; }); |
| 369 | |
| 370 | SmallVector<OpFoldResult> tiledShape; |
| 371 | for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { |
| 372 | if (isZeroInteger(tileSize)) { |
| 373 | tiledShape.push_back(dimSize); |
| 374 | } else { |
| 375 | tiledShape.push_back(tileSize); |
| 376 | } |
| 377 | } |
| 378 | |
| 379 | SmallVector<Value> inits; |
| 380 | for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; |
| 381 | ++initIdx) { |
| 382 | SmallVector<Operation *, 4> combinerOps; |
| 383 | if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, |
| 384 | combinerOps) || |
| 385 | combinerOps.size() != 1) |
| 386 | return op->emitOpError(message: "Failed to anaysis the reduction operation." ); |
| 387 | |
| 388 | Operation *reductionOp = combinerOps[0]; |
| 389 | std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp); |
| 390 | if (!identity.has_value()) |
| 391 | return op->emitOpError( |
| 392 | message: "Failed to get an identity value for the reduction operation." ); |
| 393 | |
| 394 | // Append the new partial result dimensions. |
| 395 | AffineMap partialMap = |
| 396 | getPartialResultAffineMap(linalgOp, reductionDims, initIdx); |
| 397 | SmallVector<OpFoldResult> partialResultShape; |
| 398 | for (AffineExpr dimExpr : partialMap.getResults()) { |
| 399 | auto dim = cast<AffineDimExpr>(dimExpr); |
| 400 | partialResultShape.push_back(tiledShape[dim.getPosition()]); |
| 401 | } |
| 402 | |
| 403 | Type elType = |
| 404 | getElementTypeOrSelf(linalgOp->getResult(initIdx).getType()); |
| 405 | Value emptyTensor = |
| 406 | b.create<tensor::EmptyOp>(loc, partialResultShape, elType); |
| 407 | Value constantOp = b.create<arith::ConstantOp>(loc, *identity); |
| 408 | auto identityTensor = |
| 409 | b.create<linalg::FillOp>(loc, constantOp, emptyTensor); |
| 410 | inits.push_back(Elt: identityTensor.getResult(0)); |
| 411 | } |
| 412 | |
| 413 | return inits; |
| 414 | } |
| 415 | |
| 416 | FailureOr<TilingResult> |
| 417 | tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, |
| 418 | ValueRange init, ArrayRef<OpFoldResult> offsets, |
| 419 | ArrayRef<OpFoldResult> sizes, |
| 420 | ArrayRef<int> reductionDims) const { |
| 421 | OpBuilder::InsertionGuard guard(b); |
| 422 | auto linalgOp = cast<LinalgOp>(op); |
| 423 | |
| 424 | // Step 1. Extend init maps to have reduction dimension dims, since we |
| 425 | // are converting them to parallel dimensions. |
| 426 | SmallVector<AffineMap> newInitMaps; |
| 427 | newInitMaps.reserve(N: linalgOp.getNumDpsInits()); |
| 428 | for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { |
| 429 | // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace |
| 430 | // this with a for range loop when we have it. |
| 431 | AffineMap newMap = |
| 432 | getPartialResultAffineMap(linalgOp, reductionDims, idx); |
| 433 | newInitMaps.push_back(newMap); |
| 434 | } |
| 435 | |
| 436 | // Step 2a: Extract a slice of the input operands. |
| 437 | SmallVector<Value> tiledInputs = makeTiledShapes( |
| 438 | b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true); |
| 439 | SmallVector<Operation *> generatedSlices = llvm::map_to_vector( |
| 440 | llvm::make_filter_range( |
| 441 | tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }), |
| 442 | [](Value v) -> Operation * { return v.getDefiningOp(); }); |
| 443 | |
| 444 | // Step 2b: Extract a slice of the init operands. |
| 445 | SmallVector<Value, 1> tiledInits; |
| 446 | for (auto [valueMap, valueToTile] : llvm::zip_equal(t&: newInitMaps, u&: init)) { |
| 447 | int64_t initRank = valueMap.getNumResults(); |
| 448 | SmallVector<OpFoldResult> initOffset(initRank, b.getIndexAttr(0)); |
| 449 | SmallVector<OpFoldResult> initStride(initRank, b.getIndexAttr(1)); |
| 450 | SmallVector<OpFoldResult> initSizes; |
| 451 | for (AffineExpr dimExpr : valueMap.getResults()) { |
| 452 | auto dim = cast<AffineDimExpr>(Val&: dimExpr); |
| 453 | initSizes.push_back(Elt: sizes[dim.getPosition()]); |
| 454 | } |
| 455 | // TODO: Use SubsetExtractOpInterface here once available. |
| 456 | auto = b.create<tensor::ExtractSliceOp>( |
| 457 | loc, valueToTile, initOffset, initSizes, initStride); |
| 458 | tiledInits.push_back(Elt: extractSlice); |
| 459 | generatedSlices.push_back(Elt: extractSlice); |
| 460 | } |
| 461 | |
| 462 | // Update the indexing maps. |
| 463 | SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray(); |
| 464 | // Change the init maps. |
| 465 | for (int idx : llvm::seq<int>(0, linalgOp.getNumDpsInits())) { |
| 466 | // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace |
| 467 | // this with a for range loop when we have it. |
| 468 | OpOperand *initOperand = linalgOp.getDpsInitOperand(idx); |
| 469 | int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand); |
| 470 | newMaps[mapIdx] = newInitMaps[idx]; |
| 471 | } |
| 472 | |
| 473 | // Step 3. Change the reduction dim iterator types. |
| 474 | SmallVector<utils::IteratorType> newIteratorTypes = |
| 475 | linalgOp.getIteratorTypesArray(); |
| 476 | for (int dim : reductionDims) |
| 477 | newIteratorTypes[dim] = utils::IteratorType::parallel; |
| 478 | |
| 479 | // Step 4. Create the new generic op. |
| 480 | auto genericOp = |
| 481 | b.create<GenericOp>(loc, ValueRange(tiledInits).getTypes(), tiledInputs, |
| 482 | tiledInits, newMaps, newIteratorTypes); |
| 483 | IRMapping mapping; |
| 484 | op->getRegion(index: 0).cloneInto(&genericOp.getRegion(), |
| 485 | genericOp.getRegion().begin(), mapping); |
| 486 | return TilingResult{ |
| 487 | {genericOp.getOperation()}, |
| 488 | llvm::map_to_vector(genericOp->getResults(), |
| 489 | [](OpResult r) -> Value { return r; }), |
| 490 | generatedSlices}; |
| 491 | } |
| 492 | |
| 493 | FailureOr<MergeResult> mergeReductions(Operation *op, OpBuilder &b, |
| 494 | Location loc, ValueRange partialReduce, |
| 495 | ArrayRef<int> reductionDims) const { |
| 496 | auto linalgOp = cast<LinalgOp>(op); |
| 497 | |
| 498 | // Permute the reduction dims as permuted by the partial result map. |
| 499 | |
| 500 | int64_t numInits = linalgOp.getNumDpsInits(); |
| 501 | SmallVector<Operation *> mergeOperations; |
| 502 | SmallVector<Value> replacements; |
| 503 | for (int idx : llvm::seq(numInits)) { |
| 504 | // linalg.reduce's iteration space is the tiled result's iteration space |
| 505 | // (and not the tiled operation's iteration space). To account for this, |
| 506 | // permute the reduction dimensions based on the partial result map of the |
| 507 | // tiled result. |
| 508 | AffineMap partialMap = |
| 509 | getPartialResultAffineMap(linalgOp, reductionDims, idx); |
| 510 | SmallVector<int64_t> partialReductionDims; |
| 511 | for (auto [resultNum, dimExpr] : |
| 512 | llvm::enumerate(partialMap.getResults())) { |
| 513 | unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); |
| 514 | if (llvm::is_contained(reductionDims, dim)) { |
| 515 | partialReductionDims.push_back(resultNum); |
| 516 | } |
| 517 | } |
| 518 | |
| 519 | Value partialResult = partialReduce[idx]; |
| 520 | Value init = linalgOp.getDpsInits()[idx]; |
| 521 | |
| 522 | auto reduction = b.create<linalg::ReduceOp>( |
| 523 | loc, partialResult, init, partialReductionDims, |
| 524 | [&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) { |
| 525 | // Get the combiner op. |
| 526 | SmallVector<Operation *, 4> combinerOps; |
| 527 | matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); |
| 528 | Operation *clonedReductionOp = b.clone(*combinerOps[0]); |
| 529 | // Combine the input at idx and output at numInits + idx. |
| 530 | clonedReductionOp->setOperand(0, inputs[0]); |
| 531 | clonedReductionOp->setOperand(1, inputs[1]); |
| 532 | b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); |
| 533 | }); |
| 534 | |
| 535 | mergeOperations.push_back(reduction); |
| 536 | replacements.push_back(reduction->getResult(0)); |
| 537 | } |
| 538 | |
| 539 | return MergeResult{.mergeOps: mergeOperations, .replacements: replacements}; |
| 540 | } |
| 541 | |
| 542 | LogicalResult getPartialResultTilePosition( |
| 543 | Operation *op, OpBuilder &b, unsigned resultNumber, |
| 544 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 545 | SmallVector<OpFoldResult> &resultOffsets, |
| 546 | SmallVector<OpFoldResult> &resultSizes, |
| 547 | ArrayRef<int> reductionDims) const { |
| 548 | auto linalgOp = cast<LinalgOp>(op); |
| 549 | |
| 550 | AffineMap partialMap = |
| 551 | getPartialResultAffineMap(linalgOp, reductionDims, resultNumber); |
| 552 | for (AffineExpr dimExpr : partialMap.getResults()) { |
| 553 | unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); |
| 554 | resultSizes.push_back(sizes[dim]); |
| 555 | |
| 556 | if (llvm::is_contained(reductionDims, dim)) { |
| 557 | // Reduction dims are reduced, and are always outputed in the same |
| 558 | // place. So use offset 0 for them. |
| 559 | resultOffsets.push_back(b.getIndexAttr(0)); |
| 560 | } else { |
| 561 | resultOffsets.push_back(offsets[dim]); |
| 562 | } |
| 563 | } |
| 564 | |
| 565 | return success(); |
| 566 | } |
| 567 | }; |
| 568 | |
| 569 | template <typename OpTy> |
| 570 | static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, |
| 571 | OpBuilder &builder) { |
| 572 | static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, |
| 573 | "applies to only pack or unpack operations" ); |
| 574 | OpBuilder::InsertionGuard g(builder); |
| 575 | int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() |
| 576 | : op.getDestRank(); |
| 577 | OpFoldResult zero = builder.getIndexAttr(0); |
| 578 | OpFoldResult one = builder.getIndexAttr(1); |
| 579 | ReifiedRankedShapedTypeDims resultShape; |
| 580 | (void)reifyResultShapes(builder, op, resultShape); |
| 581 | SmallVector<Range> loopBounds(rank); |
| 582 | for (auto dim : llvm::seq<int64_t>(Begin: 0, End: rank)) { |
| 583 | loopBounds[dim].offset = zero; |
| 584 | loopBounds[dim].stride = one; |
| 585 | loopBounds[dim].size = resultShape[0][dim]; |
| 586 | } |
| 587 | return loopBounds; |
| 588 | } |
| 589 | |
| 590 | static void applyPermToRange(SmallVector<OpFoldResult> &offsets, |
| 591 | SmallVector<OpFoldResult> &sizes, |
| 592 | ArrayRef<int64_t> permutation) { |
| 593 | if (permutation.empty()) |
| 594 | return; |
| 595 | applyPermutationToVector<OpFoldResult>(inVec&: offsets, permutation); |
| 596 | applyPermutationToVector<OpFoldResult>(inVec&: sizes, permutation); |
| 597 | } |
| 598 | |
| 599 | struct PackOpTiling |
| 600 | : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> { |
| 601 | |
| 602 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 603 | // Note that here we only consider untiled dimensions and outer tiled data |
| 604 | // dimensions, the inner tiled data dimensions are materialized when |
| 605 | // building the body of the operation. |
| 606 | auto packOp = cast<PackOp>(op); |
| 607 | SmallVector<utils::IteratorType> iteratorTypes( |
| 608 | packOp.getSourceRank(), utils::IteratorType::parallel); |
| 609 | return iteratorTypes; |
| 610 | } |
| 611 | |
| 612 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 613 | return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); |
| 614 | } |
| 615 | |
| 616 | FailureOr<TilingResult> |
| 617 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 618 | ArrayRef<OpFoldResult> offsets, |
| 619 | ArrayRef<OpFoldResult> sizes) const { |
| 620 | auto packOp = cast<PackOp>(op); |
| 621 | Location loc = packOp.getLoc(); |
| 622 | |
| 623 | // The tiling is applied on interchanged dimensions. We have to undo the |
| 624 | // interchange to map sizes and offsets to the original input. |
| 625 | int64_t inputRank = packOp.getSourceRank(); |
| 626 | SmallVector<OpFoldResult> origOffsets(offsets); |
| 627 | SmallVector<OpFoldResult> origSizes(sizes); |
| 628 | applyPermToRange(origOffsets, origSizes, |
| 629 | invertPermutationVector(packOp.getOuterDimsPerm())); |
| 630 | |
| 631 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 632 | packOp.getDimAndTileMapping(); |
| 633 | SmallVector<OpFoldResult> srcDimValues = |
| 634 | tensor::getMixedSizes(builder&: b, loc, value: packOp.getSource()); |
| 635 | SmallVector<OpFoldResult> inputIndices, inputSizes; |
| 636 | for (auto dim : llvm::seq<int64_t>(0, inputRank)) { |
| 637 | using AV = affine::AffineValueExpr; |
| 638 | affine::AffineBuilder ab(b, loc); |
| 639 | AffineExpr dim0, dim1, sym; |
| 640 | bindDims(b.getContext(), dim0, dim1); |
| 641 | bindSymbols(b.getContext(), sym); |
| 642 | if (dimAndTileMapping.count(dim)) { |
| 643 | // If the data dimension is tiled, the i-th index is the product of |
| 644 | // offset_i and tile_i, and the i-th size is the product of sizes_i and |
| 645 | // tile_i. |
| 646 | auto avOffset = AV(dim0).bind(origOffsets[dim]); |
| 647 | auto avSize = AV(dim0).bind(origSizes[dim]); |
| 648 | auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); |
| 649 | inputIndices.push_back(ab.mul(avOffset, avTileSize)); |
| 650 | inputSizes.push_back(ab.mul(avSize, avTileSize)); |
| 651 | } else { |
| 652 | inputIndices.push_back(origOffsets[dim]); |
| 653 | inputSizes.push_back(origSizes[dim]); |
| 654 | } |
| 655 | |
| 656 | // Limit the size of the input operand for incomplete tiles. |
| 657 | if (packOp.getPaddingValue()) { |
| 658 | OpFoldResult dimSize = srcDimValues[dim]; |
| 659 | auto avDimSize = AV(dim0).bind(dimSize); |
| 660 | auto avInputIdx = AV(dim1).bind(inputIndices.back()); |
| 661 | inputSizes.back() = |
| 662 | ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); |
| 663 | } |
| 664 | } |
| 665 | |
| 666 | auto oneAttr = b.getI64IntegerAttr(1); |
| 667 | SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
| 668 | |
| 669 | SmallVector<Value> tiledOperands; |
| 670 | auto sourceSlice = b.create<tensor::ExtractSliceOp>( |
| 671 | loc, packOp.getSource(), inputIndices, inputSizes, strides); |
| 672 | tiledOperands.push_back(Elt: sourceSlice); |
| 673 | |
| 674 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| 675 | if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets, sizes, resultOffsets&: outputOffsets, |
| 676 | resultSizes&: outputSizes))) |
| 677 | return {}; |
| 678 | |
| 679 | strides.append(packOp.getDestRank() - inputRank, oneAttr); |
| 680 | auto outSlice = b.create<tensor::ExtractSliceOp>( |
| 681 | loc, packOp.getDest(), outputOffsets, outputSizes, strides); |
| 682 | tiledOperands.push_back(Elt: outSlice); |
| 683 | |
| 684 | if (auto val = packOp.getPaddingValue()) |
| 685 | tiledOperands.push_back(Elt: val); |
| 686 | for (auto tile : packOp.getInnerTiles()) |
| 687 | tiledOperands.push_back(tile); |
| 688 | |
| 689 | Operation *tiledPackOp = b.create<PackOp>( |
| 690 | loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); |
| 691 | |
| 692 | return TilingResult{ |
| 693 | .tiledOps: {tiledPackOp}, |
| 694 | .tiledValues: SmallVector<Value>(tiledPackOp->getResults()), |
| 695 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})}; |
| 696 | } |
| 697 | |
| 698 | LogicalResult |
| 699 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 700 | ArrayRef<OpFoldResult> offsets, |
| 701 | ArrayRef<OpFoldResult> sizes, |
| 702 | SmallVector<OpFoldResult> &resultOffsets, |
| 703 | SmallVector<OpFoldResult> &resultSizes) const { |
| 704 | // The iteration domain is over outer dimensions of packed layout. In this |
| 705 | // context, the outer dimensions of `resultOffsets` are `offsets`. The |
| 706 | // inner dimensions of `resultOffsets` are zeros because tiling is not |
| 707 | // applied to them. |
| 708 | auto packOp = cast<PackOp>(op); |
| 709 | int64_t inputRank = packOp.getSourceRank(); |
| 710 | int64_t outputRank = packOp.getDestRank(); |
| 711 | auto zeroAttr = b.getI64IntegerAttr(0); |
| 712 | resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end()); |
| 713 | resultOffsets.append(outputRank - inputRank, zeroAttr); |
| 714 | |
| 715 | ReifiedRankedShapedTypeDims outputShape; |
| 716 | (void)reifyResultShapes(b, packOp, outputShape); |
| 717 | resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end()); |
| 718 | for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) |
| 719 | resultSizes.push_back(outputShape[0][dataTileDim]); |
| 720 | |
| 721 | return success(); |
| 722 | } |
| 723 | |
| 724 | FailureOr<TilingResult> |
| 725 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 726 | ArrayRef<OpFoldResult> offsets, |
| 727 | ArrayRef<OpFoldResult> sizes) const { |
| 728 | auto packOp = cast<PackOp>(op); |
| 729 | int64_t numTiles = packOp.getInnerDimsPos().size(); |
| 730 | |
| 731 | // tensor.pack op is fusible (as a producer) only if full inner tiles are |
| 732 | // iterated or inner dims are not tiled. Otherwise, it will generate a |
| 733 | // sequence of non-trivial ops (for partial tiles). |
| 734 | for (auto offset : offsets.take_back(numTiles)) |
| 735 | if (!isZeroInteger(offset)) |
| 736 | return failure(); |
| 737 | |
| 738 | for (auto iter : |
| 739 | llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) |
| 740 | if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) |
| 741 | return failure(); |
| 742 | |
| 743 | FailureOr<TilingResult> tilingResult = getTiledImplementation( |
| 744 | op, b, offsets: offsets.drop_back(N: numTiles), sizes: sizes.drop_back(N: numTiles)); |
| 745 | if (failed(Result: tilingResult)) |
| 746 | return failure(); |
| 747 | return tilingResult.value(); |
| 748 | } |
| 749 | |
| 750 | /// Method to return the position of iteration domain tile computed by the |
| 751 | /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and |
| 752 | /// `resultSizes` only cover outer dimensions. |
| 753 | LogicalResult getIterationDomainTileFromOperandTile( |
| 754 | Operation *op, OpBuilder &b, unsigned operandNumber, |
| 755 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 756 | SmallVectorImpl<OpFoldResult> &resultOffsets, |
| 757 | SmallVectorImpl<OpFoldResult> &resultSizes) const { |
| 758 | if (operandNumber != 0) |
| 759 | return failure(); |
| 760 | |
| 761 | auto packOp = cast<PackOp>(op); |
| 762 | // It is not trivial to infer dest tile from source tile if `packOp` has |
| 763 | // padding semantic. |
| 764 | if (packOp.getPaddingValue()) |
| 765 | return failure(); |
| 766 | |
| 767 | Location loc = packOp.getLoc(); |
| 768 | |
| 769 | SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; |
| 770 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 771 | packOp.getDimAndTileMapping(); |
| 772 | for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) { |
| 773 | if (dimAndTileMapping.count(dim)) { |
| 774 | FailureOr<int64_t> cstSize = |
| 775 | ValueBoundsConstraintSet::computeConstantBound( |
| 776 | presburger::BoundType::UB, sizes[dim], |
| 777 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
| 778 | std::optional<int64_t> cstInnerSize = |
| 779 | getConstantIntValue(dimAndTileMapping[dim]); |
| 780 | // Currently fusing `packOp` as consumer only expects perfect tiling |
| 781 | // scenario because even if without padding semantic, the `packOp` may |
| 782 | // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, |
| 783 | // where the `tileSize` from operand of `packOp` is 5, which is not |
| 784 | // exactly divided by `innerTile`(=6) of `packOp`. As the result: |
| 785 | // 1. the first slice is extracted from (0) to (4) and inserted into |
| 786 | // (0,0)~(0,4) at first row. |
| 787 | // 2. the second slice is extracted from (5) to (9) and SHOULD BE |
| 788 | // respectively inserted into two rows with different length, including |
| 789 | // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate |
| 790 | // them, thus adding below constraint to bypass them temporarily. In |
| 791 | // another word, we can only support tiling with consumer if the tile |
| 792 | // size for the producer is a multiple of the inner tile size for the |
| 793 | // packed dimensions at this moment. |
| 794 | if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { |
| 795 | return failure(); |
| 796 | } |
| 797 | |
| 798 | using AV = affine::AffineValueExpr; |
| 799 | affine::AffineBuilder ab(b, loc); |
| 800 | AffineExpr dim0, sym; |
| 801 | bindDims(b.getContext(), dim0); |
| 802 | bindSymbols(b.getContext(), sym); |
| 803 | auto avOffset = AV(dim0).bind(offsets[dim]); |
| 804 | auto avSize = AV(dim0).bind(sizes[dim]); |
| 805 | auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); |
| 806 | outerDimOffsets.push_back(ab.floor(avOffset, avTileSize)); |
| 807 | outerDimSizes.push_back(ab.ceil(avSize, avTileSize)); |
| 808 | } else { |
| 809 | outerDimOffsets.push_back(offsets[dim]); |
| 810 | outerDimSizes.push_back(sizes[dim]); |
| 811 | } |
| 812 | } |
| 813 | applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm()); |
| 814 | resultOffsets = outerDimOffsets; |
| 815 | resultSizes = outerDimSizes; |
| 816 | return success(); |
| 817 | } |
| 818 | |
| 819 | /// Method to return the tiled implementation of tensor.pack as a consumer. |
| 820 | FailureOr<TilingResult> getTiledImplementationFromOperandTile( |
| 821 | Operation *op, OpBuilder &b, unsigned operandNumber, |
| 822 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { |
| 823 | if (operandNumber != 0) |
| 824 | return failure(); |
| 825 | |
| 826 | auto packOp = cast<PackOp>(op); |
| 827 | Location loc = packOp.getLoc(); |
| 828 | |
| 829 | int64_t inputRank = packOp.getSourceRank(); |
| 830 | auto oneAttr = b.getI64IntegerAttr(1); |
| 831 | SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
| 832 | |
| 833 | SmallVector<Value> tiledOperands; |
| 834 | auto sourceSlice = b.create<tensor::ExtractSliceOp>( |
| 835 | loc, packOp.getSource(), offsets, sizes, strides); |
| 836 | tiledOperands.push_back(Elt: sourceSlice); |
| 837 | |
| 838 | SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; |
| 839 | if (failed(Result: getIterationDomainTileFromOperandTile( |
| 840 | op, b, /*operandNumber=*/0, offsets, sizes, resultOffsets&: outerDimOffsets, |
| 841 | resultSizes&: outerDimSizes))) |
| 842 | return failure(); |
| 843 | |
| 844 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| 845 | if (failed(Result: getResultTilePosition(op, b, resultNumber: 0, offsets: outerDimOffsets, sizes: outerDimSizes, |
| 846 | resultOffsets&: outputOffsets, resultSizes&: outputSizes))) |
| 847 | return failure(); |
| 848 | |
| 849 | strides.append(packOp.getDestRank() - inputRank, oneAttr); |
| 850 | auto outSlice = b.create<tensor::ExtractSliceOp>( |
| 851 | loc, packOp.getDest(), outputOffsets, outputSizes, strides); |
| 852 | tiledOperands.push_back(Elt: outSlice); |
| 853 | |
| 854 | assert(!packOp.getPaddingValue() && "Expect no padding semantic" ); |
| 855 | for (auto tile : packOp.getInnerTiles()) |
| 856 | tiledOperands.push_back(tile); |
| 857 | |
| 858 | Operation *tiledPackOp = b.create<PackOp>( |
| 859 | loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); |
| 860 | |
| 861 | return TilingResult{ |
| 862 | .tiledOps: {tiledPackOp}, |
| 863 | .tiledValues: SmallVector<Value>(tiledPackOp->getResults()), |
| 864 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{sourceSlice, outSlice})}; |
| 865 | } |
| 866 | }; |
| 867 | |
| 868 | struct UnpackTileDimInfo { |
| 869 | bool isAlignedToInnerTileSize; |
| 870 | OpFoldResult sourceOffset; |
| 871 | OpFoldResult sourceSize; |
| 872 | OpFoldResult resultOffset; |
| 873 | OpFoldResult destExpandedSize; |
| 874 | }; |
| 875 | |
| 876 | /// Returns the needed information for tiling unpack op on `tileDim` with given |
| 877 | /// `tileOffset` and `tileSize`. For more details, see the comment of the |
| 878 | /// `getTiledImplementation`. |
| 879 | static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, |
| 880 | int64_t tileDim, |
| 881 | OpFoldResult tileOffset, |
| 882 | OpFoldResult tileSize) { |
| 883 | UnpackTileDimInfo info; |
| 884 | Attribute zeroAttr = b.getIndexAttr(0); |
| 885 | Attribute oneAttr = b.getIndexAttr(1); |
| 886 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 887 | unpackOp.getDimAndTileMapping(); |
| 888 | // The dimension is not one of packed data dimension. |
| 889 | if (!dimAndTileMapping.count(Val: tileDim)) { |
| 890 | info.isAlignedToInnerTileSize = true; |
| 891 | info.sourceOffset = tileOffset; |
| 892 | info.sourceSize = tileSize; |
| 893 | info.resultOffset = zeroAttr; |
| 894 | info.destExpandedSize = tileSize; |
| 895 | return info; |
| 896 | } |
| 897 | |
| 898 | Location loc = unpackOp.getLoc(); |
| 899 | using AV = affine::AffineValueExpr; |
| 900 | affine::AffineBuilder ab(b, loc); |
| 901 | AffineExpr dim0, dim1, sym0; |
| 902 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
| 903 | bindSymbols(ctx: b.getContext(), exprs&: sym0); |
| 904 | |
| 905 | OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; |
| 906 | |
| 907 | info.isAlignedToInnerTileSize = false; |
| 908 | FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( |
| 909 | type: presburger::BoundType::UB, var: tileSize, |
| 910 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
| 911 | std::optional<int64_t> cstInnerSize = getConstantIntValue(ofr: innerTileSize); |
| 912 | if (!failed(Result: cstSize) && cstInnerSize) { |
| 913 | if (*cstSize % *cstInnerSize == 0) |
| 914 | info.isAlignedToInnerTileSize = true; |
| 915 | |
| 916 | // If the tiling size equals to the inner tiling size, the outer dims are |
| 917 | // always 1. |
| 918 | if (*cstInnerSize == *cstSize) { |
| 919 | auto lhs = AV(dim0).bind(v: tileOffset); |
| 920 | auto rhs = AV(dim1).bind(v: innerTileSize); |
| 921 | info.sourceOffset = ab.floor(lhs, rhs: rhs); |
| 922 | info.sourceSize = oneAttr; |
| 923 | info.resultOffset = zeroAttr; |
| 924 | info.destExpandedSize = tileSize; |
| 925 | return info; |
| 926 | } |
| 927 | } |
| 928 | |
| 929 | if (info.isAlignedToInnerTileSize) { |
| 930 | info.sourceOffset = |
| 931 | ab.floor(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: innerTileSize)); |
| 932 | info.resultOffset = zeroAttr; |
| 933 | info.destExpandedSize = tileSize; |
| 934 | |
| 935 | // The ceilDiv is needed here because there could be incomplete tile even |
| 936 | // it is perfect tiling cases. E.g., |
| 937 | // %0 = unpack tensor<33x2xf32> into tensor<64xf32> |
| 938 | // If the tiling size is 32, there will be 3 tiles. Two of them have |
| 939 | // size=32; one of them have size=2. The size is represented using |
| 940 | // affine_min op; we need ceilDiv. |
| 941 | info.sourceSize = |
| 942 | ab.ceil(lhs: AV(dim0).bind(v: tileSize), rhs: AV(dim1).bind(v: innerTileSize)); |
| 943 | return info; |
| 944 | } |
| 945 | |
| 946 | affine::DivModValue firstCoord = affine::getDivMod( |
| 947 | b, loc, lhs: getValueOrCreateConstantIndexOp(b, loc, ofr: tileOffset), |
| 948 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
| 949 | OpFoldResult tileExclusiveBound = |
| 950 | ab.add(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: tileSize)); |
| 951 | affine::DivModValue lastCoord = affine::getDivMod( |
| 952 | b, loc, |
| 953 | lhs: getValueOrCreateConstantIndexOp( |
| 954 | b, loc, |
| 955 | ofr: ab.sub(lhs: AV(dim0).bind(v: tileExclusiveBound), rhs: AV(dim1).bind(v: oneAttr))), |
| 956 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
| 957 | |
| 958 | OpFoldResult lengthMinusOne = ab.sub(lhs: AV(dim0).bind(v: lastCoord.quotient), |
| 959 | rhs: AV(dim1).bind(v: firstCoord.quotient)); |
| 960 | info.sourceSize = |
| 961 | ab.add(lhs: AV(dim0).bind(v: lengthMinusOne), rhs: AV(dim1).bind(v: oneAttr)); |
| 962 | info.sourceOffset = firstCoord.quotient; |
| 963 | info.resultOffset = firstCoord.remainder; |
| 964 | // Do not create an Affine ops for expanded size because the affine op is too |
| 965 | // complicated which would trigger an issue in affine ops simplification. |
| 966 | info.destExpandedSize = b.createOrFold<arith::MulIOp>( |
| 967 | loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), |
| 968 | getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); |
| 969 | return info; |
| 970 | } |
| 971 | |
| 972 | struct UnPackOpTiling |
| 973 | : public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> { |
| 974 | |
| 975 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 976 | auto unpackOp = cast<UnPackOp>(op); |
| 977 | SmallVector<utils::IteratorType> iteratorTypes( |
| 978 | unpackOp.getDestRank(), utils::IteratorType::parallel); |
| 979 | return iteratorTypes; |
| 980 | } |
| 981 | |
| 982 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 983 | return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); |
| 984 | } |
| 985 | |
| 986 | /// There are two cases in tiling unpack ops. If the tiling size is aligned to |
| 987 | /// the inner tile size, the corresponding tiles of source are all complete. |
| 988 | /// Otherwise, there are in-complete tiles. We will need to expand the slice |
| 989 | /// of source for getting complete tiles. The tiled unpack op unpacks more |
| 990 | /// data from source, so We'll need an extract_slice op to shift and truncate |
| 991 | /// the output. |
| 992 | /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The |
| 993 | /// coordinates of second tile (i.e., result[15..31]) are |
| 994 | /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last |
| 995 | /// row are incomplete tiles. To represent the unpack op, we have to complete |
| 996 | /// the rows. I.e., the input coordinates would start with (1, 0); end with |
| 997 | /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements |
| 998 | /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we |
| 999 | /// can get the actual result. |
| 1000 | FailureOr<TilingResult> |
| 1001 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 1002 | ArrayRef<OpFoldResult> offsets, |
| 1003 | ArrayRef<OpFoldResult> sizes) const { |
| 1004 | auto unpackOp = cast<UnPackOp>(op); |
| 1005 | int64_t srcRank = unpackOp.getSourceRank(); |
| 1006 | int64_t destRank = unpackOp.getDestRank(); |
| 1007 | int64_t numInnerTiles = srcRank - destRank; |
| 1008 | Location loc = unpackOp.getLoc(); |
| 1009 | |
| 1010 | // The perfect tiling case indicates that the tiling sizes are multiple of |
| 1011 | // inner_tile_size. In this context, no extra data is needed when |
| 1012 | // representing the tiled unpack op. |
| 1013 | bool isPerfectTilingCase = true; |
| 1014 | Attribute oneAttr = b.getIndexAttr(1); |
| 1015 | SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); |
| 1016 | SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; |
| 1017 | SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; |
| 1018 | for (auto dim : llvm::seq<int64_t>(0, destRank)) { |
| 1019 | UnpackTileDimInfo info = |
| 1020 | getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); |
| 1021 | if (!info.isAlignedToInnerTileSize) |
| 1022 | isPerfectTilingCase = false; |
| 1023 | sliceSrcIndices.push_back(info.sourceOffset); |
| 1024 | sliceSrcSizes.push_back(info.sourceSize); |
| 1025 | destExpandedSizes.push_back(info.destExpandedSize); |
| 1026 | resultOffsetsFromDest.push_back(info.resultOffset); |
| 1027 | } |
| 1028 | |
| 1029 | // The tiling is applied on destination dimensions. We have to apply the |
| 1030 | // interchange on source dimensions if outer_dims_perm is set. |
| 1031 | applyPermToRange(sliceSrcIndices, sliceSrcSizes, |
| 1032 | unpackOp.getOuterDimsPerm()); |
| 1033 | Attribute zeroAttr = b.getIndexAttr(0); |
| 1034 | sliceSrcIndices.append(NumInputs: numInnerTiles, Elt: zeroAttr); |
| 1035 | sliceSrcSizes.append(unpackOp.getMixedTiles()); |
| 1036 | sliceSrcStrides.append(NumInputs: numInnerTiles, Elt: oneAttr); |
| 1037 | SmallVector<Operation *> generatedSlices; |
| 1038 | tensor::ExtractSliceOp sliceSource = b.create<tensor::ExtractSliceOp>( |
| 1039 | loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes, |
| 1040 | sliceSrcStrides); |
| 1041 | generatedSlices.push_back(Elt: sliceSource); |
| 1042 | |
| 1043 | SmallVector<OpFoldResult> destStrides(destRank, oneAttr); |
| 1044 | Value sliceDest; |
| 1045 | if (isPerfectTilingCase) { |
| 1046 | auto destSliceOp = b.create<tensor::ExtractSliceOp>( |
| 1047 | loc, unpackOp.getDest(), offsets, sizes, destStrides); |
| 1048 | sliceDest = destSliceOp; |
| 1049 | generatedSlices.push_back(Elt: destSliceOp); |
| 1050 | } else { |
| 1051 | sliceDest = b.create<tensor::EmptyOp>( |
| 1052 | loc, destExpandedSizes, unpackOp.getDestType().getElementType()); |
| 1053 | } |
| 1054 | |
| 1055 | SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest}; |
| 1056 | for (auto tile : unpackOp.getInnerTiles()) |
| 1057 | tiledOperands.push_back(tile); |
| 1058 | |
| 1059 | Operation *tiledUnpackOp = b.create<UnPackOp>( |
| 1060 | loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); |
| 1061 | |
| 1062 | if (isPerfectTilingCase) |
| 1063 | return TilingResult{.tiledOps: {tiledUnpackOp}, |
| 1064 | .tiledValues: SmallVector<Value>(tiledUnpackOp->getResults()), |
| 1065 | .generatedSlices: generatedSlices}; |
| 1066 | |
| 1067 | auto = b.create<tensor::ExtractSliceOp>( |
| 1068 | loc, tiledUnpackOp->getResult(idx: 0), resultOffsetsFromDest, sizes, |
| 1069 | destStrides); |
| 1070 | return TilingResult{ |
| 1071 | {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; |
| 1072 | } |
| 1073 | |
| 1074 | LogicalResult |
| 1075 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 1076 | ArrayRef<OpFoldResult> offsets, |
| 1077 | ArrayRef<OpFoldResult> sizes, |
| 1078 | SmallVector<OpFoldResult> &resultOffsets, |
| 1079 | SmallVector<OpFoldResult> &resultSizes) const { |
| 1080 | resultOffsets = llvm::to_vector(Range&: offsets); |
| 1081 | resultSizes = llvm::to_vector(Range&: sizes); |
| 1082 | return success(); |
| 1083 | } |
| 1084 | |
| 1085 | FailureOr<TilingResult> |
| 1086 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 1087 | ArrayRef<OpFoldResult> offsets, |
| 1088 | ArrayRef<OpFoldResult> sizes) const { |
| 1089 | FailureOr<TilingResult> tilingResult = |
| 1090 | getTiledImplementation(op, b, offsets, sizes); |
| 1091 | if (failed(Result: tilingResult)) |
| 1092 | return failure(); |
| 1093 | return tilingResult.value(); |
| 1094 | } |
| 1095 | |
| 1096 | /// Method to return the position of iteration domain tile computed by the |
| 1097 | /// tiled operation. |
| 1098 | LogicalResult getIterationDomainTileFromOperandTile( |
| 1099 | Operation *op, OpBuilder &b, unsigned operandNumber, |
| 1100 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 1101 | SmallVectorImpl<OpFoldResult> &resultOffsets, |
| 1102 | SmallVectorImpl<OpFoldResult> &resultSizes) const { |
| 1103 | auto unPackOp = cast<UnPackOp>(op); |
| 1104 | // If the operand tile is the dest, then no adjustment is needed. |
| 1105 | if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) { |
| 1106 | resultOffsets = llvm::to_vector(Range&: offsets); |
| 1107 | resultSizes = llvm::to_vector(Range&: sizes); |
| 1108 | return success(); |
| 1109 | } |
| 1110 | Location loc = unPackOp.getLoc(); |
| 1111 | |
| 1112 | int64_t numTiles = unPackOp.getInnerDimsPos().size(); |
| 1113 | auto destOffsets = offsets.drop_back(N: numTiles); |
| 1114 | auto destSizes = sizes.drop_back(N: numTiles); |
| 1115 | // The tiling is applied on interchanged dimensions. We have to undo the |
| 1116 | // interchange to map sizes and offsets to the original input. |
| 1117 | int64_t outputRank = unPackOp.getDestRank(); |
| 1118 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
| 1119 | if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes))) |
| 1120 | return failure(); |
| 1121 | SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front(); |
| 1122 | SmallVector<OpFoldResult> origOffsets(destOffsets); |
| 1123 | SmallVector<OpFoldResult> origSizes(destSizes); |
| 1124 | applyPermToRange(origOffsets, origSizes, |
| 1125 | invertPermutationVector(unPackOp.getOuterDimsPerm())); |
| 1126 | |
| 1127 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
| 1128 | unPackOp.getDimAndTileMapping(); |
| 1129 | |
| 1130 | for (auto dim : llvm::seq<int64_t>(0, outputRank)) { |
| 1131 | using AV = affine::AffineValueExpr; |
| 1132 | affine::AffineBuilder ab(b, loc); |
| 1133 | AffineExpr dim0, dim1, sym0; |
| 1134 | bindDims(b.getContext(), dim0, dim1); |
| 1135 | bindSymbols(b.getContext(), sym0); |
| 1136 | if (dimAndTileMapping.count(dim)) { |
| 1137 | // If the data dimension is tiled, the i-th index is the product of |
| 1138 | // offset_i and tile_i, and the i-th size is the product of sizes_i and |
| 1139 | // tile_i. The sizes must be clamped to the sizes of the unpack result. |
| 1140 | auto avOffset = AV(dim0).bind(origOffsets[dim]); |
| 1141 | auto avSize = AV(dim0).bind(origSizes[dim]); |
| 1142 | auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]); |
| 1143 | auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]); |
| 1144 | resultOffsets.push_back(ab.mul(avOffset, avTileSize)); |
| 1145 | auto avResultOffset = AV(dim1).bind(resultOffsets.back()); |
| 1146 | resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize), |
| 1147 | ab.sub(avResultSize, avResultOffset)})); |
| 1148 | } else { |
| 1149 | resultOffsets.push_back(origOffsets[dim]); |
| 1150 | resultSizes.push_back(origSizes[dim]); |
| 1151 | } |
| 1152 | } |
| 1153 | return success(); |
| 1154 | } |
| 1155 | |
| 1156 | /// Method to return the tiled implementation of tensor.unpack as a consumer. |
| 1157 | FailureOr<TilingResult> getTiledImplementationFromOperandTile( |
| 1158 | Operation *op, OpBuilder &b, unsigned operandNumber, |
| 1159 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const { |
| 1160 | auto unPackOp = cast<UnPackOp>(op); |
| 1161 | // tensor.unpack op is fusible (as a consumer) only if inner dims are not |
| 1162 | // tiled. |
| 1163 | int64_t numTiles = unPackOp.getInnerDimsPos().size(); |
| 1164 | for (auto iter : |
| 1165 | llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) { |
| 1166 | if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) |
| 1167 | return failure(); |
| 1168 | } |
| 1169 | |
| 1170 | Location loc = unPackOp.getLoc(); |
| 1171 | |
| 1172 | // Fetch offset/size for creating the slice of the dest operand of |
| 1173 | // unpack op. |
| 1174 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| 1175 | if (failed(Result: getIterationDomainTileFromOperandTile( |
| 1176 | op, b, /*operandNumber=*/0, offsets, sizes, resultOffsets&: outputOffsets, |
| 1177 | resultSizes&: outputSizes))) |
| 1178 | return failure(); |
| 1179 | |
| 1180 | auto oneAttr = b.getI64IntegerAttr(1); |
| 1181 | int64_t outputRank = unPackOp.getDestRank(); |
| 1182 | SmallVector<OpFoldResult> strides(outputRank, oneAttr); |
| 1183 | |
| 1184 | SmallVector<Value> tiledOperands; |
| 1185 | // Create slice of the dest operand. |
| 1186 | auto = b.create<tensor::ExtractSliceOp>( |
| 1187 | loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); |
| 1188 | tiledOperands.push_back(Elt: extractDestSlice); |
| 1189 | |
| 1190 | strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); |
| 1191 | // Create slice of the source operand. |
| 1192 | auto = b.create<tensor::ExtractSliceOp>( |
| 1193 | loc, unPackOp.getSource(), offsets, sizes, strides); |
| 1194 | tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); |
| 1195 | for (auto tile : unPackOp.getInnerTiles()) |
| 1196 | tiledOperands.push_back(tile); |
| 1197 | |
| 1198 | // Create tiled unpack op. |
| 1199 | Operation *tiledUnPackOp = |
| 1200 | b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()}, |
| 1201 | tiledOperands, op->getAttrs()); |
| 1202 | |
| 1203 | return TilingResult{.tiledOps: {tiledUnPackOp}, |
| 1204 | .tiledValues: SmallVector<Value>(tiledUnPackOp->getResults()), |
| 1205 | .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{ |
| 1206 | extractSourceSlice, extractDestSlice})}; |
| 1207 | } |
| 1208 | }; |
| 1209 | |
| 1210 | } // namespace |
| 1211 | |
| 1212 | template <typename OpType> |
| 1213 | static void registerOne(MLIRContext *ctx) { |
| 1214 | OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx); |
| 1215 | OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>( |
| 1216 | *ctx); |
| 1217 | } |
| 1218 | |
| 1219 | /// Variadic helper function. |
| 1220 | template <typename... OpTypes> |
| 1221 | static void registerAll(MLIRContext *ctx) { |
| 1222 | (registerOne<OpTypes>(ctx), ...); |
| 1223 | } |
| 1224 | |
| 1225 | #define GET_OP_LIST |
| 1226 | |
| 1227 | void mlir::linalg::registerTilingInterfaceExternalModels( |
| 1228 | DialectRegistry ®istry) { |
| 1229 | registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { |
| 1230 | registerOne<linalg::GenericOp>(ctx); |
| 1231 | linalg::PackOp::attachInterface<PackOpTiling>(*ctx); |
| 1232 | linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); |
| 1233 | registerAll< |
| 1234 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
| 1235 | >(ctx); |
| 1236 | }); |
| 1237 | } |
| 1238 | |
| 1239 | void mlir::linalg::registerTilingInterfaceExternalModelsForPackUnPackOps( |
| 1240 | DialectRegistry ®istry) { |
| 1241 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, LinalgDialect *dialect) { |
| 1242 | linalg::PackOp::attachInterface<PackOpTiling>(*ctx); |
| 1243 | linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); |
| 1244 | }); |
| 1245 | } |
| 1246 | |