1//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Complex/IR/Complex.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/Dialect/Utils/StaticValueUtils.h"
15#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/BuiltinTypeInterfaces.h"
18#include "mlir/IR/BuiltinTypes.h"
19#include "mlir/IR/OpDefinition.h"
20#include "mlir/IR/Value.h"
21#include "mlir/Interfaces/TilingInterface.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/Casting.h"
24
25#define DEBUG_TYPE "pad-tiling-interface"
26
27using namespace mlir;
28using namespace mlir::linalg;
29using namespace mlir::tensor;
30
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
32#define DBGSNL() (llvm::dbgs() << "\n")
33
34/// Form a "full-rank" padding specification so that the application is easy.
35static SmallVector<OpFoldResult>
36getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
37 const PadTilingInterfaceOptions &options) {
38 SmallVector<OpFoldResult> paddingSizes;
39 // Complete the padding specification to specify all dimensions.
40 for (size_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
41 // Complete to zero if needed.
42 paddingSizes.push_back(Elt: options.paddingSizes.size() > idx
43 ? options.paddingSizes[idx]
44 : b.getIndexAttr(value: 0));
45 // If a dimension is zero (either specified or completed), replace by:
46 // - 1 if we are padding to the next multiple of.
47 // - indexingSizes[idx] otherwise
48 if (isZeroInteger(v: paddingSizes[idx])) {
49 paddingSizes[idx] =
50 options.padToMultipleOf ? b.getIndexAttr(value: 1) : indexingSizes[idx];
51 }
52 LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << paddingSizes[idx]
53 << "\n");
54 }
55 return paddingSizes;
56}
57
58/// Compute the padded shape of the given value `v` of `RankedTensorType` given
59/// - `indexingSizes` a list of OpFoldResult.
60/// - an `indexingMap` that encodes how the shape of varies with increases
61/// in `indexingSizes`.
62/// The `indexingMap` encodes how the shape of varies with `indexingSizes`.
63/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
64/// The implementaiton below iteratively combines increases from contributing
65/// dimensions using affine.apply operations.
66/// In the future, more general interfaces can be devised to encode similar
67/// shape evolutions and map between an op and its operands.
68SmallVector<OpFoldResult> linalg::computePaddedShape(
69 RewriterBase &rewriter, TypedValue<RankedTensorType> v,
70 AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
71 const PadTilingInterfaceOptions &options) {
72 Location loc = v.getLoc();
73 SmallVector<OpFoldResult> paddedShape;
74 auto tensorType = cast<RankedTensorType>(Val: v.getType());
75 paddedShape.resize_for_overwrite(N: tensorType.getRank());
76 assert(tensorType.getRank() == indexingMap.getNumResults() &&
77 "expect the number of results of the affine map to match the tensor "
78 "rank");
79
80 // "Full-rank" padding specification.
81 SmallVector<OpFoldResult> paddingSizes =
82 getFullRankPaddingSizes(b&: rewriter, indexingSizes, options);
83
84 // For each dimension in the operand's shape, iterate over indexingSizes and
85 // add the various term contributions.
86 for (const auto &enResults : enumerate(First: indexingMap.getResults())) {
87 int64_t resultIndex = enResults.index();
88 AffineMap partialIndexingMap = indexingMap.getSubMap(
89 resultPos: ArrayRef<unsigned>{static_cast<unsigned>(resultIndex)});
90
91 LLVM_DEBUG(DBGS() << "----resultIndex: " << resultIndex
92 << " with partialIndexingMap: " << partialIndexingMap
93 << "\n");
94
95 // Find all padding dimensions that contribute to this operand dimension
96 // and compute the padded term contribution to the final padded shape.
97 SmallVector<OpFoldResult> terms;
98 for (size_t paddingDim = 0, e = paddingSizes.size(); paddingDim != e;
99 ++paddingDim) {
100 OpFoldResult paddingSize = paddingSizes[paddingDim];
101 LLVM_DEBUG(DBGS() << "------try apply padding of dim: " << paddingDim
102 << " to: " << paddingSize << "\n");
103 if (!enResults.value().isFunctionOfDim(position: paddingDim))
104 continue;
105
106 LLVM_DEBUG(DBGS() << "------apply padding of dim: " << paddingDim
107 << " to: " << paddingSize << "\n");
108
109 // Project non-'paddingDim' dimensions and compress the result.
110 llvm::SmallBitVector projectedDims(partialIndexingMap.getNumDims(), true);
111 projectedDims.flip(Idx: paddingDim);
112 AffineMap projectedMap =
113 mlir::projectDims(map: partialIndexingMap, projectedDimensions: projectedDims,
114 /*compressDims=*/compressDimsFlag: true);
115
116 // If we are padding to the next multiple of, compose with ceil(sz) * sz.
117 if (options.padToMultipleOf) {
118 AffineExpr d0, s0;
119 bindDims(ctx: rewriter.getContext(), exprs&: d0);
120 bindSymbols(ctx: rewriter.getContext(), exprs&: s0);
121 AffineMap ceilMap = AffineMap::get(dimCount: 1, symbolCount: 1, result: d0.ceilDiv(other: s0) * s0);
122 AffineMap composedMap = projectedMap.compose(map: ceilMap);
123 OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
124 b&: rewriter, loc, map: composedMap,
125 operands: {indexingSizes[paddingDim], paddingSize},
126 /*composeAffineMin=*/true);
127 terms.push_back(Elt: paddingDimOfr);
128 } else {
129 // Otherwise just set to paddingSize.
130 OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
131 b&: rewriter, loc, map: projectedMap, operands: paddingSize);
132 terms.push_back(Elt: paddingDimOfr);
133 }
134
135 LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
136 }
137
138 // If there are no terms, just return the dim.
139 if (terms.empty()) {
140 paddedShape[resultIndex] =
141 createFoldedDimOp(b&: rewriter, loc, val: v, dim: resultIndex);
142 continue;
143 }
144
145 // Sum individual terms' contributions.
146 SmallVector<AffineExpr> dims(terms.size());
147 bindDimsList(ctx: rewriter.getContext(), exprs: MutableArrayRef{dims});
148 AffineExpr sumExpr = dims.front();
149 for (unsigned i = 1; i < dims.size(); ++i)
150 sumExpr = sumExpr + dims[i];
151 OpFoldResult paddedDimOfr =
152 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: sumExpr, operands: terms);
153 paddedShape[resultIndex] = paddedDimOfr;
154 }
155
156 return paddedShape;
157}
158
159FailureOr<SmallVector<OpFoldResult>>
160linalg::computeIndexingMapOpInterfacePaddedShape(
161 RewriterBase &rewriter, OpOperand &operandToPad,
162 ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
163 auto transferOp =
164 llvm::dyn_cast<IndexingMapOpInterface>(Val: operandToPad.getOwner());
165 if (!transferOp)
166 return failure();
167
168 // clang-format off
169 assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
170 return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
171 r.stride == OpFoldResult(rewriter.getIndexAttr(1));
172 }) && "expected 0-offset 1-stride loop ranges");
173 // clang-format on
174 SmallVector<OpFoldResult> loopUpperBounds;
175 loopUpperBounds.reserve(N: iterationDomain.size());
176 for (const Range &range : iterationDomain)
177 loopUpperBounds.push_back(Elt: range.size);
178
179 AffineMap indexingMap = transferOp.getMatchingIndexingMap(opOperand: &operandToPad);
180 return computePaddedShape(
181 rewriter, v: cast<TypedValue<RankedTensorType>>(Val: operandToPad.get()),
182 indexingMap, indexingSizes: loopUpperBounds, options);
183}
184
185/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
186/// Value.
187static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
188 TypedValue<RankedTensorType> v,
189 ArrayRef<OpFoldResult> paddedShape,
190 Attribute paddingValueAttr) {
191 Value paddingValue;
192 if (auto complexTy =
193 dyn_cast<ComplexType>(Val: getElementTypeOrSelf(type: v.getType()))) {
194 auto complexAttr = cast<ArrayAttr>(Val&: paddingValueAttr);
195 paddingValue = rewriter.create<complex::ConstantOp>(location: opToPad.getLoc(),
196 args&: complexTy, args&: complexAttr);
197 } else {
198 paddingValue = rewriter.create<arith::ConstantOp>(
199 location: opToPad.getLoc(), args: cast<TypedAttr>(Val&: paddingValueAttr));
200 }
201
202 // Pad the operand to the bounding box defined by `paddedShape`.
203 SmallVector<int64_t> tensorShape;
204 SmallVector<Value> dynDims;
205 for (OpFoldResult ofr : paddedShape) {
206 std::optional<int64_t> cst = getConstantIntValue(ofr);
207 tensorShape.push_back(Elt: cst.has_value() ? *cst : ShapedType::kDynamic);
208 if (!cst.has_value())
209 dynDims.push_back(Elt: ofr.dyn_cast<Value>());
210 }
211 // TODO: use dispatchIndexOpFoldResults(paddedShape, dynDims, paddedShape);
212
213 auto paddedTensorType =
214 RankedTensorType::get(shape: tensorShape, elementType: getElementTypeOrSelf(val: v));
215 LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
216 << paddedTensorType);
217 return makeComposedPadHighOp(b&: rewriter, loc: opToPad.getLoc(), type: paddedTensorType, source: v,
218 padding: paddingValue, /*nofold=*/false, typeDynDims: dynDims);
219}
220
221FailureOr<TilingInterface>
222linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
223 const PadTilingInterfaceOptions &constOptions,
224 SmallVector<tensor::PadOp> &padOps,
225 PadSizeComputationFunction computePaddingSizeFun) {
226 LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
227
228 Location loc = opToPad.getLoc();
229 PadTilingInterfaceOptions options(constOptions);
230 // Allow inference of pad values if they are not explicitly specified.
231 // TODO: be mindful about the value depending on the actual operation.
232 if (options.paddingValues.empty()) {
233 SmallVector<Type> types(opToPad->getOperandTypes());
234 llvm::append_range(C&: types, R: opToPad->getResultTypes());
235 for (Type t : types) {
236 options.paddingValues.push_back(
237 Elt: rewriter.getZeroAttr(type: getElementTypeOrSelf(type: t)));
238 }
239 }
240
241 if (llvm::any_of(Range: opToPad->getOperands(),
242 P: [](Value v) { return isa<MemRefType>(Val: v.getType()); })) {
243 return rewriter.notifyMatchFailure(arg&: opToPad,
244 msg: "expected operation on tensors");
245 }
246
247 OpBuilder::InsertionGuard g(rewriter);
248 // Set IP after opToPad because we also take the dims of opToPad's output.
249 rewriter.setInsertionPointAfter(opToPad);
250
251 // 1. Get the loopUpperBounds from the TilingInterface.
252 SmallVector<Range> iterationDomain = opToPad.getIterationDomain(b&: rewriter);
253
254 // 2. For each operand.
255 SmallVector<Value> newOperands;
256 newOperands.reserve(N: opToPad->getNumOperands());
257 for (OpOperand &opOperand : opToPad->getOpOperands()) {
258 Value operand = opOperand.get();
259 LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
260
261 // 2.a. Skip scalar-like operands.
262 Type operandType = operand.getType();
263 if (!isa<RankedTensorType>(Val: operandType)) {
264 assert((!isa<ShapedType>(operandType) || isa<VectorType>(operandType)) &&
265 "Unexpected non-vector ShapedType");
266 newOperands.push_back(Elt: operand);
267 continue;
268 }
269 // 2.a. Compute padded shape.
270 FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
271 computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
272 if (failed(Result: maybePaddedShape)) {
273 return rewriter.notifyMatchFailure(arg&: opToPad, msg: "could not pad op");
274 }
275
276 // 2.b. Expect proper `paddingValues`.
277 // TODO: we may want to allow garbage padding in the future, in which case
278 // we would just not assert.
279 if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
280 return rewriter.notifyMatchFailure(arg&: opToPad,
281 msg: "--no padding value specified");
282 }
283 Attribute paddingValueAttr =
284 options.paddingValues[opOperand.getOperandNumber()];
285
286 // 2.c. Perform actual padding.
287 Value paddedOperand = padOperand(
288 rewriter, opToPad, v: cast<TypedValue<RankedTensorType>>(Val&: operand),
289 paddedShape: *maybePaddedShape, paddingValueAttr);
290 LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
291
292 // 2.d. Perform actual padding.
293 newOperands.push_back(Elt: paddedOperand);
294 if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
295 padOps.push_back(Elt: padOp);
296 }
297
298 // 3. Form the resulting tensor::ExtractSliceOp.
299 ReifiedRankedShapedTypeDims reifiedResultShapes;
300 if (failed(Result: reifyResultShapes(b&: rewriter, op: opToPad, reifiedReturnShapes&: reifiedResultShapes))) {
301 LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
302 return rewriter.notifyMatchFailure(arg&: opToPad,
303 msg: "failed to reify result shapes");
304 }
305 assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
306 "expected same number of results");
307
308 // Clone `opToPad` to operate on the statically padded shapes.
309 auto resultTensorTypes =
310 ValueRange(newOperands).take_back(n: opToPad->getNumResults()).getTypes();
311 // clone **should** properly notify the rewriter.
312 TilingInterface paddedOp =
313 clone(b&: rewriter, op: opToPad, newResultTypes: resultTensorTypes, newOperands);
314 LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
315
316 // Recover the slice out of the new static results. This keeps the original
317 // opToPad around because it uses the dims of the original results.
318 SmallVector<Value> paddedSubtensorResults;
319 paddedSubtensorResults.reserve(N: opToPad->getNumResults());
320 for (const auto &en : llvm::enumerate(First: paddedOp->getResults())) {
321 Value paddedResult = en.value();
322 int64_t resultNumber = en.index();
323 int64_t rank = cast<RankedTensorType>(Val: paddedResult.getType()).getRank();
324 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0));
325 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1));
326 paddedSubtensorResults.push_back(Elt: rewriter.create<tensor::ExtractSliceOp>(
327 location: loc, args&: paddedResult, args&: offsets, args&: reifiedResultShapes[resultNumber],
328 args&: strides));
329 }
330
331 rewriter.replaceOp(op: opToPad, newValues: paddedSubtensorResults);
332
333 return paddedOp;
334}
335

source code of mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp