1 | //===- Split.cpp - Structured op splitting --------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
11 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
12 | #include "mlir/IR/AffineExpr.h" |
13 | #include "mlir/IR/Attributes.h" |
14 | #include "mlir/IR/BuiltinAttributes.h" |
15 | #include "mlir/IR/OpDefinition.h" |
16 | #include "mlir/Interfaces/TilingInterface.h" |
17 | |
18 | #include "llvm/ADT/STLExtras.h" |
19 | #include "llvm/ADT/SmallVector.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace mlir::linalg; |
23 | |
24 | /// Creates a part of the given `op` split along the iteration space `dimension` |
25 | /// with the given `size` and an optional `offset` (default 0). Makes slices |
26 | /// of operands, using the input operands of the original op and the output |
27 | /// operands provided as `resultOperands`. Expects `offsets` and `sizes` to |
28 | /// define the shape of the iteration space of the original op. Returns the |
29 | /// split-out op as well as the output operand values updated with the partial |
30 | /// results produced by this op through `results`. |
31 | static TilingInterface |
32 | createSplitPart(RewriterBase &b, Location loc, TilingInterface op, |
33 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
34 | ValueRange resultOperands, unsigned dimension, |
35 | OpFoldResult size, OpFoldResult offset, |
36 | SmallVectorImpl<Value> &results) { |
37 | // Iteration space of the current part. |
38 | SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(Range&: sizes); |
39 | SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(Range&: offsets); |
40 | sizesCopy[dimension] = size; |
41 | offsetsCopy[dimension] = offset; |
42 | |
43 | // Create the part as it it were a single tile. |
44 | FailureOr<TilingResult> tilingResult = |
45 | op.getTiledImplementation(b, offsetsCopy, sizesCopy); |
46 | |
47 | // Insert the results back and populate the `results` list. |
48 | for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) { |
49 | SmallVector<OpFoldResult> resultOffsets, resultSizes; |
50 | if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy, |
51 | resultOffsets, resultSizes))) |
52 | return nullptr; |
53 | SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), |
54 | b.getIndexAttr(1)); |
55 | Value inserted = b.create<tensor::InsertSliceOp>( |
56 | loc, result, resultOperands[index], resultOffsets, resultSizes, |
57 | resultStrides); |
58 | results.push_back(inserted); |
59 | } |
60 | // TODO: this part can be generalized maybe to not expect a single op. |
61 | assert(tilingResult->tiledOps.size() == 1 && |
62 | "expected split part to return a single tiled operation" ); |
63 | return cast<TilingInterface>(tilingResult->tiledOps[0]); |
64 | } |
65 | |
66 | std::pair<TilingInterface, TilingInterface> |
67 | linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, |
68 | OpFoldResult splitPoint) { |
69 | // Compute the iteration space. |
70 | SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter); |
71 | |
72 | // Bail out on dimension overflow. |
73 | if (dimension >= iterationSpace.size()) |
74 | return std::make_pair(op, TilingInterface()); |
75 | |
76 | SmallVector<OpFoldResult> offsets = llvm::to_vector(Range: llvm::map_range( |
77 | C&: iterationSpace, F: [](const Range &range) { return range.offset; })); |
78 | SmallVector<OpFoldResult> sizes = llvm::to_vector(Range: llvm::map_range( |
79 | C&: iterationSpace, F: [](const Range &range) { return range.size; })); |
80 | |
81 | // Adjust the split point so that it doesn't overflow the size. |
82 | AffineExpr d0, d1, d2; |
83 | bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2); |
84 | OpFoldResult minSplitPoint = affine::makeComposedFoldedAffineMin( |
85 | b&: rewriter, loc: op.getLoc(), |
86 | map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{d0, d1 + d2}, |
87 | context: rewriter.getContext()) |
88 | .front(), |
89 | operands: {splitPoint, offsets[dimension], sizes[dimension]}); |
90 | |
91 | // Compute the size of the second part. Return early if the second part would |
92 | // have an empty iteration space. |
93 | OpFoldResult remainingSize = affine::makeComposedFoldedAffineApply( |
94 | rewriter, op.getLoc(), d0 + d1 - d2, |
95 | {iterationSpace[dimension].offset, iterationSpace[dimension].size, |
96 | minSplitPoint}); |
97 | if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) { |
98 | if (cast<IntegerAttr>(attr).getValue().isZero()) |
99 | return {op, TilingInterface()}; |
100 | } |
101 | |
102 | // Compute destination tensors. |
103 | SmallVector<Value> destinationTensors; |
104 | LogicalResult destStatus = tensor::getOrCreateDestinations( |
105 | b&: rewriter, loc: op.getLoc(), op: op, result&: destinationTensors); |
106 | (void)destStatus; |
107 | assert(succeeded(destStatus) && "failed to get destination tensors" ); |
108 | |
109 | // Create the first part. |
110 | SmallVector<Value> firstResults; |
111 | TilingInterface firstPart = createSplitPart( |
112 | rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension, |
113 | minSplitPoint, iterationSpace[dimension].offset, firstResults); |
114 | |
115 | // Need to pretend that the original op now takes as operands firstResults, |
116 | // otherwise tiling interface implementation will take the wrong value to |
117 | // produce data tiles. |
118 | rewriter.modifyOpInPlace(op, [&]() { |
119 | unsigned numTotalOperands = op->getNumOperands(); |
120 | unsigned numOutputOperands = firstResults.size(); |
121 | op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands, |
122 | firstResults); |
123 | }); |
124 | |
125 | // Create the second part. |
126 | OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply( |
127 | rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint}); |
128 | SmallVector<Value> secondResults; |
129 | TilingInterface secondPart = |
130 | createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults, |
131 | dimension, remainingSize, totalOffset, secondResults); |
132 | |
133 | // Propagate any errors in part creation. |
134 | if (!firstPart || !secondPart) |
135 | return {TilingInterface(), TilingInterface()}; |
136 | |
137 | // Replace the original op with the results of the two newly created ops. |
138 | rewriter.replaceOp(op, secondResults); |
139 | return {firstPart, secondPart}; |
140 | } |
141 | |