1//===- Padding.cpp - Padding of Linalg ops --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12#include "mlir/Dialect/Complex/IR/Complex.h"
13#include "mlir/Dialect/Linalg/IR/Linalg.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Interfaces/ValueBoundsOpInterface.h"
16
17#define DEBUG_TYPE "linalg-padding"
18
19using namespace mlir;
20using namespace mlir::linalg;
21
22#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
23#define DBGSNL() (llvm::dbgs() << "\n")
24
25/// Compute the padded shape of the given operand. The operand is padded to a
26/// static bounding box according to the specified padding options.
27static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
28 OpOperand *opOperand,
29 const LinalgPaddingOptions &options,
30 SmallVector<int64_t> &paddedShape,
31 bool &alreadyHasRequestedShape) {
32 AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
33 ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
34
35 // Collect the shape dimensions that are a function of "paddingDimensions",
36 // along with the multiple that they should be padded to ("1" if none).
37 alreadyHasRequestedShape = true;
38 DenseMap<int64_t, int64_t> shapeDimToMultiple;
39 for (const auto &dimEn : enumerate(First: options.paddingDimensions)) {
40 for (const auto &en : enumerate(indexingMap.getResults())) {
41 if (en.value().isFunctionOfDim(dimEn.value())) {
42 int64_t dimSize = shape[en.index()];
43 if (options.padToMultipleOf.has_value()) {
44 shapeDimToMultiple[en.index()] =
45 (*options.padToMultipleOf)[dimEn.index()];
46 } else {
47 shapeDimToMultiple[en.index()] = 1;
48 }
49 if (ShapedType::isDynamic(dimSize)) {
50 alreadyHasRequestedShape = false;
51 } else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
52 alreadyHasRequestedShape = false;
53 }
54 }
55 }
56 }
57
58 // Helper function to round a number up to a given multiple.
59 auto ceil = [](int64_t val, int64_t multiple) {
60 return ((val + multiple - 1) / multiple) * multiple;
61 };
62
63 // Upper bound the sizes to obtain a static bounding box.
64 paddedShape.assign(in_start: shape.begin(), in_end: shape.end());
65 for (int64_t i = 0, e = shape.size(); i < e; ++i) {
66 LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
67 // Skip dimensions that do not require padding.
68 if (!shapeDimToMultiple.contains(Val: i)) {
69 LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
70 continue;
71 }
72 // Otherwise, try to compute a constant upper bound for the size value.
73 FailureOr<int64_t> upperBound =
74 ValueBoundsConstraintSet::computeConstantBound(
75 type: presburger::BoundType::UB,
76 var: {opOperand->get(),
77 /*dim=*/i},
78 /*stopCondition=*/nullptr, /*closedUB=*/true);
79 if (failed(result: upperBound)) {
80 LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
81 return failure();
82 }
83 paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
84 LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
85 }
86
87 return success();
88}
89
90/// Pad the `opOperand` in the "paddingDimensions" using the padding value and
91/// the nofold flag found in "paddingValues" and "packPaddings", respectively.
92///
93/// Exit early and return the `opOperand` value if it already has the requested
94/// shape. i.e.:
95/// - static shape
96/// - nofold is not set
97/// - dim sizes are multiples of "padToMultipleOf"
98///
99/// Otherwise, try to pad the shape dimensions that match the iterator
100/// dimensions "paddingDimensions" and return the tensor::PadOp result if
101/// padding succeeds or failure otherwise.
102static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
103 RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
104 const LinalgPaddingOptions &options) {
105 assert(
106 (!options.padToMultipleOf.has_value() ||
107 options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
108 "invalid number of elements in padToMultipleOf");
109
110 // Compute padded shape.
111 SmallVector<int64_t> paddedShape;
112 bool alreadyHasRequestedShape = false;
113 if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
114 alreadyHasRequestedShape)))
115 return rewriter.notifyMatchFailure(opToPad,
116 "--failed to compute padded shape");
117
118 // Return the unpadded operand if padding to a static shape is not needed and
119 // if the nofold flag is not set.
120 bool nofold = opOperand->getOperandNumber() < options.packPaddings.size()
121 ? options.packPaddings[opOperand->getOperandNumber()]
122 : false;
123 if (!nofold && alreadyHasRequestedShape)
124 return opOperand->get();
125
126 // Fail if `paddingValues` specifies no padding value.
127 if (opOperand->getOperandNumber() >= options.paddingValues.size()) {
128 return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
129 }
130 Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
131
132 Value paddingValue;
133 if (auto complexTy = dyn_cast<ComplexType>(
134 getElementTypeOrSelf(opOperand->get().getType()))) {
135 auto complexAttr = cast<ArrayAttr>(paddingAttr);
136 paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
137 complexTy, complexAttr);
138 } else {
139 paddingValue = rewriter.create<arith::ConstantOp>(
140 opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
141 }
142
143 // Pad the operand to the bounding box defined by `paddedShape`.
144 auto paddedTensorType = RankedTensorType::get(
145 paddedShape, getElementTypeOrSelf(opOperand->get()));
146 LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
147 << paddedTensorType);
148 return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
149 opOperand->get(), paddingValue, nofold);
150}
151
152LogicalResult
153linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
154 const LinalgPaddingOptions &constOptions,
155 LinalgOp &paddedOp, SmallVector<Value> &replacements,
156 SmallVector<tensor::PadOp> &padOps) {
157 LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
158 Location loc = opToPad->getLoc();
159
160 LinalgPaddingOptions options(constOptions);
161 // Allow inference of pad values if they are not explicitly specified.
162 // TODO: be mindful about the value depending on the actual operation.
163 if (options.paddingValues.empty()) {
164 SmallVector<Type> types(opToPad->getOperandTypes());
165 llvm::append_range(types, opToPad->getResultTypes());
166 for (Type t : types) {
167 options.paddingValues.push_back(
168 rewriter.getZeroAttr(getElementTypeOrSelf(t)));
169 }
170 }
171
172 // TODO: there are cases where we may still want to pad to larger sizes.
173 if (!opToPad.hasPureTensorSemantics())
174 return rewriter.notifyMatchFailure(opToPad,
175 "expected operation on tensors");
176
177 OpBuilder::InsertionGuard g(rewriter);
178 // Set IP after op because we also take the dims of the original output.
179 rewriter.setInsertionPointAfter(opToPad);
180
181 // Make a copy of the shaped operands and update it.
182 SmallVector<Value> newOperands;
183 newOperands.reserve(N: opToPad->getNumOperands());
184 for (OpOperand &opOperand : opToPad->getOpOperands()) {
185 FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
186 rewriter, opToPad, &opOperand, options);
187 // Exit if `paddingDimensions` cannot be bounded statically.
188 if (failed(paddedOperand)) {
189 LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : "
190 << opOperand.get() << " -> FAIL\n");
191 return rewriter.notifyMatchFailure(opToPad,
192 "operand cannot be bound statically");
193 }
194 newOperands.push_back(*paddedOperand);
195 if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>())
196 padOps.push_back(padOp);
197 }
198
199 ReifiedRankedShapedTypeDims reifiedResultShapes;
200 if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
201 LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
202 return rewriter.notifyMatchFailure(opToPad,
203 "failed to reify result shapes");
204 }
205 assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
206 "expected same number of results");
207
208 // Clone `opToPad` to operate on the statically padded shapes.
209 auto resultTensorTypes =
210 ValueRange(newOperands).take_back(n: opToPad.getNumDpsInits()).getTypes();
211 // clone **should** properly notify the rewriter.
212 paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands);
213 LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
214
215 // Recover the slice out of the new static results. This keeps the original
216 // linalg op around because it uses the dims of the original results.
217 SmallVector<Value> paddedSubtensorResults;
218 paddedSubtensorResults.reserve(N: opToPad->getNumResults());
219 for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
220 Value paddedResult = en.value();
221 int64_t resultNumber = en.index();
222 int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
223 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
224 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
225 paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
226 loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
227 strides));
228 }
229
230 if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) {
231 replacements = std::move(paddedSubtensorResults);
232 return success();
233 }
234
235 // Copy back unpadded results to the original destination (i.e., inits of the
236 // linalg op), so that the destination buffer of the computation does not
237 // change. If the padding folds away, this will materialize as a memcpy
238 // between two identical buffers, which will then also fold away.
239 assert(static_cast<int64_t>(paddedSubtensorResults.size()) ==
240 opToPad.getNumDpsInits() &&
241 "expected matching number of results");
242 for (auto it :
243 llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
244 if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) {
245 replacements.push_back(rewriter
246 .create<linalg::CopyOp>(loc, std::get<0>(it),
247 std::get<1>(it).get())
248 .getResult(0));
249 } else if (options.copyBackOp ==
250 LinalgPaddingOptions::CopyBackOp::
251 BufferizationMaterializeInDestination) {
252 replacements.push_back(
253 rewriter
254 .create<bufferization::MaterializeInDestinationOp>(
255 loc, std::get<0>(it), std::get<1>(it).get())
256 ->getResult(0));
257 } else {
258 llvm_unreachable("unsupported copy back op");
259 }
260 }
261 return success();
262}
263
264FailureOr<LinalgOp>
265mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
266 const LinalgPaddingOptions &options) {
267 assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
268 "invalid options");
269
270 if (!linalgOp.hasPureTensorSemantics())
271 return rewriter.notifyMatchFailure(
272 linalgOp, "only applies to Linalg ops with tensor semantics");
273
274 // Pad the operation.
275 LinalgOp paddedOp;
276 SmallVector<Value> newResults;
277 SmallVector<tensor::PadOp> padOps;
278 if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp,
279 newResults, padOps)))
280 return rewriter.notifyMatchFailure(linalgOp,
281 "failed to rewrite as a padded op");
282
283 // Hoist the padding.
284 for (const auto &en : enumerate(First: options.hoistPaddings)) {
285 if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
286 break;
287 OpOperand &opOperand = paddedOp->getOpOperand(en.index());
288 auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
289 if (!padOp || en.value() == 0) {
290 (void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip");
291 continue;
292 }
293
294 // Fail hoisting if the operand shape is not fully static.
295 if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) {
296 (void)rewriter.notifyMatchFailure(linalgOp,
297 "non static padding shape -- skip");
298 continue;
299 }
300
301 tensor::PadOp hoistedOp;
302 SmallVector<GenericOp> transposeOps;
303 SmallVector<int64_t> transposeVector =
304 en.index() < options.transposePaddings.size()
305 ? options.transposePaddings[en.index()]
306 : SmallVector<int64_t>{};
307
308 FailureOr<Value> newResult = hoistPaddingOnTensors(
309 padOp, en.value(), transposeVector, hoistedOp, transposeOps);
310 if (failed(result: newResult)) {
311 (void)rewriter.notifyMatchFailure(linalgOp,
312 "failed to apply hoistPadding");
313 continue;
314 }
315 rewriter.replaceOp(padOp, *newResult);
316 }
317
318 // Replace the original operation to pad.
319 rewriter.replaceOp(linalgOp, newResults);
320
321 return paddedOp;
322}
323

source code of mlir/lib/Dialect/Linalg/Transforms/Padding.cpp