| 1 | //===- Split.cpp - Structured op splitting --------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 10 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 11 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 12 | #include "mlir/IR/AffineExpr.h" |
| 13 | #include "mlir/IR/Attributes.h" |
| 14 | #include "mlir/IR/BuiltinAttributes.h" |
| 15 | #include "mlir/IR/OpDefinition.h" |
| 16 | #include "mlir/Interfaces/TilingInterface.h" |
| 17 | |
| 18 | #include "llvm/ADT/STLExtras.h" |
| 19 | #include "llvm/ADT/SmallVector.h" |
| 20 | |
| 21 | using namespace mlir; |
| 22 | using namespace mlir::linalg; |
| 23 | |
| 24 | /// Creates a part of the given `op` split along the iteration space `dimension` |
| 25 | /// with the given `size` and an optional `offset` (default 0). Makes slices |
| 26 | /// of operands, using the input operands of the original op and the output |
| 27 | /// operands provided as `resultOperands`. Expects `offsets` and `sizes` to |
| 28 | /// define the shape of the iteration space of the original op. Returns the |
| 29 | /// split-out op as well as the output operand values updated with the partial |
| 30 | /// results produced by this op through `results`. |
| 31 | static TilingInterface |
| 32 | createSplitPart(RewriterBase &b, Location loc, TilingInterface op, |
| 33 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 34 | ValueRange resultOperands, unsigned dimension, |
| 35 | OpFoldResult size, OpFoldResult offset, |
| 36 | SmallVectorImpl<Value> &results) { |
| 37 | // Iteration space of the current part. |
| 38 | SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(Range&: sizes); |
| 39 | SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(Range&: offsets); |
| 40 | sizesCopy[dimension] = size; |
| 41 | offsetsCopy[dimension] = offset; |
| 42 | |
| 43 | // Create the part as if it were a single tile. |
| 44 | FailureOr<TilingResult> tilingResult = |
| 45 | op.getTiledImplementation(b, offsetsCopy, sizesCopy); |
| 46 | |
| 47 | // Insert the results back and populate the `results` list. |
| 48 | for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) { |
| 49 | SmallVector<OpFoldResult> resultOffsets, resultSizes; |
| 50 | if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy, |
| 51 | resultOffsets, resultSizes))) |
| 52 | return nullptr; |
| 53 | SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), |
| 54 | b.getIndexAttr(1)); |
| 55 | Value inserted = b.create<tensor::InsertSliceOp>( |
| 56 | loc, result, resultOperands[index], resultOffsets, resultSizes, |
| 57 | resultStrides); |
| 58 | results.push_back(inserted); |
| 59 | } |
| 60 | // TODO: this part can be generalized maybe to not expect a single op. |
| 61 | assert(tilingResult->tiledOps.size() == 1 && |
| 62 | "expected split part to return a single tiled operation" ); |
| 63 | return cast<TilingInterface>(tilingResult->tiledOps[0]); |
| 64 | } |
| 65 | |
| 66 | std::pair<TilingInterface, TilingInterface> |
| 67 | linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, |
| 68 | OpFoldResult splitPoint) { |
| 69 | // Compute the iteration space. |
| 70 | SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter); |
| 71 | |
| 72 | // Bail out on dimension overflow. |
| 73 | if (dimension >= iterationSpace.size()) |
| 74 | return std::make_pair(op, TilingInterface()); |
| 75 | |
| 76 | SmallVector<OpFoldResult> offsets = llvm::to_vector(Range: llvm::map_range( |
| 77 | C&: iterationSpace, F: [](const Range &range) { return range.offset; })); |
| 78 | SmallVector<OpFoldResult> sizes = llvm::to_vector(Range: llvm::map_range( |
| 79 | C&: iterationSpace, F: [](const Range &range) { return range.size; })); |
| 80 | |
| 81 | // Adjust the split point so that it doesn't overflow the size. |
| 82 | AffineExpr d0, d1, d2; |
| 83 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2); |
| 84 | OpFoldResult minSplitPoint = affine::makeComposedFoldedAffineMin( |
| 85 | b&: rewriter, loc: op.getLoc(), |
| 86 | map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{d0, d1 + d2}, |
| 87 | context: rewriter.getContext()) |
| 88 | .front(), |
| 89 | operands: {splitPoint, offsets[dimension], sizes[dimension]}); |
| 90 | |
| 91 | // Compute the size of the second part. Return early if the second part would |
| 92 | // have an empty iteration space. |
| 93 | OpFoldResult remainingSize = affine::makeComposedFoldedAffineApply( |
| 94 | rewriter, op.getLoc(), d0 + d1 - d2, |
| 95 | {iterationSpace[dimension].offset, iterationSpace[dimension].size, |
| 96 | minSplitPoint}); |
| 97 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) { |
| 98 | if (cast<IntegerAttr>(attr).getValue().isZero()) |
| 99 | return {op, TilingInterface()}; |
| 100 | } |
| 101 | |
| 102 | // Compute destination tensors. |
| 103 | SmallVector<Value> destinationTensors; |
| 104 | LogicalResult destStatus = tensor::getOrCreateDestinations( |
| 105 | b&: rewriter, loc: op.getLoc(), op: op, result&: destinationTensors); |
| 106 | (void)destStatus; |
| 107 | assert(succeeded(destStatus) && "failed to get destination tensors" ); |
| 108 | |
| 109 | // Create the first part. |
| 110 | SmallVector<Value> firstResults; |
| 111 | TilingInterface firstPart = createSplitPart( |
| 112 | rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension, |
| 113 | minSplitPoint, iterationSpace[dimension].offset, firstResults); |
| 114 | |
| 115 | // Need to pretend that the original op now takes as operands firstResults, |
| 116 | // otherwise tiling interface implementation will take the wrong value to |
| 117 | // produce data tiles. |
| 118 | rewriter.modifyOpInPlace(op, [&]() { |
| 119 | unsigned numTotalOperands = op->getNumOperands(); |
| 120 | unsigned numOutputOperands = firstResults.size(); |
| 121 | op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands, |
| 122 | firstResults); |
| 123 | }); |
| 124 | |
| 125 | // Create the second part. |
| 126 | OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply( |
| 127 | rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint}); |
| 128 | SmallVector<Value> secondResults; |
| 129 | TilingInterface secondPart = |
| 130 | createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults, |
| 131 | dimension, remainingSize, totalOffset, secondResults); |
| 132 | |
| 133 | // Propagate any errors in part creation. |
| 134 | if (!firstPart || !secondPart) |
| 135 | return {TilingInterface(), TilingInterface()}; |
| 136 | |
| 137 | // Replace the original op with the results of the two newly created ops. |
| 138 | rewriter.replaceOp(op, secondResults); |
| 139 | return {firstPart, secondPart}; |
| 140 | } |
| 141 | |