1 | //===- Padding.cpp - Padding of Linalg ops --------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
12 | #include "mlir/Dialect/Complex/IR/Complex.h" |
13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
15 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
16 | |
17 | #define DEBUG_TYPE "linalg-padding" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::linalg; |
21 | |
22 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
23 | #define DBGSNL() (llvm::dbgs() << "\n") |
24 | |
25 | /// Compute the padded shape of the given operand. The operand is padded to a |
26 | /// static bounding box according to the specified padding options. |
27 | static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, |
28 | OpOperand *opOperand, |
29 | const LinalgPaddingOptions &options, |
30 | SmallVector<int64_t> &paddedShape, |
31 | bool &alreadyHasRequestedShape) { |
32 | AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand); |
33 | ArrayRef<int64_t> shape = opToPad.getShape(opOperand); |
34 | |
35 | // Collect the shape dimensions that are a function of "paddingDimensions", |
36 | // along with the multiple that they should be padded to ("1" if none). |
37 | alreadyHasRequestedShape = true; |
38 | DenseMap<int64_t, int64_t> shapeDimToMultiple; |
39 | for (const auto &dimEn : enumerate(First: options.paddingDimensions)) { |
40 | for (const auto &en : enumerate(indexingMap.getResults())) { |
41 | if (en.value().isFunctionOfDim(dimEn.value())) { |
42 | int64_t dimSize = shape[en.index()]; |
43 | if (options.padToMultipleOf.has_value()) { |
44 | shapeDimToMultiple[en.index()] = |
45 | (*options.padToMultipleOf)[dimEn.index()]; |
46 | } else { |
47 | shapeDimToMultiple[en.index()] = 1; |
48 | } |
49 | if (ShapedType::isDynamic(dimSize)) { |
50 | alreadyHasRequestedShape = false; |
51 | } else if (dimSize % shapeDimToMultiple[en.index()] != 0) { |
52 | alreadyHasRequestedShape = false; |
53 | } |
54 | } |
55 | } |
56 | } |
57 | |
58 | // Helper function to round a number up to a given multiple. |
59 | auto ceil = [](int64_t val, int64_t multiple) { |
60 | return ((val + multiple - 1) / multiple) * multiple; |
61 | }; |
62 | |
63 | // Upper bound the sizes to obtain a static bounding box. |
64 | paddedShape.assign(in_start: shape.begin(), in_end: shape.end()); |
65 | for (int64_t i = 0, e = shape.size(); i < e; ++i) { |
66 | LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n" ); |
67 | // Skip dimensions that do not require padding. |
68 | if (!shapeDimToMultiple.contains(Val: i)) { |
69 | LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n" ); |
70 | continue; |
71 | } |
72 | // Otherwise, try to compute a constant upper bound for the size value. |
73 | FailureOr<int64_t> upperBound = |
74 | ValueBoundsConstraintSet::computeConstantBound( |
75 | type: presburger::BoundType::UB, |
76 | var: {opOperand->get(), |
77 | /*dim=*/i}, |
78 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
79 | if (failed(result: upperBound)) { |
80 | LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding" ); |
81 | return failure(); |
82 | } |
83 | paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]); |
84 | LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n" ); |
85 | } |
86 | |
87 | return success(); |
88 | } |
89 | |
90 | /// Pad the `opOperand` in the "paddingDimensions" using the padding value and |
91 | /// the nofold flag found in "paddingValues" and "packPaddings", respectively. |
92 | /// |
93 | /// Exit early and return the `opOperand` value if it already has the requested |
94 | /// shape. i.e.: |
95 | /// - static shape |
96 | /// - nofold is not set |
97 | /// - dim sizes are multiples of "padToMultipleOf" |
98 | /// |
99 | /// Otherwise, try to pad the shape dimensions that match the iterator |
100 | /// dimensions "paddingDimensions" and return the tensor::PadOp result if |
101 | /// padding succeeds or failure otherwise. |
102 | static FailureOr<Value> padOperandToSmallestStaticBoundingBox( |
103 | RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, |
104 | const LinalgPaddingOptions &options) { |
105 | assert( |
106 | (!options.padToMultipleOf.has_value() || |
107 | options.padToMultipleOf->size() == options.paddingDimensions.size()) && |
108 | "invalid number of elements in padToMultipleOf" ); |
109 | |
110 | // Compute padded shape. |
111 | SmallVector<int64_t> paddedShape; |
112 | bool alreadyHasRequestedShape = false; |
113 | if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape, |
114 | alreadyHasRequestedShape))) |
115 | return rewriter.notifyMatchFailure(opToPad, |
116 | "--failed to compute padded shape" ); |
117 | |
118 | // Return the unpadded operand if padding to a static shape is not needed and |
119 | // if the nofold flag is not set. |
120 | bool nofold = opOperand->getOperandNumber() < options.packPaddings.size() |
121 | ? options.packPaddings[opOperand->getOperandNumber()] |
122 | : false; |
123 | if (!nofold && alreadyHasRequestedShape) |
124 | return opOperand->get(); |
125 | |
126 | // Fail if `paddingValues` specifies no padding value. |
127 | if (opOperand->getOperandNumber() >= options.paddingValues.size()) { |
128 | return rewriter.notifyMatchFailure(opToPad, "--no padding value specified" ); |
129 | } |
130 | Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()]; |
131 | |
132 | Value paddingValue; |
133 | if (auto complexTy = dyn_cast<ComplexType>( |
134 | getElementTypeOrSelf(opOperand->get().getType()))) { |
135 | auto complexAttr = cast<ArrayAttr>(paddingAttr); |
136 | paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(), |
137 | complexTy, complexAttr); |
138 | } else { |
139 | paddingValue = rewriter.create<arith::ConstantOp>( |
140 | opToPad.getLoc(), cast<TypedAttr>(paddingAttr)); |
141 | } |
142 | |
143 | // Pad the operand to the bounding box defined by `paddedShape`. |
144 | auto paddedTensorType = RankedTensorType::get( |
145 | paddedShape, getElementTypeOrSelf(opOperand->get())); |
146 | LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: " |
147 | << paddedTensorType); |
148 | return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType, |
149 | opOperand->get(), paddingValue, nofold); |
150 | } |
151 | |
152 | LogicalResult |
153 | linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, |
154 | const LinalgPaddingOptions &constOptions, |
155 | LinalgOp &paddedOp, SmallVector<Value> &replacements, |
156 | SmallVector<tensor::PadOp> &padOps) { |
157 | LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n" ); |
158 | Location loc = opToPad->getLoc(); |
159 | |
160 | LinalgPaddingOptions options(constOptions); |
161 | // Allow inference of pad values if they are not explicitly specified. |
162 | // TODO: be mindful about the value depending on the actual operation. |
163 | if (options.paddingValues.empty()) { |
164 | SmallVector<Type> types(opToPad->getOperandTypes()); |
165 | llvm::append_range(types, opToPad->getResultTypes()); |
166 | for (Type t : types) { |
167 | options.paddingValues.push_back( |
168 | rewriter.getZeroAttr(getElementTypeOrSelf(t))); |
169 | } |
170 | } |
171 | |
172 | // TODO: there are cases where we may still want to pad to larger sizes. |
173 | if (!opToPad.hasPureTensorSemantics()) |
174 | return rewriter.notifyMatchFailure(opToPad, |
175 | "expected operation on tensors" ); |
176 | |
177 | OpBuilder::InsertionGuard g(rewriter); |
178 | // Set IP after op because we also take the dims of the original output. |
179 | rewriter.setInsertionPointAfter(opToPad); |
180 | |
181 | // Make a copy of the shaped operands and update it. |
182 | SmallVector<Value> newOperands; |
183 | newOperands.reserve(N: opToPad->getNumOperands()); |
184 | for (OpOperand &opOperand : opToPad->getOpOperands()) { |
185 | FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox( |
186 | rewriter, opToPad, &opOperand, options); |
187 | // Exit if `paddingDimensions` cannot be bounded statically. |
188 | if (failed(paddedOperand)) { |
189 | LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : " |
190 | << opOperand.get() << " -> FAIL\n" ); |
191 | return rewriter.notifyMatchFailure(opToPad, |
192 | "operand cannot be bound statically" ); |
193 | } |
194 | newOperands.push_back(*paddedOperand); |
195 | if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>()) |
196 | padOps.push_back(padOp); |
197 | } |
198 | |
199 | ReifiedRankedShapedTypeDims reifiedResultShapes; |
200 | if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) { |
201 | LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n" ); |
202 | return rewriter.notifyMatchFailure(opToPad, |
203 | "failed to reify result shapes" ); |
204 | } |
205 | assert(reifiedResultShapes.size() == opToPad->getNumResults() && |
206 | "expected same number of results" ); |
207 | |
208 | // Clone `opToPad` to operate on the statically padded shapes. |
209 | auto resultTensorTypes = |
210 | ValueRange(newOperands).take_back(n: opToPad.getNumDpsInits()).getTypes(); |
211 | // clone **should** properly notify the rewriter. |
212 | paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands); |
213 | LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n" ); |
214 | |
215 | // Recover the slice out of the new static results. This keeps the original |
216 | // linalg op around because it uses the dims of the original results. |
217 | SmallVector<Value> paddedSubtensorResults; |
218 | paddedSubtensorResults.reserve(N: opToPad->getNumResults()); |
219 | for (const auto &en : llvm::enumerate(paddedOp->getResults())) { |
220 | Value paddedResult = en.value(); |
221 | int64_t resultNumber = en.index(); |
222 | int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank(); |
223 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); |
224 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
225 | paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>( |
226 | loc, paddedResult, offsets, reifiedResultShapes[resultNumber], |
227 | strides)); |
228 | } |
229 | |
230 | if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None) { |
231 | replacements = std::move(paddedSubtensorResults); |
232 | return success(); |
233 | } |
234 | |
235 | // Copy back unpadded results to the original destination (i.e., inits of the |
236 | // linalg op), so that the destination buffer of the computation does not |
237 | // change. If the padding folds away, this will materialize as a memcpy |
238 | // between two identical buffers, which will then also fold away. |
239 | assert(static_cast<int64_t>(paddedSubtensorResults.size()) == |
240 | opToPad.getNumDpsInits() && |
241 | "expected matching number of results" ); |
242 | for (auto it : |
243 | llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) { |
244 | if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) { |
245 | replacements.push_back(rewriter |
246 | .create<linalg::CopyOp>(loc, std::get<0>(it), |
247 | std::get<1>(it).get()) |
248 | .getResult(0)); |
249 | } else if (options.copyBackOp == |
250 | LinalgPaddingOptions::CopyBackOp:: |
251 | BufferizationMaterializeInDestination) { |
252 | replacements.push_back( |
253 | rewriter |
254 | .create<bufferization::MaterializeInDestinationOp>( |
255 | loc, std::get<0>(it), std::get<1>(it).get()) |
256 | ->getResult(0)); |
257 | } else { |
258 | llvm_unreachable("unsupported copy back op" ); |
259 | } |
260 | } |
261 | return success(); |
262 | } |
263 | |
264 | FailureOr<LinalgOp> |
265 | mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, |
266 | const LinalgPaddingOptions &options) { |
267 | assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None && |
268 | "invalid options" ); |
269 | |
270 | if (!linalgOp.hasPureTensorSemantics()) |
271 | return rewriter.notifyMatchFailure( |
272 | linalgOp, "only applies to Linalg ops with tensor semantics" ); |
273 | |
274 | // Pad the operation. |
275 | LinalgOp paddedOp; |
276 | SmallVector<Value> newResults; |
277 | SmallVector<tensor::PadOp> padOps; |
278 | if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp, |
279 | newResults, padOps))) |
280 | return rewriter.notifyMatchFailure(linalgOp, |
281 | "failed to rewrite as a padded op" ); |
282 | |
283 | // Hoist the padding. |
284 | for (const auto &en : enumerate(First: options.hoistPaddings)) { |
285 | if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands()) |
286 | break; |
287 | OpOperand &opOperand = paddedOp->getOpOperand(en.index()); |
288 | auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>(); |
289 | if (!padOp || en.value() == 0) { |
290 | (void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip" ); |
291 | continue; |
292 | } |
293 | |
294 | // Fail hoisting if the operand shape is not fully static. |
295 | if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) { |
296 | (void)rewriter.notifyMatchFailure(linalgOp, |
297 | "non static padding shape -- skip" ); |
298 | continue; |
299 | } |
300 | |
301 | tensor::PadOp hoistedOp; |
302 | SmallVector<GenericOp> transposeOps; |
303 | SmallVector<int64_t> transposeVector = |
304 | en.index() < options.transposePaddings.size() |
305 | ? options.transposePaddings[en.index()] |
306 | : SmallVector<int64_t>{}; |
307 | |
308 | FailureOr<Value> newResult = hoistPaddingOnTensors( |
309 | padOp, en.value(), transposeVector, hoistedOp, transposeOps); |
310 | if (failed(result: newResult)) { |
311 | (void)rewriter.notifyMatchFailure(linalgOp, |
312 | "failed to apply hoistPadding" ); |
313 | continue; |
314 | } |
315 | rewriter.replaceOp(padOp, *newResult); |
316 | } |
317 | |
318 | // Replace the original operation to pad. |
319 | rewriter.replaceOp(linalgOp, newResults); |
320 | |
321 | return paddedOp; |
322 | } |
323 | |