1//===- Padding.cpp - Padding of Linalg ops --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Complex/IR/Complex.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Interfaces/ValueBoundsOpInterface.h"
17
18#define DEBUG_TYPE "linalg-padding"
19
20using namespace mlir;
21using namespace mlir::linalg;
22
23#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
24#define DBGSNL() (llvm::dbgs() << "\n")
25
26namespace {
27/// Helper class for storing padding information.
28struct PaddingInfo {
29 PaddingInfo(int64_t padToMultipleOf = 1, OpFoldResult size = {})
30 : padToMultipleOf(padToMultipleOf), size(size) {}
31 /// Pad the tensor to a multiple of.
32 int64_t padToMultipleOf = 1;
33 /// The size used for padding.
34 OpFoldResult size = {};
35};
36
37/// Helper class for storing and computing the padded shape.
38struct PaddedShape {
39 /// Initializes the shape information and on success it returns whether the
40 /// shape of the operand will change. Returns failure if the operand cannot be
41 /// padded.
42 FailureOr<bool> initialize(linalg::LinalgOp opToPad, OpOperand *opOperand,
43 const LinalgPaddingOptions &options);
44
45 /// Computs the padded shape.
46 void computePadding(OpBuilder &builder, Value operand);
47
48 /// Returns the new tensor type.
49 RankedTensorType getType(Type elemTy) {
50 return RankedTensorType::get(shape, elementType: elemTy);
51 }
52
53 SmallVector<Value> dynDims;
54
55private:
56 SmallVector<int64_t> shape;
57 DenseMap<int64_t, PaddingInfo> dimToInfo;
58};
59} // namespace
60
61FailureOr<bool> PaddedShape::initialize(linalg::LinalgOp opToPad,
62 OpOperand *opOperand,
63 const LinalgPaddingOptions &options) {
64 AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
65
66 // Initialize the padded shape.
67 llvm::append_range(C&: shape, R: opToPad.getShape(opOperand));
68
69 // Collect the shape dimensions that are a function of "paddingDimensions",
70 // along with the multiple that they should be padded to ("1" if none).
71 bool alreadyHasRequestedShape = true;
72 for (const auto &dimEn : enumerate(First: options.paddingDimensions)) {
73 for (const auto &en : enumerate(First: indexingMap.getResults())) {
74 if (en.value().isFunctionOfDim(position: dimEn.value())) {
75 PaddingInfo paddingInfo;
76 int64_t dimSize = shape[en.index()];
77 if (options.padToMultipleOf.has_value()) {
78 paddingInfo.padToMultipleOf =
79 (*options.padToMultipleOf)[dimEn.index()];
80 } else {
81 paddingInfo.padToMultipleOf = 1;
82 }
83
84 // Check if the user provided a size in the options.
85 paddingInfo.size =
86 options.getSizeToPadTo(operandIndex: opOperand->getOperandNumber(), dimIndex: en.index());
87
88 // Set the padding info.
89 dimToInfo[en.index()] = paddingInfo;
90 if (ShapedType::isDynamic(dValue: dimSize) ||
91 dimSize % paddingInfo.padToMultipleOf != 0 ||
92 !paddingInfo.size.isNull()) {
93 alreadyHasRequestedShape = false;
94 }
95 }
96 }
97 }
98
99 // Upper bound the sizes to obtain a static bounding box.
100 for (int64_t i = 0, e = shape.size(); i < e; ++i) {
101 LLVM_DEBUG(DBGS() << "--computing un-padded size for dim " << i << "\n");
102 // Skip dimensions that do not require padding.
103 if (!dimToInfo.contains(Val: i)) {
104 LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
105 continue;
106 }
107 PaddingInfo &info = dimToInfo[i];
108 if (info.size) {
109 LLVM_DEBUG(DBGS() << "----the user provided the size: " << info.size
110 << "\n");
111 continue;
112 }
113 // Otherwise, try to compute a constant upper bound for the size value.
114 FailureOr<int64_t> upperBound =
115 ValueBoundsConstraintSet::computeConstantBound(
116 type: presburger::BoundType::UB,
117 var: {opOperand->get(),
118 /*dim=*/i},
119 /*stopCondition=*/nullptr, /*closedUB=*/true);
120 if (failed(Result: upperBound)) {
121 LLVM_DEBUG(
122 DBGS() << "----could not compute a bounding box for padding\n");
123 return failure();
124 }
125 info.size =
126 IntegerAttr::get(type: IndexType::get(context: opToPad.getContext()), value: *upperBound);
127 LLVM_DEBUG(DBGS() << "----new un-padded size: " << info.size << "\n");
128 }
129 return alreadyHasRequestedShape;
130}
131
132void PaddedShape::computePadding(OpBuilder &builder, Value operand) {
133 Location loc = operand.getLoc();
134 AffineExpr sizeSym = builder.getAffineSymbolExpr(position: 0);
135
136 // Compute the padding for each dimension.
137 for (auto &&[i, dim] : llvm::enumerate(First&: shape)) {
138 LLVM_DEBUG(DBGS() << "--computing padded size for dim " << i << "\n");
139
140 // Get the padding info or default info for the shape dimension.
141 PaddingInfo paddingInfo = dimToInfo.lookup(Val: i);
142
143 // Skip dimensions that do not require padding.
144 if (paddingInfo.size.isNull()) {
145 LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
146
147 // We still need to push the size as `makeComposedPadHighOp` expects a
148 // range with all the dynamic sizes, whether they're being padded or not.
149 if (ShapedType::isDynamic(dValue: dim)) {
150 dynDims.push_back(
151 Elt: cast<Value>(Val: tensor::getMixedSize(builder, loc, value: operand, dim: i)));
152 }
153 continue;
154 }
155
156 // Compute the padded size to be a multiple of `padToMultipleOf`.
157 AffineExpr szExpr = (sizeSym).ceilDiv(v: paddingInfo.padToMultipleOf) *
158 paddingInfo.padToMultipleOf;
159 OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
160 b&: builder, loc, expr: szExpr, operands: paddingInfo.size);
161 assert(paddedSize && "invalid arguments to affine apply");
162
163 if (auto cstSzAttr = dyn_cast<Attribute>(Val&: paddedSize)) {
164 // Update the shape as the size is static.
165 dim = cast<IntegerAttr>(Val&: cstSzAttr).getValue().getZExtValue();
166 } else {
167 // Add a dynamic dimension.
168 dim = ShapedType::kDynamic;
169 dynDims.push_back(Elt: cast<Value>(Val&: paddedSize));
170 }
171 LLVM_DEBUG(DBGS() << "----new dim size: " << paddedSize << "\n");
172 }
173}
174
175/// Pad the `opOperand` in the "paddingDimensions" using the padding value and
176/// the nofold flag found in "paddingValues" and "nofoldFlags", respectively.
177///
178/// Exit early and return the `opOperand` value if it already has the requested
179/// shape. i.e.:
180/// - static shape
181/// - nofold is not set
182/// - dim sizes are multiples of "padToMultipleOf"
183///
184/// Otherwise, try to pad the shape dimensions that match the iterator
185/// dimensions "paddingDimensions" and return the tensor::PadOp result if
186/// padding succeeds or failure otherwise.
187static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
188 RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
189 const LinalgPaddingOptions &options) {
190 assert(
191 (!options.padToMultipleOf.has_value() ||
192 options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
193 "invalid number of elements in padToMultipleOf");
194
195 // Initialize the padded shape and get whether it requires padding.
196 PaddedShape shape;
197 FailureOr<bool> alreadyHasRequestedShape =
198 shape.initialize(opToPad, opOperand, options);
199 if (failed(Result: alreadyHasRequestedShape)) {
200 return rewriter.notifyMatchFailure(arg&: opToPad,
201 msg: "--failed to compute padded shape");
202 }
203
204 // Return the un-padded operand if padding to a static shape is not needed and
205 // if the nofold flag is not set.
206 bool nofold = opOperand->getOperandNumber() < options.nofoldFlags.size()
207 ? bool(options.nofoldFlags[opOperand->getOperandNumber()])
208 : false;
209 if (!nofold && *alreadyHasRequestedShape)
210 return opOperand->get();
211
212 // Fail if `paddingValues` specifies no padding value.
213 if (opOperand->getOperandNumber() >= options.paddingValues.size()) {
214 return rewriter.notifyMatchFailure(arg&: opToPad, msg: "--no padding value specified");
215 }
216 Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
217
218 Value paddingValue;
219 if (auto complexTy = dyn_cast<ComplexType>(
220 Val: getElementTypeOrSelf(type: opOperand->get().getType()))) {
221 auto complexAttr = cast<ArrayAttr>(Val&: paddingAttr);
222 paddingValue = rewriter.create<complex::ConstantOp>(location: opToPad.getLoc(),
223 args&: complexTy, args&: complexAttr);
224 } else {
225 paddingValue = rewriter.create<arith::ConstantOp>(
226 location: opToPad.getLoc(), args: cast<TypedAttr>(Val&: paddingAttr));
227 }
228
229 // Computes the padded shape.
230 if (!*alreadyHasRequestedShape)
231 shape.computePadding(builder&: rewriter, operand: opOperand->get());
232
233 // Pad the operand to the bounding box defined by `paddedShape`.
234 RankedTensorType paddedTensorType =
235 shape.getType(elemTy: getElementTypeOrSelf(val: opOperand->get()));
236 LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
237 << paddedTensorType);
238 return makeComposedPadHighOp(b&: rewriter, loc: opToPad->getLoc(), type: paddedTensorType,
239 source: opOperand->get(), padding: paddingValue, nofold,
240 typeDynDims: shape.dynDims);
241}
242
243LogicalResult
244linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
245 const LinalgPaddingOptions &constOptions,
246 LinalgOp &paddedOp, SmallVector<Value> &replacements,
247 SmallVector<tensor::PadOp> &padOps) {
248 LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
249 Location loc = opToPad->getLoc();
250
251 LinalgPaddingOptions options(constOptions);
252 // Allow inference of pad values if they are not explicitly specified.
253 // TODO: be mindful about the value depending on the actual operation.
254 if (options.paddingValues.empty()) {
255 SmallVector<Type> types(opToPad->getOperandTypes());
256 llvm::append_range(C&: types, R: opToPad->getResultTypes());
257 for (Type t : types) {
258 options.paddingValues.push_back(
259 Elt: rewriter.getZeroAttr(type: getElementTypeOrSelf(type: t)));
260 }
261 }
262
263 // TODO: there are cases where we may still want to pad to larger sizes.
264 if (!opToPad.hasPureTensorSemantics())
265 return rewriter.notifyMatchFailure(arg&: opToPad,
266 msg: "expected operation on tensors");
267
268 OpBuilder::InsertionGuard g(rewriter);
269 // Set IP after op because we also take the dims of the original output.
270 rewriter.setInsertionPointAfter(opToPad);
271
272 // Make a copy of the shaped operands and update it.
273 SmallVector<Value> newOperands;
274 newOperands.reserve(N: opToPad->getNumOperands());
275 for (OpOperand &opOperand : opToPad->getOpOperands()) {
276 FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
277 rewriter, opToPad, opOperand: &opOperand, options);
278 // Exit if `paddingDimensions` cannot be bounded statically.
279 if (failed(Result: paddedOperand)) {
280 LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : "
281 << opOperand.get() << " -> FAIL\n");
282 return rewriter.notifyMatchFailure(arg&: opToPad,
283 msg: "operand cannot be bound statically");
284 }
285 newOperands.push_back(Elt: *paddedOperand);
286 if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>())
287 padOps.push_back(Elt: padOp);
288 }
289
290 ReifiedRankedShapedTypeDims reifiedResultShapes;
291 if (failed(Result: reifyResultShapes(b&: rewriter, op: opToPad, reifiedReturnShapes&: reifiedResultShapes))) {
292 LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
293 return rewriter.notifyMatchFailure(arg&: opToPad,
294 msg: "failed to reify result shapes");
295 }
296 assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
297 "expected same number of results");
298
299 // Clone `opToPad` to operate on the statically padded shapes.
300 auto resultTensorTypes =
301 ValueRange(newOperands).take_back(n: opToPad.getNumDpsInits()).getTypes();
302 // clone **should** properly notify the rewriter.
303 paddedOp = clone(b&: rewriter, op: opToPad, newResultTypes: resultTensorTypes, newOperands);
304 LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
305
306 // Recover the slice out of the new static results. This keeps the original
307 // linalg op around because it uses the dims of the original results.
308 SmallVector<Value> paddedSubtensorResults;
309 paddedSubtensorResults.reserve(N: opToPad->getNumResults());
310 for (const auto &en : llvm::enumerate(First: paddedOp->getResults())) {
311 Value paddedResult = en.value();
312 int64_t resultNumber = en.index();
313 int64_t rank = cast<RankedTensorType>(Val: paddedResult.getType()).getRank();
314 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0));
315 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1));
316 paddedSubtensorResults.push_back(Elt: rewriter.create<tensor::ExtractSliceOp>(
317 location: loc, args&: paddedResult, args&: offsets, args&: reifiedResultShapes[resultNumber],
318 args&: strides));
319 }
320
321 if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) {
322 replacements = std::move(paddedSubtensorResults);
323 return success();
324 }
325
326 // Copy back unpadded results to the original destination (i.e., inits of the
327 // linalg op), so that the destination buffer of the computation does not
328 // change. If the padding folds away, this will materialize as a memcpy
329 // between two identical buffers, which will then also fold away.
330 assert(static_cast<int64_t>(paddedSubtensorResults.size()) ==
331 opToPad.getNumDpsInits() &&
332 "expected matching number of results");
333 for (auto it :
334 llvm::zip(t&: paddedSubtensorResults, u: opToPad.getDpsInitsMutable())) {
335 if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) {
336 replacements.push_back(Elt: rewriter
337 .create<linalg::CopyOp>(location: loc, args&: std::get<0>(t&: it),
338 args: std::get<1>(t&: it).get())
339 .getResult(i: 0));
340 } else if (options.copyBackOp ==
341 LinalgPaddingOptions::CopyBackOp::
342 BufferizationMaterializeInDestination) {
343 replacements.push_back(
344 Elt: rewriter
345 .create<bufferization::MaterializeInDestinationOp>(
346 location: loc, args&: std::get<0>(t&: it), args: std::get<1>(t&: it).get())
347 ->getResult(idx: 0));
348 } else {
349 llvm_unreachable("unsupported copy back op");
350 }
351 }
352 return success();
353}
354
355FailureOr<LinalgOp>
356mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
357 const LinalgPaddingOptions &options) {
358 assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
359 "invalid options");
360
361 if (!linalgOp.hasPureTensorSemantics())
362 return rewriter.notifyMatchFailure(
363 arg&: linalgOp, msg: "only applies to Linalg ops with tensor semantics");
364
365 // Pad the operation.
366 LinalgOp paddedOp;
367 SmallVector<Value> newResults;
368 SmallVector<tensor::PadOp> padOps;
369 if (failed(Result: rewriteAsPaddedOp(rewriter, opToPad: linalgOp, constOptions: options, paddedOp,
370 replacements&: newResults, padOps)))
371 return rewriter.notifyMatchFailure(arg&: linalgOp,
372 msg: "failed to rewrite as a padded op");
373
374 // Hoist the padding.
375 for (const auto &en : enumerate(First: options.hoistPaddings)) {
376 if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
377 break;
378 OpOperand &opOperand = paddedOp->getOpOperand(idx: en.index());
379 auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
380 if (!padOp || en.value() == 0) {
381 (void)rewriter.notifyMatchFailure(arg&: linalgOp, msg: "not a tensor.pad -- skip");
382 continue;
383 }
384
385 // Fail hoisting if the operand shape is not fully static.
386 if (llvm::any_of(Range: paddedOp.getShape(opOperand: &opOperand), P: ShapedType::isDynamic)) {
387 (void)rewriter.notifyMatchFailure(arg&: linalgOp,
388 msg: "non static padding shape -- skip");
389 continue;
390 }
391
392 tensor::PadOp hoistedOp;
393 SmallVector<TransposeOp> transposeOps;
394 SmallVector<int64_t> transposeVector =
395 en.index() < options.transposePaddings.size()
396 ? options.transposePaddings[en.index()]
397 : SmallVector<int64_t>{};
398
399 FailureOr<Value> newResult = hoistPaddingOnTensors(
400 opToHoist: padOp, numLoops: en.value(), transposeVector, hoistedOp, transposeOps);
401 if (failed(Result: newResult)) {
402 (void)rewriter.notifyMatchFailure(arg&: linalgOp,
403 msg: "failed to apply hoistPadding");
404 continue;
405 }
406 rewriter.replaceOp(op: padOp, newValues: *newResult);
407 }
408
409 // Replace the original operation to pad.
410 rewriter.replaceOp(op: linalgOp, newValues: newResults);
411
412 return paddedOp;
413}
414

source code of mlir/lib/Dialect/Linalg/Transforms/Padding.cpp