1//===- Split.cpp - Structured op splitting --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/IR/AffineOps.h"
10#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11#include "mlir/Dialect/Utils/StaticValueUtils.h"
12#include "mlir/IR/AffineExpr.h"
13#include "mlir/IR/Attributes.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/Interfaces/TilingInterface.h"
17
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVector.h"
20
21using namespace mlir;
22using namespace mlir::linalg;
23
24/// Creates a part of the given `op` split along the iteration space `dimension`
25/// with the given `size` and an optional `offset` (default 0). Makes slices
26/// of operands, using the input operands of the original op and the output
27/// operands provided as `resultOperands`. Expects `offsets` and `sizes` to
28/// define the shape of the iteration space of the original op. Returns the
29/// split-out op as well as the output operand values updated with the partial
30/// results produced by this op through `results`.
31static TilingInterface
32createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
33 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
34 ValueRange resultOperands, unsigned dimension,
35 OpFoldResult size, OpFoldResult offset,
36 SmallVectorImpl<Value> &results) {
37 // Iteration space of the current part.
38 SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(Range&: sizes);
39 SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(Range&: offsets);
40 sizesCopy[dimension] = size;
41 offsetsCopy[dimension] = offset;
42
43 // Create the part as it it were a single tile.
44 FailureOr<TilingResult> tilingResult =
45 op.getTiledImplementation(b, offsetsCopy, sizesCopy);
46
47 // Insert the results back and populate the `results` list.
48 for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
49 SmallVector<OpFoldResult> resultOffsets, resultSizes;
50 if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
51 resultOffsets, resultSizes)))
52 return nullptr;
53 SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
54 b.getIndexAttr(1));
55 Value inserted = b.create<tensor::InsertSliceOp>(
56 loc, result, resultOperands[index], resultOffsets, resultSizes,
57 resultStrides);
58 results.push_back(inserted);
59 }
60 // TODO: this part can be generalized maybe to not expect a single op.
61 assert(tilingResult->tiledOps.size() == 1 &&
62 "expected split part to return a single tiled operation");
63 return cast<TilingInterface>(tilingResult->tiledOps[0]);
64}
65
66std::pair<TilingInterface, TilingInterface>
67linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
68 OpFoldResult splitPoint) {
69 // Compute the iteration space.
70 SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
71
72 // Bail out on dimension overflow.
73 if (dimension >= iterationSpace.size())
74 return std::make_pair(op, TilingInterface());
75
76 SmallVector<OpFoldResult> offsets = llvm::to_vector(Range: llvm::map_range(
77 C&: iterationSpace, F: [](const Range &range) { return range.offset; }));
78 SmallVector<OpFoldResult> sizes = llvm::to_vector(Range: llvm::map_range(
79 C&: iterationSpace, F: [](const Range &range) { return range.size; }));
80
81 // Adjust the split point so that it doesn't overflow the size.
82 AffineExpr d0, d1, d2;
83 bindDims(ctx: rewriter.getContext(), exprs&: d0, exprs&: d1, exprs&: d2);
84 OpFoldResult minSplitPoint = affine::makeComposedFoldedAffineMin(
85 b&: rewriter, loc: op.getLoc(),
86 map: AffineMap::inferFromExprList(exprsList: ArrayRef<AffineExpr>{d0, d1 + d2},
87 context: rewriter.getContext())
88 .front(),
89 operands: {splitPoint, offsets[dimension], sizes[dimension]});
90
91 // Compute the size of the second part. Return early if the second part would
92 // have an empty iteration space.
93 OpFoldResult remainingSize = affine::makeComposedFoldedAffineApply(
94 rewriter, op.getLoc(), d0 + d1 - d2,
95 {iterationSpace[dimension].offset, iterationSpace[dimension].size,
96 minSplitPoint});
97 if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
98 if (cast<IntegerAttr>(attr).getValue().isZero())
99 return {op, TilingInterface()};
100 }
101
102 // Compute destination tensors.
103 SmallVector<Value> destinationTensors;
104 LogicalResult destStatus = tensor::getOrCreateDestinations(
105 b&: rewriter, loc: op.getLoc(), op: op, result&: destinationTensors);
106 (void)destStatus;
107 assert(succeeded(destStatus) && "failed to get destination tensors");
108
109 // Create the first part.
110 SmallVector<Value> firstResults;
111 TilingInterface firstPart = createSplitPart(
112 rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension,
113 minSplitPoint, iterationSpace[dimension].offset, firstResults);
114
115 // Need to pretend that the original op now takes as operands firstResults,
116 // otherwise tiling interface implementation will take the wrong value to
117 // produce data tiles.
118 rewriter.modifyOpInPlace(op, [&]() {
119 unsigned numTotalOperands = op->getNumOperands();
120 unsigned numOutputOperands = firstResults.size();
121 op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
122 firstResults);
123 });
124
125 // Create the second part.
126 OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply(
127 rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
128 SmallVector<Value> secondResults;
129 TilingInterface secondPart =
130 createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
131 dimension, remainingSize, totalOffset, secondResults);
132
133 // Propagate any errors in part creation.
134 if (!firstPart || !secondPart)
135 return {TilingInterface(), TilingInterface()};
136
137 // Replace the original op with the results of the two newly created ops.
138 rewriter.replaceOp(op, secondResults);
139 return {firstPart, secondPart};
140}
141

source code of mlir/lib/Dialect/Linalg/Transforms/Split.cpp