1//===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
10#include "mlir/Dialect/Affine/IR/AffineOps.h"
11#include "mlir/Dialect/Affine/Utils.h"
12#include "mlir/Dialect/Arith/Utils/Utils.h"
13#include "mlir/Dialect/Linalg/Utils/Utils.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Interfaces/InferTypeOpInterface.h"
17#include "mlir/Interfaces/TilingInterface.h"
18
19using namespace mlir;
20using namespace mlir::tensor;
21
22namespace {
23
24struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
25
26 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
27 auto padOp = cast<PadOp>(Val: op);
28 SmallVector<utils::IteratorType> iteratorTypes(
29 padOp.getResultType().getRank(), utils::IteratorType::parallel);
30 return iteratorTypes;
31 }
32
33 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
34 ReifiedRankedShapedTypeDims reifiedShapes;
35 (void)reifyResultShapes(b, op, reifiedReturnShapes&: reifiedShapes);
36 OpFoldResult zero = b.getIndexAttr(value: 0);
37 OpFoldResult one = b.getIndexAttr(value: 1);
38 // Initialize all the ranges to {zero, one, one}. All the `ub`s are
39 // overwritten.
40 SmallVector<Range> loopRanges(reifiedShapes[0].size(), {.offset: zero, .size: one, .stride: one});
41 for (const auto &ub : enumerate(First&: reifiedShapes[0]))
42 loopRanges[ub.index()].size = ub.value();
43 return loopRanges;
44 }
45
46 FailureOr<TilingResult>
47 getTiledImplementation(Operation *op, OpBuilder &b,
48 ArrayRef<OpFoldResult> offsets,
49 ArrayRef<OpFoldResult> sizes) const {
50 FailureOr<TilingResult> result =
51 tensor::bubbleUpPadSlice(b, padOp: cast<PadOp>(Val: op), offsets, sizes);
52 if (failed(Result: result))
53 return failure();
54 return result.value();
55 }
56
57 LogicalResult
58 getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
59 ArrayRef<OpFoldResult> offsets,
60 ArrayRef<OpFoldResult> sizes,
61 SmallVector<OpFoldResult> &resultOffsets,
62 SmallVector<OpFoldResult> &resultSizes) const {
63 resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end());
64 resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end());
65 return success();
66 }
67
68 LogicalResult getIterationDomainTileFromResultTile(
69 Operation *op, OpBuilder &b, unsigned resultNumber,
70 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
71 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
72 SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
73 iterDomainOffsets.assign(in_start: offsets.begin(), in_end: offsets.end());
74 iterDomainSizes.assign(in_start: sizes.begin(), in_end: sizes.end());
75 return success();
76 }
77
78 FailureOr<TilingResult>
79 generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
80 ArrayRef<OpFoldResult> offsets,
81 ArrayRef<OpFoldResult> sizes) const {
82 return getTiledImplementation(op, b, offsets, sizes);
83 }
84};
85
86} // namespace
87
88FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
89 tensor::PadOp padOp,
90 ArrayRef<OpFoldResult> offsets,
91 ArrayRef<OpFoldResult> sizes,
92 bool generateZeroSliceGuard) {
93 // Only constant padding value supported.
94 Value padValue = padOp.getConstantPaddingValue();
95 if (!padValue)
96 return failure();
97
98 // Helper variables and functions for various arithmetic operations. These
99 // are used extensively for computing new offset/length and padding values.
100 Location loc = padOp->getLoc();
101 AffineExpr dim0, dim1;
102 bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1);
103 // Subtract two integers.
104 auto subMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 - dim1});
105 auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
106 return affine::makeComposedFoldedAffineApply(b, loc, map: subMap, operands: {v1, v2});
107 };
108 // Take the minimum of two integers.
109 auto idMap = AffineMap::getMultiDimIdentityMap(numDims: 2, context: b.getContext());
110 auto min = [&](OpFoldResult v1, OpFoldResult v2) {
111 return affine::makeComposedFoldedAffineMin(b, loc, map: idMap, operands: {v1, v2});
112 };
113 // Take the maximum of two integers.
114 auto max = [&](OpFoldResult v1, OpFoldResult v2) {
115 return affine::makeComposedFoldedAffineMax(b, loc, map: idMap, operands: {v1, v2});
116 };
117 // Zero index-typed integer.
118 OpFoldResult zero = b.getIndexAttr(value: 0);
119
120 // Compute new offsets, lengths, low padding, high padding.
121 SmallVector<OpFoldResult> newOffsets, newLengths;
122 SmallVector<OpFoldResult> newLows, newHighs;
123 // Set to true if the original data source is not read at all.
124 bool hasZeroLen = false;
125 // Same as hasZeroLen, but for dynamic dimension sizes. This condition
126 // is true if the original data source turns out to be unused at runtime.
127 Value dynHasZeroLenCond;
128
129 int64_t rank = padOp.getSourceType().getRank();
130 // Only unit stride supported.
131 SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(value: 1));
132 for (unsigned dim = 0; dim < rank; ++dim) {
133 auto low = padOp.getMixedLowPad()[dim];
134 bool hasLowPad = !isZeroInteger(v: low);
135 auto high = padOp.getMixedHighPad()[dim];
136 bool hasHighPad = !isZeroInteger(v: high);
137 auto offset = offsets[dim];
138 auto length = sizes[dim];
139 // If the dim has no padding, we dont need to calculate new values for that
140 // dim as the exisiting ones are correct even after the pattern.
141 if (!hasLowPad && !hasHighPad) {
142 newOffsets.push_back(Elt: offset);
143 newLengths.push_back(Elt: length);
144 newLows.push_back(Elt: low);
145 newHighs.push_back(Elt: high);
146 continue;
147 }
148
149 auto srcSize = tensor::getMixedSize(builder&: b, loc, value: padOp.getSource(), dim);
150
151 // The new amount of low padding is `low - offset`. Except for the case
152 // where none of the low padding is read. In that case, the new amount of
153 // low padding is zero.
154 //
155 // Optimization: If low = 0, then newLow = 0.
156 OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
157 newLows.push_back(Elt: newLow);
158
159 // Start reading the data from position `offset - low`. Since the original
160 // read may have started in the low padding zone, this value could be
161 // negative. Therefore, start reading from:
162 //
163 // max(offset - low, 0)
164 //
165 // The original read could also have started in the high padding zone.
166 // In that case, set the offset to the end of source tensor. The new
167 // ExtractSliceOp length will be zero in that case. (Effectively reading
168 // no data from the source.)
169 //
170 // Optimization: If low = 0, then the formula can be simplified.
171 OpFoldResult newOffset = hasLowPad
172 ? min(max(sub(offset, low), zero), srcSize)
173 : min(offset, srcSize);
174 newOffsets.push_back(Elt: newOffset);
175
176 // The original ExtractSliceOp was reading until position `offset +
177 // length`. Therefore, the corresponding position within the source tensor
178 // is:
179 //
180 // offset + length - low
181 //
182 // In case the original ExtractSliceOp stopped reading within the low
183 // padding zone, this value can be negative. In that case, the end
184 // position of the read should be zero. (Similar to newOffset.)
185 //
186 // The original read could also have stopped in the high padding zone.
187 // In that case, set the end positition of the read should be the end of
188 // the source tensor. (Similar to newOffset.)
189 // srcSize - newOffset represents how much length we have available
190 // and length - newLow represents how much length we want at most.
191 // Note that there are many ways to order this indexing math to compute
192 // newLength, but we want to make sure that the final affine.min ops in the
193 // sequence are bounding the index to as small a value as possible. If
194 // ValueBoundsOpInterface is used, this calculation will get upper bounds
195 // from the affine.min ops, so we want to use the smallest known value to
196 // set the bound at the end of the computation sequence. In this case, the
197 // index will be upper bounded by length - newLow.
198 OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow));
199 // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming
200 // length >= 0.
201 if (hasLowPad)
202 newLength = max(newLength, zero);
203 newLengths.push_back(Elt: newLength);
204
205 // Check if newLength is zero. In that case, no SubTensorOp should be
206 // executed.
207 if (isZeroInteger(v: newLength)) {
208 hasZeroLen = true;
209 } else if (!hasZeroLen) {
210 Value check = b.create<arith::CmpIOp>(
211 location: loc, args: arith::CmpIPredicate::eq,
212 args: getValueOrCreateConstantIndexOp(b, loc, ofr: newLength),
213 args: getValueOrCreateConstantIndexOp(b, loc, ofr: zero));
214 dynHasZeroLenCond =
215 dynHasZeroLenCond
216 ? b.create<arith::OrIOp>(location: loc, args&: check, args&: dynHasZeroLenCond)
217 : check;
218 }
219
220 // The amount of high padding is simply the number of elements remaining,
221 // so that the result has the same length as the original ExtractSliceOp.
222 // As an optimization, if the original high padding is zero, then the new
223 // high padding must also be zero.
224 OpFoldResult newHigh =
225 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
226 newHighs.push_back(Elt: newHigh);
227 }
228
229 // The shape of the result can be obtained from the sizes passed in.
230 SmallVector<Value> dynDims;
231 SmallVector<int64_t> shape;
232 dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynDims, staticVec&: shape);
233 RankedTensorType resultType =
234 RankedTensorType::get(shape, elementType: padOp.getResultType().getElementType());
235
236 // Insert cast to ensure that types match. (May be folded away.)
237 auto castResult = [&](Value val) -> Value {
238 if (resultType == val.getType())
239 return val;
240 return b.create<tensor::CastOp>(location: loc, args&: resultType, args&: val);
241 };
242
243 // In cases where the original data source is unused: Emit a GenerateOp and
244 // do not generate a SliceOp. (The result shape of the SliceOp would
245 // have a dimension of size 0, the semantics of which is unclear.)
246 auto createGenerateOp = [&]() {
247 // Create GenerateOp.
248 auto generateOp = b.create<tensor::GenerateOp>(
249 location: loc, args&: resultType, args&: dynDims,
250 args: [&](OpBuilder &builder, Location gLoc, ValueRange indices) {
251 builder.create<tensor::YieldOp>(location: gLoc, args&: padValue);
252 });
253 return generateOp;
254 };
255
256 // Emit a SliceOp and a PadOp. Should not be used in cases where
257 // the result shape of the new SliceOp has a zero dimension.
258 auto createPadOfExtractSlice = [&]() {
259 // Create pad(extract_slice(x)).
260 auto newSliceOp = b.create<tensor::ExtractSliceOp>(
261 location: loc, args: padOp.getSource(), args&: newOffsets, args&: newLengths, args&: newStrides);
262 auto newPadOp = b.create<PadOp>(
263 location: loc, args: Type(), args&: newSliceOp, args&: newLows, args&: newHighs,
264 /*nofold=*/args: padOp.getNofold(),
265 args: getPrunedAttributeList(op: padOp, elidedAttrs: PadOp::getAttributeNames()));
266
267 // Copy region to new PadOp.
268 IRMapping bvm;
269 padOp.getRegion().cloneInto(dest: &newPadOp.getRegion(), mapper&: bvm);
270
271 // Cast result and return.
272 return std::make_tuple(args&: newPadOp, args&: newSliceOp);
273 };
274
275 // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that
276 // the original data source x is not used.
277 if (hasZeroLen) {
278 Operation *generateOp = createGenerateOp();
279 return TilingResult{.tiledOps: {generateOp},
280 .tiledValues: {castResult(generateOp->getResult(idx: 0))},
281 /*generatedSlices=*/{}};
282 }
283
284 // If there are dynamic dimensions: Generate an scf.if check to avoid
285 // creating SliceOps with result dimensions of size 0 at runtime.
286 if (generateZeroSliceGuard && dynHasZeroLenCond) {
287 Operation *thenOp;
288 Operation *elseOp;
289 Operation *sliceOp;
290 auto result = b.create<scf::IfOp>(
291 location: loc, args&: dynHasZeroLenCond,
292 /*thenBuilder=*/
293 args: [&](OpBuilder &b, Location loc) {
294 thenOp = createGenerateOp();
295 b.create<scf::YieldOp>(location: loc, args: castResult(thenOp->getResult(idx: 0)));
296 },
297 /*elseBuilder=*/
298 args: [&](OpBuilder &b, Location loc) {
299 std::tie(args&: elseOp, args&: sliceOp) = createPadOfExtractSlice();
300 b.create<scf::YieldOp>(location: loc, args: castResult(elseOp->getResult(idx: 0)));
301 });
302 return TilingResult{
303 .tiledOps: {elseOp}, .tiledValues: SmallVector<Value>(result->getResults()), .generatedSlices: {sliceOp}};
304 }
305
306 auto [newPadOp, sliceOp] = createPadOfExtractSlice();
307 return TilingResult{
308 .tiledOps: {newPadOp}, .tiledValues: {castResult(newPadOp->getResult(idx: 0))}, .generatedSlices: {sliceOp}};
309}
310
311void mlir::tensor::registerTilingInterfaceExternalModels(
312 DialectRegistry &registry) {
313 registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) {
314 tensor::PadOp::attachInterface<PadOpTiling>(context&: *ctx);
315 });
316}
317

source code of mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp