1//===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Affine/Utils.h"
12#include "mlir/Dialect/Arith/Utils/Utils.h"
13#include "mlir/Dialect/Linalg/IR/Linalg.h"
14#include "mlir/Dialect/Linalg/Utils/Utils.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Utils/Utils.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/Interfaces/TilingInterface.h"
20#include "mlir/Interfaces/ValueBoundsOpInterface.h"
21
22using namespace mlir;
23using namespace mlir::tensor;
24
25namespace {
26
27struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
28
29 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
30 auto padOp = cast<PadOp>(op);
31 SmallVector<utils::IteratorType> iteratorTypes(
32 padOp.getResultType().getRank(), utils::IteratorType::parallel);
33 return iteratorTypes;
34 }
35
36 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
37 ReifiedRankedShapedTypeDims reifiedShapes;
38 (void)reifyResultShapes(b, op, reifiedShapes);
39 OpFoldResult zero = b.getIndexAttr(0);
40 OpFoldResult one = b.getIndexAttr(1);
41 // Initialize all the ranges to {zero, one, one}. All the `ub`s are
42 // overwritten.
43 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
44 for (const auto &ub : enumerate(reifiedShapes[0]))
45 loopRanges[ub.index()].size = ub.value();
46 return loopRanges;
47 }
48
49 FailureOr<TilingResult>
50 getTiledImplementation(Operation *op, OpBuilder &b,
51 ArrayRef<OpFoldResult> offsets,
52 ArrayRef<OpFoldResult> sizes) const {
53 FailureOr<TilingResult> result =
54 tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes);
55 if (failed(result))
56 return failure();
57 return result.value();
58 }
59
60 LogicalResult
61 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
62 ArrayRef<OpFoldResult> offsets,
63 ArrayRef<OpFoldResult> sizes,
64 SmallVector<OpFoldResult> &resultOffsets,
65 SmallVector<OpFoldResult> &resultSizes) const {
66 resultOffsets.assign(offsets.begin(), offsets.end());
67 resultSizes.assign(sizes.begin(), sizes.end());
68 return success();
69 }
70};
71
72template <typename OpTy>
73static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
74 OpBuilder &builder) {
75 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
76 "applies to only pack or unpack operations");
77 OpBuilder::InsertionGuard g(builder);
78 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
79 : op.getDestRank();
80 OpFoldResult zero = builder.getIndexAttr(0);
81 OpFoldResult one = builder.getIndexAttr(1);
82 ReifiedRankedShapedTypeDims resultShape;
83 (void)reifyResultShapes(builder, op, resultShape);
84 SmallVector<Range> loopBounds(rank);
85 for (auto dim : llvm::seq<int64_t>(0, rank)) {
86 loopBounds[dim].offset = zero;
87 loopBounds[dim].stride = one;
88 loopBounds[dim].size = resultShape[0][dim];
89 }
90 return loopBounds;
91}
92
93static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
94 SmallVector<OpFoldResult> &sizes,
95 ArrayRef<int64_t> permutation) {
96 if (permutation.empty())
97 return;
98 applyPermutationToVector<OpFoldResult>(offsets, permutation);
99 applyPermutationToVector<OpFoldResult>(sizes, permutation);
100}
101
102struct PackOpTiling
103 : public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
104
105 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
106 // Note that here we only consider untiled dimensions and outer tiled data
107 // dimensions, the inner tiled data dimensions are materialized when
108 // building the body of the operation.
109 auto packOp = cast<PackOp>(op);
110 SmallVector<utils::IteratorType> iteratorTypes(
111 packOp.getSourceRank(), utils::IteratorType::parallel);
112 return iteratorTypes;
113 }
114
115 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
116 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
117 }
118
119 FailureOr<TilingResult>
120 getTiledImplementation(Operation *op, OpBuilder &b,
121 ArrayRef<OpFoldResult> offsets,
122 ArrayRef<OpFoldResult> sizes) const {
123 auto packOp = cast<PackOp>(op);
124 Location loc = packOp.getLoc();
125
126 // The tiling is applied on interchanged dimensions. We have to undo the
127 // interchange to map sizes and offsets to the original input.
128 int64_t inputRank = packOp.getSourceRank();
129 SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end());
130 SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end());
131 applyPermToRange(origOffsets, origSizes,
132 invertPermutationVector(packOp.getOuterDimsPerm()));
133
134 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
135 packOp.getDimAndTileMapping();
136 SmallVector<OpFoldResult> srcDimValues =
137 tensor::getMixedSizes(b, loc, packOp.getSource());
138 SmallVector<OpFoldResult> inputIndices, inputSizes;
139 for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
140 using AV = affine::AffineValueExpr;
141 affine::AffineBuilder ab(b, loc);
142 AffineExpr dim0, dim1, sym;
143 bindDims(b.getContext(), dim0, dim1);
144 bindSymbols(b.getContext(), sym);
145 if (dimAndTileMapping.count(dim)) {
146 // If the data dimension is tiled, the i-th index is the product of
147 // offset_i and tile_i, and the i-th size is the product of sizes_i and
148 // tile_i.
149 auto avOffset = AV(dim0).bind(origOffsets[dim]);
150 auto avSize = AV(dim0).bind(origSizes[dim]);
151 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
152 inputIndices.push_back(ab.mul(avOffset, avTileSize));
153 inputSizes.push_back(ab.mul(avSize, avTileSize));
154 } else {
155 inputIndices.push_back(origOffsets[dim]);
156 inputSizes.push_back(origSizes[dim]);
157 }
158
159 // Limit the size of the input operand for incomplete tiles.
160 if (packOp.getPaddingValue()) {
161 OpFoldResult dimSize = srcDimValues[dim];
162 auto avDimSize = AV(dim0).bind(dimSize);
163 auto avInputIdx = AV(dim1).bind(inputIndices.back());
164 inputSizes.back() =
165 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
166 }
167 }
168
169 auto oneAttr = b.getI64IntegerAttr(1);
170 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
171
172 SmallVector<Value> tiledOperands;
173 tiledOperands.push_back(b.create<ExtractSliceOp>(
174 loc, packOp.getSource(), inputIndices, inputSizes, strides));
175
176 SmallVector<OpFoldResult> outputOffsets, outputSizes;
177 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
178 outputSizes)))
179 return {};
180
181 strides.append(packOp.getDestRank() - inputRank, oneAttr);
182 auto extractSlice = b.create<ExtractSliceOp>(
183 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
184 tiledOperands.push_back(extractSlice);
185
186 if (auto val = packOp.getPaddingValue())
187 tiledOperands.push_back(val);
188 for (auto tile : packOp.getInnerTiles())
189 tiledOperands.push_back(tile);
190
191 Operation *tiledPackOp = b.create<PackOp>(
192 loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
193
194 return TilingResult{{tiledPackOp},
195 SmallVector<Value>(tiledPackOp->getResults())};
196 }
197
198 LogicalResult
199 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
200 ArrayRef<OpFoldResult> offsets,
201 ArrayRef<OpFoldResult> sizes,
202 SmallVector<OpFoldResult> &resultOffsets,
203 SmallVector<OpFoldResult> &resultSizes) const {
204 // The iteration domain is over outer dimensions of packed layout. In this
205 // context, the outer dimensions of `resultOffsets` are `offsets`. The
206 // inner dimensions of `resultOffsets` are zeros because tiling is not
207 // applied to them.
208 auto packOp = cast<PackOp>(op);
209 int64_t inputRank = packOp.getSourceRank();
210 int64_t outputRank = packOp.getDestRank();
211 auto zeroAttr = b.getI64IntegerAttr(0);
212 resultOffsets.assign(offsets.begin(), offsets.end());
213 resultOffsets.append(outputRank - inputRank, zeroAttr);
214
215 ReifiedRankedShapedTypeDims outputShape;
216 (void)reifyResultShapes(b, packOp, outputShape);
217 resultSizes.assign(sizes.begin(), sizes.end());
218 for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
219 resultSizes.push_back(outputShape[0][dataTileDim]);
220
221 return success();
222 }
223
224 FailureOr<TilingResult>
225 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
226 ArrayRef<OpFoldResult> offsets,
227 ArrayRef<OpFoldResult> sizes) const {
228 auto packOp = cast<PackOp>(op);
229 int64_t numTiles = packOp.getInnerDimsPos().size();
230
231 // tensor.pack op is fusible (as a producer) only if full inner tiles are
232 // iterated or inner dims are not tiled. Otherwise, it will generate a
233 // sequence of non-trivial ops (for partial tiles).
234 for (auto offset : offsets.take_back(numTiles))
235 if (!isConstantIntValue(offset, 0))
236 return failure();
237
238 for (auto iter :
239 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
240 if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter)))
241 return failure();
242
243 FailureOr<TilingResult> tilingResult = getTiledImplementation(
244 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
245 if (failed(tilingResult))
246 return failure();
247 return tilingResult.value();
248 }
249};
250
251struct UnpackTileDimInfo {
252 bool isAlignedToInnerTileSize;
253 OpFoldResult sourceOffset;
254 OpFoldResult sourceSize;
255 OpFoldResult resultOffset;
256 OpFoldResult destExpandedSize;
257};
258
259/// Returns the needed information for tiling unpack op on `tileDim` with given
260/// `tileOffset` and `tileSize`. For more details, see the comment of the
261/// `getTiledImplementation`.
262static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
263 int64_t tileDim,
264 OpFoldResult tileOffset,
265 OpFoldResult tileSize) {
266 UnpackTileDimInfo info;
267 Attribute zeroAttr = b.getIndexAttr(0);
268 Attribute oneAttr = b.getIndexAttr(1);
269 DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
270 unpackOp.getDimAndTileMapping();
271 // The dimension is not one of packed data dimension.
272 if (!dimAndTileMapping.count(tileDim)) {
273 info.isAlignedToInnerTileSize = true;
274 info.sourceOffset = tileOffset;
275 info.sourceSize = tileSize;
276 info.resultOffset = zeroAttr;
277 info.destExpandedSize = tileSize;
278 return info;
279 }
280
281 Location loc = unpackOp.getLoc();
282 using AV = affine::AffineValueExpr;
283 affine::AffineBuilder ab(b, loc);
284 AffineExpr dim0, dim1, sym0;
285 bindDims(b.getContext(), dim0, dim1);
286 bindSymbols(b.getContext(), sym0);
287
288 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
289
290 info.isAlignedToInnerTileSize = false;
291 FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
292 presburger::BoundType::UB, tileSize,
293 /*stopCondition=*/nullptr, /*closedUB=*/true);
294 std::optional<int64_t> cstInnerSize = getConstantIntValue(ofr: innerTileSize);
295 if (!failed(cstSize) && cstInnerSize) {
296 if (*cstSize % *cstInnerSize == 0)
297 info.isAlignedToInnerTileSize = true;
298
299 // If the tiling size equals to the inner tiling size, the outer dims are
300 // always 1.
301 if (*cstInnerSize == *cstSize) {
302 auto lhs = AV(dim0).bind(v: tileOffset);
303 auto rhs = AV(dim1).bind(v: innerTileSize);
304 info.sourceOffset = ab.floor(lhs, rhs: rhs);
305 info.sourceSize = oneAttr;
306 info.resultOffset = zeroAttr;
307 info.destExpandedSize = tileSize;
308 return info;
309 }
310 }
311
312 if (info.isAlignedToInnerTileSize) {
313 info.sourceOffset =
314 ab.floor(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: innerTileSize));
315 info.resultOffset = zeroAttr;
316 info.destExpandedSize = tileSize;
317
318 // The ceilDiv is needed here because there could be incomplete tile even
319 // it is perfect tiling cases. E.g.,
320 // %0 = unpack tensor<33x2xf32> into tensor<64xf32>
321 // If the tiling size is 32, there will be 3 tiles. Two of them have
322 // size=32; one of them have size=2. The size is represented using
323 // affine_min op; we need ceilDiv.
324 info.sourceSize =
325 ab.ceil(lhs: AV(dim0).bind(v: tileSize), rhs: AV(dim1).bind(v: innerTileSize));
326 return info;
327 }
328
329 affine::DivModValue firstCoord = affine::getDivMod(
330 b, loc, lhs: getValueOrCreateConstantIndexOp(b, loc, ofr: tileOffset),
331 rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize));
332 OpFoldResult tileExclusiveBound =
333 ab.add(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: tileSize));
334 affine::DivModValue lastCoord = affine::getDivMod(
335 b, loc,
336 lhs: getValueOrCreateConstantIndexOp(
337 b, loc,
338 ofr: ab.sub(lhs: AV(dim0).bind(v: tileExclusiveBound), rhs: AV(dim1).bind(v: oneAttr))),
339 rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize));
340
341 OpFoldResult lengthMinusOne = ab.sub(lhs: AV(dim0).bind(v: lastCoord.quotient),
342 rhs: AV(dim1).bind(v: firstCoord.quotient));
343 info.sourceSize =
344 ab.add(lhs: AV(dim0).bind(v: lengthMinusOne), rhs: AV(dim1).bind(v: oneAttr));
345 info.sourceOffset = firstCoord.quotient;
346 info.resultOffset = firstCoord.remainder;
347 // Do not create an Affine ops for expanded size because the affine op is too
348 // complicated which would trigger an issue in affine ops simplification.
349 info.destExpandedSize = b.createOrFold<arith::MulIOp>(
350 loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize),
351 getValueOrCreateConstantIndexOp(b, loc, innerTileSize));
352 return info;
353}
354
355struct UnPackOpTiling
356 : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
357
358 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
359 auto unpackOp = cast<UnPackOp>(op);
360 SmallVector<utils::IteratorType> iteratorTypes(
361 unpackOp.getDestRank(), utils::IteratorType::parallel);
362 return iteratorTypes;
363 }
364
365 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
366 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
367 }
368
369 /// There are two cases in tiling unpack ops. If the tiling size is aligned to
370 /// the inner tile size, the corresponding tiles of source are all complete.
371 /// Otherwise, there are in-complete tiles. We will need to expand the slice
372 /// of source for getting complete tiles. The tiled unpack op unpacks more
373 /// data from source, so We'll need an extract_slice op to shift and truncate
374 /// the output.
375 /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The
376 /// coordinates of second tile (i.e., result[15..31]) are
377 /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last
378 /// row are incomplete tiles. To represent the unpack op, we have to complete
379 /// the rows. I.e., the input coordinates would start with (1, 0); end with
380 /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements
381 /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we
382 /// can get the actual result.
383 FailureOr<TilingResult>
384 getTiledImplementation(Operation *op, OpBuilder &b,
385 ArrayRef<OpFoldResult> offsets,
386 ArrayRef<OpFoldResult> sizes) const {
387 auto unpackOp = cast<UnPackOp>(op);
388 int64_t srcRank = unpackOp.getSourceRank();
389 int64_t destRank = unpackOp.getDestRank();
390 int64_t numInnerTiles = srcRank - destRank;
391 Location loc = unpackOp.getLoc();
392
393 // The perfect tiling case indicates that the tiling sizes are multiple of
394 // inner_tile_size. In this context, no extra data is needed when
395 // representing the tiled unpack op.
396 bool isPerfectTilingCase = true;
397 Attribute oneAttr = b.getIndexAttr(1);
398 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
399 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
400 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
401 for (auto dim : llvm::seq<int64_t>(0, destRank)) {
402 UnpackTileDimInfo info =
403 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
404 if (!info.isAlignedToInnerTileSize)
405 isPerfectTilingCase = false;
406 sliceSrcIndices.push_back(info.sourceOffset);
407 sliceSrcSizes.push_back(info.sourceSize);
408 destExpandedSizes.push_back(info.destExpandedSize);
409 resultOffsetsFromDest.push_back(info.resultOffset);
410 }
411
412 // The tiling is applied on destination dimensions. We have to apply the
413 // interchange on source dimensions if outer_dims_perm is set.
414 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
415 unpackOp.getOuterDimsPerm());
416 Attribute zeroAttr = b.getIndexAttr(0);
417 sliceSrcIndices.append(numInnerTiles, zeroAttr);
418 sliceSrcSizes.append(unpackOp.getMixedTiles());
419 sliceSrcStrides.append(numInnerTiles, oneAttr);
420 Value sliceSource =
421 b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
422 sliceSrcSizes, sliceSrcStrides);
423
424 SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
425 Value sliceDest;
426 if (isPerfectTilingCase) {
427 sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
428 sizes, destStrides);
429 } else {
430 sliceDest = b.create<EmptyOp>(loc, destExpandedSizes,
431 unpackOp.getDestType().getElementType());
432 }
433
434 SmallVector<Value> tiledOperands = {sliceSource, sliceDest};
435 for (auto tile : unpackOp.getInnerTiles())
436 tiledOperands.push_back(tile);
437
438 Operation *tiledUnpackOp = b.create<UnPackOp>(
439 loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs());
440
441 if (isPerfectTilingCase)
442 return TilingResult{{tiledUnpackOp},
443 SmallVector<Value>(tiledUnpackOp->getResults())};
444
445 auto extractSlice =
446 b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0),
447 resultOffsetsFromDest, sizes, destStrides);
448 return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}};
449 }
450
451 LogicalResult
452 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
453 ArrayRef<OpFoldResult> offsets,
454 ArrayRef<OpFoldResult> sizes,
455 SmallVector<OpFoldResult> &resultOffsets,
456 SmallVector<OpFoldResult> &resultSizes) const {
457 resultOffsets = llvm::to_vector(offsets);
458 resultSizes = llvm::to_vector(sizes);
459 return success();
460 }
461
462 FailureOr<TilingResult>
463 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
464 ArrayRef<OpFoldResult> offsets,
465 ArrayRef<OpFoldResult> sizes) const {
466 FailureOr<TilingResult> tilingResult =
467 getTiledImplementation(op, b, offsets, sizes);
468 if (failed(tilingResult))
469 return failure();
470 return tilingResult.value();
471 }
472};
473
474} // namespace
475
476FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
477 tensor::PadOp padOp,
478 ArrayRef<OpFoldResult> offsets,
479 ArrayRef<OpFoldResult> sizes,
480 bool generateZeroSliceGuard) {
481 // Only constant padding value supported.
482 Value padValue = padOp.getConstantPaddingValue();
483 if (!padValue)
484 return failure();
485
486 // Helper variables and functions for various arithmetic operations. These
487 // are used extensively for computing new offset/length and padding values.
488 Location loc = padOp->getLoc();
489 AffineExpr dim0, dim1;
490 bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1);
491 // Add two integers.
492 auto addMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 + dim1});
493 auto add = [&](OpFoldResult v1, OpFoldResult v2) {
494 return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2});
495 };
496 // Subtract two integers.
497 auto subMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 - dim1});
498 auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
499 return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2});
500 };
501 // Take the minimum of two integers.
502 auto idMap = AffineMap::getMultiDimIdentityMap(numDims: 2, context: b.getContext());
503 auto min = [&](OpFoldResult v1, OpFoldResult v2) {
504 return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2});
505 };
506 // Take the maximum of two integers.
507 auto max = [&](OpFoldResult v1, OpFoldResult v2) {
508 return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2});
509 };
510 // Zero index-typed integer.
511 OpFoldResult zero = b.getIndexAttr(0);
512
513 // Compute new offsets, lengths, low padding, high padding.
514 SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
515 SmallVector<OpFoldResult> newLows, newHighs;
516 // Set to true if the original data source is not read at all.
517 bool hasZeroLen = false;
518 // Same as hasZeroLen, but for dynamic dimension sizes. This condition
519 // is true if the original data source turns out to be unused at runtime.
520 Value dynHasZeroLenCond;
521
522 int64_t rank = padOp.getSourceType().getRank();
523 for (unsigned dim = 0; dim < rank; ++dim) {
524 auto low = padOp.getMixedLowPad()[dim];
525 bool hasLowPad = !isConstantIntValue(low, 0);
526 auto high = padOp.getMixedHighPad()[dim];
527 bool hasHighPad = !isConstantIntValue(high, 0);
528 auto offset = offsets[dim];
529 auto length = sizes[dim];
530 auto srcSize = tensor::getMixedSize(builder&: b, loc, value: padOp.getSource(), dim);
531
532 // The new amount of low padding is `low - offset`. Except for the case
533 // where none of the low padding is read. In that case, the new amount of
534 // low padding is zero.
535 //
536 // Optimization: If low = 0, then newLow = 0.
537 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
538 newLows.push_back(Elt: newLow);
539
540 // Start reading the data from position `offset - low`. Since the original
541 // read may have started in the low padding zone, this value could be
542 // negative. Therefore, start reading from:
543 //
544 // max(offset - low, 0)
545 //
546 // The original read could also have started in the high padding zone.
547 // In that case, set the offset to the end of source tensor. The new
548 // ExtractSliceOp length will be zero in that case. (Effectively reading
549 // no data from the source.)
550 //
551 // Optimization: If low = 0, then the formula can be simplified.
552 OpFoldResult newOffset = hasLowPad
553 ? min(max(sub(offset, low), zero), srcSize)
554 : min(offset, srcSize);
555 newOffsets.push_back(Elt: newOffset);
556
557 // The original ExtractSliceOp was reading until position `offset +
558 // length`. Therefore, the corresponding position within the source tensor
559 // is:
560 //
561 // offset + length - low
562 //
563 // In case the original ExtractSliceOp stopped reading within the low
564 // padding zone, this value can be negative. In that case, the end
565 // position of the read should be zero. (Similar to newOffset.)
566 //
567 // The original read could also have stopped in the high padding zone.
568 // In that case, set the end positition of the read should be the end of
569 // the source tensor. (Similar to newOffset.)
570 //
571 // endLoc = min(max(offset - low + length, 0), srcSize)
572 //
573 // The new ExtractSliceOp length is `endLoc - newOffset`.
574 //
575 // Optimization: If low = 0, then the formula can be simplified.
576 OpFoldResult endLoc =
577 hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize)
578 : min(add(offset, length), srcSize);
579 OpFoldResult newLength = sub(endLoc, newOffset);
580 newLengths.push_back(Elt: newLength);
581
582 // Check if newLength is zero. In that case, no SubTensorOp should be
583 // executed.
584 if (isConstantIntValue(ofr: newLength, value: 0)) {
585 hasZeroLen = true;
586 } else if (!hasZeroLen) {
587 Value check = b.create<arith::CmpIOp>(
588 loc, arith::CmpIPredicate::eq,
589 getValueOrCreateConstantIndexOp(b, loc, newLength),
590 getValueOrCreateConstantIndexOp(b, loc, zero));
591 dynHasZeroLenCond =
592 dynHasZeroLenCond
593 ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
594 : check;
595 }
596
597 // The amount of high padding is simply the number of elements remaining,
598 // so that the result has the same length as the original ExtractSliceOp.
599 // As an optimization, if the original high padding is zero, then the new
600 // high padding must also be zero.
601 OpFoldResult newHigh =
602 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
603 newHighs.push_back(Elt: newHigh);
604
605 // Only unit stride supported.
606 newStrides.push_back(b.getIndexAttr(1));
607 }
608
609 // The shape of the result can be obtained from the sizes passed in.
610 SmallVector<Value> dynDims;
611 SmallVector<int64_t> shape;
612 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynDims, staticVec&: shape);
613 RankedTensorType resultType =
614 RankedTensorType::get(shape, padOp.getResultType().getElementType());
615
616 // Insert cast to ensure that types match. (May be folded away.)
617 auto castResult = [&](Value val) -> Value {
618 if (resultType == val.getType())
619 return val;
620 return b.create<tensor::CastOp>(loc, resultType, val);
621 };
622
623 // In cases where the original data source is unused: Emit a GenerateOp and
624 // do not generate a SliceOp. (The result shape of the SliceOp would
625 // have a dimension of size 0, the semantics of which is unclear.)
626 auto createGenerateOp = [&]() {
627 // Create GenerateOp.
628 auto generateOp = b.create<tensor::GenerateOp>(
629 loc, resultType, dynDims,
630 [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
631 builder.create<tensor::YieldOp>(gLoc, padValue);
632 });
633 return generateOp;
634 };
635
636 // Emit a SliceOp and a PadOp. Should not be used in cases where
637 // the result shape of the new SliceOp has a zero dimension.
638 auto createPadOfExtractSlice = [&]() {
639 // Create pad(extract_slice(x)).
640 Value newSliceOp = b.create<tensor::ExtractSliceOp>(
641 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
642 auto newPadOp = b.create<PadOp>(
643 loc, Type(), newSliceOp, newLows, newHighs,
644 /*nofold=*/padOp.getNofold(),
645 getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
646
647 // Copy region to new PadOp.
648 IRMapping bvm;
649 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
650
651 // Cast result and return.
652 return newPadOp;
653 };
654
655 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
656 // the original data source x is not used.
657 if (hasZeroLen) {
658 Operation *generateOp = createGenerateOp();
659 return TilingResult{.tiledOps: {generateOp}, .tiledValues: {castResult(generateOp->getResult(idx: 0))}};
660 }
661
662 // If there are dynamic dimensions: Generate an scf.if check to avoid
663 // creating SliceOps with result dimensions of size 0 at runtime.
664 if (generateZeroSliceGuard && dynHasZeroLenCond) {
665 Operation *thenOp;
666 Operation *elseOp;
667 auto result = b.create<scf::IfOp>(
668 loc, dynHasZeroLenCond,
669 /*thenBuilder=*/
670 [&](OpBuilder &b, Location loc) {
671 thenOp = createGenerateOp();
672 b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0)));
673 },
674 /*elseBuilder=*/
675 [&](OpBuilder &b, Location loc) {
676 elseOp = createPadOfExtractSlice();
677 b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0)));
678 });
679 return TilingResult{.tiledOps: {elseOp}, .tiledValues: SmallVector<Value>(result->getResults())};
680 }
681
682 Operation *newPadOp = createPadOfExtractSlice();
683 return TilingResult{.tiledOps: {newPadOp}, .tiledValues: {castResult(newPadOp->getResult(idx: 0))}};
684}
685
686void mlir::tensor::registerTilingInterfaceExternalModels(
687 DialectRegistry &registry) {
688 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) {
689 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
690 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
691 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
692 });
693}
694
695void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps(
696 DialectRegistry &registry) {
697 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) {
698 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
699 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
700 });
701}
702

source code of mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp