1 | //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" |
10 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
11 | #include "mlir/Dialect/Affine/Utils.h" |
12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
18 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
19 | #include "mlir/Interfaces/TilingInterface.h" |
20 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::tensor; |
24 | |
25 | namespace { |
26 | |
27 | struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { |
28 | |
29 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
30 | auto padOp = cast<PadOp>(op); |
31 | SmallVector<utils::IteratorType> iteratorTypes( |
32 | padOp.getResultType().getRank(), utils::IteratorType::parallel); |
33 | return iteratorTypes; |
34 | } |
35 | |
36 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
37 | ReifiedRankedShapedTypeDims reifiedShapes; |
38 | (void)reifyResultShapes(b, op, reifiedShapes); |
39 | OpFoldResult zero = b.getIndexAttr(0); |
40 | OpFoldResult one = b.getIndexAttr(1); |
41 | // Initialize all the ranges to {zero, one, one}. All the `ub`s are |
42 | // overwritten. |
43 | SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); |
44 | for (const auto &ub : enumerate(reifiedShapes[0])) |
45 | loopRanges[ub.index()].size = ub.value(); |
46 | return loopRanges; |
47 | } |
48 | |
49 | FailureOr<TilingResult> |
50 | getTiledImplementation(Operation *op, OpBuilder &b, |
51 | ArrayRef<OpFoldResult> offsets, |
52 | ArrayRef<OpFoldResult> sizes) const { |
53 | FailureOr<TilingResult> result = |
54 | tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); |
55 | if (failed(result)) |
56 | return failure(); |
57 | return result.value(); |
58 | } |
59 | |
60 | LogicalResult |
61 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
62 | ArrayRef<OpFoldResult> offsets, |
63 | ArrayRef<OpFoldResult> sizes, |
64 | SmallVector<OpFoldResult> &resultOffsets, |
65 | SmallVector<OpFoldResult> &resultSizes) const { |
66 | resultOffsets.assign(offsets.begin(), offsets.end()); |
67 | resultSizes.assign(sizes.begin(), sizes.end()); |
68 | return success(); |
69 | } |
70 | }; |
71 | |
72 | template <typename OpTy> |
73 | static SmallVector<Range> getPackUnPackIterationDomain(OpTy op, |
74 | OpBuilder &builder) { |
75 | static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, |
76 | "applies to only pack or unpack operations" ); |
77 | OpBuilder::InsertionGuard g(builder); |
78 | int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank() |
79 | : op.getDestRank(); |
80 | OpFoldResult zero = builder.getIndexAttr(0); |
81 | OpFoldResult one = builder.getIndexAttr(1); |
82 | ReifiedRankedShapedTypeDims resultShape; |
83 | (void)reifyResultShapes(builder, op, resultShape); |
84 | SmallVector<Range> loopBounds(rank); |
85 | for (auto dim : llvm::seq<int64_t>(0, rank)) { |
86 | loopBounds[dim].offset = zero; |
87 | loopBounds[dim].stride = one; |
88 | loopBounds[dim].size = resultShape[0][dim]; |
89 | } |
90 | return loopBounds; |
91 | } |
92 | |
93 | static void applyPermToRange(SmallVector<OpFoldResult> &offsets, |
94 | SmallVector<OpFoldResult> &sizes, |
95 | ArrayRef<int64_t> permutation) { |
96 | if (permutation.empty()) |
97 | return; |
98 | applyPermutationToVector<OpFoldResult>(offsets, permutation); |
99 | applyPermutationToVector<OpFoldResult>(sizes, permutation); |
100 | } |
101 | |
102 | struct PackOpTiling |
103 | : public TilingInterface::ExternalModel<PackOpTiling, PackOp> { |
104 | |
105 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
106 | // Note that here we only consider untiled dimensions and outer tiled data |
107 | // dimensions, the inner tiled data dimensions are materialized when |
108 | // building the body of the operation. |
109 | auto packOp = cast<PackOp>(op); |
110 | SmallVector<utils::IteratorType> iteratorTypes( |
111 | packOp.getSourceRank(), utils::IteratorType::parallel); |
112 | return iteratorTypes; |
113 | } |
114 | |
115 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
116 | return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b); |
117 | } |
118 | |
119 | FailureOr<TilingResult> |
120 | getTiledImplementation(Operation *op, OpBuilder &b, |
121 | ArrayRef<OpFoldResult> offsets, |
122 | ArrayRef<OpFoldResult> sizes) const { |
123 | auto packOp = cast<PackOp>(op); |
124 | Location loc = packOp.getLoc(); |
125 | |
126 | // The tiling is applied on interchanged dimensions. We have to undo the |
127 | // interchange to map sizes and offsets to the original input. |
128 | int64_t inputRank = packOp.getSourceRank(); |
129 | SmallVector<OpFoldResult> origOffsets(offsets.begin(), offsets.end()); |
130 | SmallVector<OpFoldResult> origSizes(sizes.begin(), sizes.end()); |
131 | applyPermToRange(origOffsets, origSizes, |
132 | invertPermutationVector(packOp.getOuterDimsPerm())); |
133 | |
134 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
135 | packOp.getDimAndTileMapping(); |
136 | SmallVector<OpFoldResult> srcDimValues = |
137 | tensor::getMixedSizes(b, loc, packOp.getSource()); |
138 | SmallVector<OpFoldResult> inputIndices, inputSizes; |
139 | for (auto dim : llvm::seq<int64_t>(0, inputRank)) { |
140 | using AV = affine::AffineValueExpr; |
141 | affine::AffineBuilder ab(b, loc); |
142 | AffineExpr dim0, dim1, sym; |
143 | bindDims(b.getContext(), dim0, dim1); |
144 | bindSymbols(b.getContext(), sym); |
145 | if (dimAndTileMapping.count(dim)) { |
146 | // If the data dimension is tiled, the i-th index is the product of |
147 | // offset_i and tile_i, and the i-th size is the product of sizes_i and |
148 | // tile_i. |
149 | auto avOffset = AV(dim0).bind(origOffsets[dim]); |
150 | auto avSize = AV(dim0).bind(origSizes[dim]); |
151 | auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); |
152 | inputIndices.push_back(ab.mul(avOffset, avTileSize)); |
153 | inputSizes.push_back(ab.mul(avSize, avTileSize)); |
154 | } else { |
155 | inputIndices.push_back(origOffsets[dim]); |
156 | inputSizes.push_back(origSizes[dim]); |
157 | } |
158 | |
159 | // Limit the size of the input operand for incomplete tiles. |
160 | if (packOp.getPaddingValue()) { |
161 | OpFoldResult dimSize = srcDimValues[dim]; |
162 | auto avDimSize = AV(dim0).bind(dimSize); |
163 | auto avInputIdx = AV(dim1).bind(inputIndices.back()); |
164 | inputSizes.back() = |
165 | ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)}); |
166 | } |
167 | } |
168 | |
169 | auto oneAttr = b.getI64IntegerAttr(1); |
170 | SmallVector<OpFoldResult> strides(inputRank, oneAttr); |
171 | |
172 | SmallVector<Value> tiledOperands; |
173 | tiledOperands.push_back(b.create<ExtractSliceOp>( |
174 | loc, packOp.getSource(), inputIndices, inputSizes, strides)); |
175 | |
176 | SmallVector<OpFoldResult> outputOffsets, outputSizes; |
177 | if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, |
178 | outputSizes))) |
179 | return {}; |
180 | |
181 | strides.append(packOp.getDestRank() - inputRank, oneAttr); |
182 | auto = b.create<ExtractSliceOp>( |
183 | loc, packOp.getDest(), outputOffsets, outputSizes, strides); |
184 | tiledOperands.push_back(extractSlice); |
185 | |
186 | if (auto val = packOp.getPaddingValue()) |
187 | tiledOperands.push_back(val); |
188 | for (auto tile : packOp.getInnerTiles()) |
189 | tiledOperands.push_back(tile); |
190 | |
191 | Operation *tiledPackOp = b.create<PackOp>( |
192 | loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); |
193 | |
194 | return TilingResult{{tiledPackOp}, |
195 | SmallVector<Value>(tiledPackOp->getResults())}; |
196 | } |
197 | |
198 | LogicalResult |
199 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
200 | ArrayRef<OpFoldResult> offsets, |
201 | ArrayRef<OpFoldResult> sizes, |
202 | SmallVector<OpFoldResult> &resultOffsets, |
203 | SmallVector<OpFoldResult> &resultSizes) const { |
204 | // The iteration domain is over outer dimensions of packed layout. In this |
205 | // context, the outer dimensions of `resultOffsets` are `offsets`. The |
206 | // inner dimensions of `resultOffsets` are zeros because tiling is not |
207 | // applied to them. |
208 | auto packOp = cast<PackOp>(op); |
209 | int64_t inputRank = packOp.getSourceRank(); |
210 | int64_t outputRank = packOp.getDestRank(); |
211 | auto zeroAttr = b.getI64IntegerAttr(0); |
212 | resultOffsets.assign(offsets.begin(), offsets.end()); |
213 | resultOffsets.append(outputRank - inputRank, zeroAttr); |
214 | |
215 | ReifiedRankedShapedTypeDims outputShape; |
216 | (void)reifyResultShapes(b, packOp, outputShape); |
217 | resultSizes.assign(sizes.begin(), sizes.end()); |
218 | for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank)) |
219 | resultSizes.push_back(outputShape[0][dataTileDim]); |
220 | |
221 | return success(); |
222 | } |
223 | |
224 | FailureOr<TilingResult> |
225 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
226 | ArrayRef<OpFoldResult> offsets, |
227 | ArrayRef<OpFoldResult> sizes) const { |
228 | auto packOp = cast<PackOp>(op); |
229 | int64_t numTiles = packOp.getInnerDimsPos().size(); |
230 | |
231 | // tensor.pack op is fusible (as a producer) only if full inner tiles are |
232 | // iterated or inner dims are not tiled. Otherwise, it will generate a |
233 | // sequence of non-trivial ops (for partial tiles). |
234 | for (auto offset : offsets.take_back(numTiles)) |
235 | if (!isConstantIntValue(offset, 0)) |
236 | return failure(); |
237 | |
238 | for (auto iter : |
239 | llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles))) |
240 | if (!isEqualConstantIntOrValue(std::get<0>(iter), std::get<1>(iter))) |
241 | return failure(); |
242 | |
243 | FailureOr<TilingResult> tilingResult = getTiledImplementation( |
244 | op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles)); |
245 | if (failed(tilingResult)) |
246 | return failure(); |
247 | return tilingResult.value(); |
248 | } |
249 | }; |
250 | |
251 | struct UnpackTileDimInfo { |
252 | bool isAlignedToInnerTileSize; |
253 | OpFoldResult sourceOffset; |
254 | OpFoldResult sourceSize; |
255 | OpFoldResult resultOffset; |
256 | OpFoldResult destExpandedSize; |
257 | }; |
258 | |
259 | /// Returns the needed information for tiling unpack op on `tileDim` with given |
260 | /// `tileOffset` and `tileSize`. For more details, see the comment of the |
261 | /// `getTiledImplementation`. |
262 | static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, |
263 | int64_t tileDim, |
264 | OpFoldResult tileOffset, |
265 | OpFoldResult tileSize) { |
266 | UnpackTileDimInfo info; |
267 | Attribute zeroAttr = b.getIndexAttr(0); |
268 | Attribute oneAttr = b.getIndexAttr(1); |
269 | DenseMap<int64_t, OpFoldResult> dimAndTileMapping = |
270 | unpackOp.getDimAndTileMapping(); |
271 | // The dimension is not one of packed data dimension. |
272 | if (!dimAndTileMapping.count(tileDim)) { |
273 | info.isAlignedToInnerTileSize = true; |
274 | info.sourceOffset = tileOffset; |
275 | info.sourceSize = tileSize; |
276 | info.resultOffset = zeroAttr; |
277 | info.destExpandedSize = tileSize; |
278 | return info; |
279 | } |
280 | |
281 | Location loc = unpackOp.getLoc(); |
282 | using AV = affine::AffineValueExpr; |
283 | affine::AffineBuilder ab(b, loc); |
284 | AffineExpr dim0, dim1, sym0; |
285 | bindDims(b.getContext(), dim0, dim1); |
286 | bindSymbols(b.getContext(), sym0); |
287 | |
288 | OpFoldResult innerTileSize = dimAndTileMapping[tileDim]; |
289 | |
290 | info.isAlignedToInnerTileSize = false; |
291 | FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( |
292 | presburger::BoundType::UB, tileSize, |
293 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
294 | std::optional<int64_t> cstInnerSize = getConstantIntValue(ofr: innerTileSize); |
295 | if (!failed(cstSize) && cstInnerSize) { |
296 | if (*cstSize % *cstInnerSize == 0) |
297 | info.isAlignedToInnerTileSize = true; |
298 | |
299 | // If the tiling size equals to the inner tiling size, the outer dims are |
300 | // always 1. |
301 | if (*cstInnerSize == *cstSize) { |
302 | auto lhs = AV(dim0).bind(v: tileOffset); |
303 | auto rhs = AV(dim1).bind(v: innerTileSize); |
304 | info.sourceOffset = ab.floor(lhs, rhs: rhs); |
305 | info.sourceSize = oneAttr; |
306 | info.resultOffset = zeroAttr; |
307 | info.destExpandedSize = tileSize; |
308 | return info; |
309 | } |
310 | } |
311 | |
312 | if (info.isAlignedToInnerTileSize) { |
313 | info.sourceOffset = |
314 | ab.floor(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: innerTileSize)); |
315 | info.resultOffset = zeroAttr; |
316 | info.destExpandedSize = tileSize; |
317 | |
318 | // The ceilDiv is needed here because there could be incomplete tile even |
319 | // it is perfect tiling cases. E.g., |
320 | // %0 = unpack tensor<33x2xf32> into tensor<64xf32> |
321 | // If the tiling size is 32, there will be 3 tiles. Two of them have |
322 | // size=32; one of them have size=2. The size is represented using |
323 | // affine_min op; we need ceilDiv. |
324 | info.sourceSize = |
325 | ab.ceil(lhs: AV(dim0).bind(v: tileSize), rhs: AV(dim1).bind(v: innerTileSize)); |
326 | return info; |
327 | } |
328 | |
329 | affine::DivModValue firstCoord = affine::getDivMod( |
330 | b, loc, lhs: getValueOrCreateConstantIndexOp(b, loc, ofr: tileOffset), |
331 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
332 | OpFoldResult tileExclusiveBound = |
333 | ab.add(lhs: AV(dim0).bind(v: tileOffset), rhs: AV(dim1).bind(v: tileSize)); |
334 | affine::DivModValue lastCoord = affine::getDivMod( |
335 | b, loc, |
336 | lhs: getValueOrCreateConstantIndexOp( |
337 | b, loc, |
338 | ofr: ab.sub(lhs: AV(dim0).bind(v: tileExclusiveBound), rhs: AV(dim1).bind(v: oneAttr))), |
339 | rhs: getValueOrCreateConstantIndexOp(b, loc, ofr: innerTileSize)); |
340 | |
341 | OpFoldResult lengthMinusOne = ab.sub(lhs: AV(dim0).bind(v: lastCoord.quotient), |
342 | rhs: AV(dim1).bind(v: firstCoord.quotient)); |
343 | info.sourceSize = |
344 | ab.add(lhs: AV(dim0).bind(v: lengthMinusOne), rhs: AV(dim1).bind(v: oneAttr)); |
345 | info.sourceOffset = firstCoord.quotient; |
346 | info.resultOffset = firstCoord.remainder; |
347 | // Do not create an Affine ops for expanded size because the affine op is too |
348 | // complicated which would trigger an issue in affine ops simplification. |
349 | info.destExpandedSize = b.createOrFold<arith::MulIOp>( |
350 | loc, getValueOrCreateConstantIndexOp(b, loc, info.sourceSize), |
351 | getValueOrCreateConstantIndexOp(b, loc, innerTileSize)); |
352 | return info; |
353 | } |
354 | |
355 | struct UnPackOpTiling |
356 | : public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> { |
357 | |
358 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
359 | auto unpackOp = cast<UnPackOp>(op); |
360 | SmallVector<utils::IteratorType> iteratorTypes( |
361 | unpackOp.getDestRank(), utils::IteratorType::parallel); |
362 | return iteratorTypes; |
363 | } |
364 | |
365 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
366 | return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b); |
367 | } |
368 | |
369 | /// There are two cases in tiling unpack ops. If the tiling size is aligned to |
370 | /// the inner tile size, the corresponding tiles of source are all complete. |
371 | /// Otherwise, there are in-complete tiles. We will need to expand the slice |
372 | /// of source for getting complete tiles. The tiled unpack op unpacks more |
373 | /// data from source, so We'll need an extract_slice op to shift and truncate |
374 | /// the output. |
375 | /// Take Nn_to_N as an example. Say that N=32, n=8, and tiling_size=15. The |
376 | /// coordinates of second tile (i.e., result[15..31]) are |
377 | /// [(1, 7), (2, 0,), (2, 1) ... (3, 6), (3, 7)]. The first row and the last |
378 | /// row are incomplete tiles. To represent the unpack op, we have to complete |
379 | /// the rows. I.e., the input coordinates would start with (1, 0); end with |
380 | /// (3, 7). In this context, the tiled unpack produces a (3 * n) elements |
381 | /// because there are 3 rows in total. Follow by a tensor.extract_slice op, we |
382 | /// can get the actual result. |
383 | FailureOr<TilingResult> |
384 | getTiledImplementation(Operation *op, OpBuilder &b, |
385 | ArrayRef<OpFoldResult> offsets, |
386 | ArrayRef<OpFoldResult> sizes) const { |
387 | auto unpackOp = cast<UnPackOp>(op); |
388 | int64_t srcRank = unpackOp.getSourceRank(); |
389 | int64_t destRank = unpackOp.getDestRank(); |
390 | int64_t numInnerTiles = srcRank - destRank; |
391 | Location loc = unpackOp.getLoc(); |
392 | |
393 | // The perfect tiling case indicates that the tiling sizes are multiple of |
394 | // inner_tile_size. In this context, no extra data is needed when |
395 | // representing the tiled unpack op. |
396 | bool isPerfectTilingCase = true; |
397 | Attribute oneAttr = b.getIndexAttr(1); |
398 | SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr); |
399 | SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes; |
400 | SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest; |
401 | for (auto dim : llvm::seq<int64_t>(0, destRank)) { |
402 | UnpackTileDimInfo info = |
403 | getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]); |
404 | if (!info.isAlignedToInnerTileSize) |
405 | isPerfectTilingCase = false; |
406 | sliceSrcIndices.push_back(info.sourceOffset); |
407 | sliceSrcSizes.push_back(info.sourceSize); |
408 | destExpandedSizes.push_back(info.destExpandedSize); |
409 | resultOffsetsFromDest.push_back(info.resultOffset); |
410 | } |
411 | |
412 | // The tiling is applied on destination dimensions. We have to apply the |
413 | // interchange on source dimensions if outer_dims_perm is set. |
414 | applyPermToRange(sliceSrcIndices, sliceSrcSizes, |
415 | unpackOp.getOuterDimsPerm()); |
416 | Attribute zeroAttr = b.getIndexAttr(0); |
417 | sliceSrcIndices.append(numInnerTiles, zeroAttr); |
418 | sliceSrcSizes.append(unpackOp.getMixedTiles()); |
419 | sliceSrcStrides.append(numInnerTiles, oneAttr); |
420 | Value sliceSource = |
421 | b.create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices, |
422 | sliceSrcSizes, sliceSrcStrides); |
423 | |
424 | SmallVector<OpFoldResult> destStrides(destRank, oneAttr); |
425 | Value sliceDest; |
426 | if (isPerfectTilingCase) { |
427 | sliceDest = b.create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets, |
428 | sizes, destStrides); |
429 | } else { |
430 | sliceDest = b.create<EmptyOp>(loc, destExpandedSizes, |
431 | unpackOp.getDestType().getElementType()); |
432 | } |
433 | |
434 | SmallVector<Value> tiledOperands = {sliceSource, sliceDest}; |
435 | for (auto tile : unpackOp.getInnerTiles()) |
436 | tiledOperands.push_back(tile); |
437 | |
438 | Operation *tiledUnpackOp = b.create<UnPackOp>( |
439 | loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); |
440 | |
441 | if (isPerfectTilingCase) |
442 | return TilingResult{{tiledUnpackOp}, |
443 | SmallVector<Value>(tiledUnpackOp->getResults())}; |
444 | |
445 | auto = |
446 | b.create<ExtractSliceOp>(loc, tiledUnpackOp->getResult(0), |
447 | resultOffsetsFromDest, sizes, destStrides); |
448 | return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; |
449 | } |
450 | |
451 | LogicalResult |
452 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
453 | ArrayRef<OpFoldResult> offsets, |
454 | ArrayRef<OpFoldResult> sizes, |
455 | SmallVector<OpFoldResult> &resultOffsets, |
456 | SmallVector<OpFoldResult> &resultSizes) const { |
457 | resultOffsets = llvm::to_vector(offsets); |
458 | resultSizes = llvm::to_vector(sizes); |
459 | return success(); |
460 | } |
461 | |
462 | FailureOr<TilingResult> |
463 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
464 | ArrayRef<OpFoldResult> offsets, |
465 | ArrayRef<OpFoldResult> sizes) const { |
466 | FailureOr<TilingResult> tilingResult = |
467 | getTiledImplementation(op, b, offsets, sizes); |
468 | if (failed(tilingResult)) |
469 | return failure(); |
470 | return tilingResult.value(); |
471 | } |
472 | }; |
473 | |
474 | } // namespace |
475 | |
476 | FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, |
477 | tensor::PadOp padOp, |
478 | ArrayRef<OpFoldResult> offsets, |
479 | ArrayRef<OpFoldResult> sizes, |
480 | bool generateZeroSliceGuard) { |
481 | // Only constant padding value supported. |
482 | Value padValue = padOp.getConstantPaddingValue(); |
483 | if (!padValue) |
484 | return failure(); |
485 | |
486 | // Helper variables and functions for various arithmetic operations. These |
487 | // are used extensively for computing new offset/length and padding values. |
488 | Location loc = padOp->getLoc(); |
489 | AffineExpr dim0, dim1; |
490 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
491 | // Add two integers. |
492 | auto addMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 + dim1}); |
493 | auto add = [&](OpFoldResult v1, OpFoldResult v2) { |
494 | return affine::makeComposedFoldedAffineApply(b, loc, addMap, {v1, v2}); |
495 | }; |
496 | // Subtract two integers. |
497 | auto subMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 - dim1}); |
498 | auto sub = [&](OpFoldResult v1, OpFoldResult v2) { |
499 | return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); |
500 | }; |
501 | // Take the minimum of two integers. |
502 | auto idMap = AffineMap::getMultiDimIdentityMap(numDims: 2, context: b.getContext()); |
503 | auto min = [&](OpFoldResult v1, OpFoldResult v2) { |
504 | return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); |
505 | }; |
506 | // Take the maximum of two integers. |
507 | auto max = [&](OpFoldResult v1, OpFoldResult v2) { |
508 | return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); |
509 | }; |
510 | // Zero index-typed integer. |
511 | OpFoldResult zero = b.getIndexAttr(0); |
512 | |
513 | // Compute new offsets, lengths, low padding, high padding. |
514 | SmallVector<OpFoldResult> newOffsets, newLengths, newStrides; |
515 | SmallVector<OpFoldResult> newLows, newHighs; |
516 | // Set to true if the original data source is not read at all. |
517 | bool hasZeroLen = false; |
518 | // Same as hasZeroLen, but for dynamic dimension sizes. This condition |
519 | // is true if the original data source turns out to be unused at runtime. |
520 | Value dynHasZeroLenCond; |
521 | |
522 | int64_t rank = padOp.getSourceType().getRank(); |
523 | for (unsigned dim = 0; dim < rank; ++dim) { |
524 | auto low = padOp.getMixedLowPad()[dim]; |
525 | bool hasLowPad = !isConstantIntValue(low, 0); |
526 | auto high = padOp.getMixedHighPad()[dim]; |
527 | bool hasHighPad = !isConstantIntValue(high, 0); |
528 | auto offset = offsets[dim]; |
529 | auto length = sizes[dim]; |
530 | auto srcSize = tensor::getMixedSize(builder&: b, loc, value: padOp.getSource(), dim); |
531 | |
532 | // The new amount of low padding is `low - offset`. Except for the case |
533 | // where none of the low padding is read. In that case, the new amount of |
534 | // low padding is zero. |
535 | // |
536 | // Optimization: If low = 0, then newLow = 0. |
537 | OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; |
538 | newLows.push_back(Elt: newLow); |
539 | |
540 | // Start reading the data from position `offset - low`. Since the original |
541 | // read may have started in the low padding zone, this value could be |
542 | // negative. Therefore, start reading from: |
543 | // |
544 | // max(offset - low, 0) |
545 | // |
546 | // The original read could also have started in the high padding zone. |
547 | // In that case, set the offset to the end of source tensor. The new |
548 | // ExtractSliceOp length will be zero in that case. (Effectively reading |
549 | // no data from the source.) |
550 | // |
551 | // Optimization: If low = 0, then the formula can be simplified. |
552 | OpFoldResult newOffset = hasLowPad |
553 | ? min(max(sub(offset, low), zero), srcSize) |
554 | : min(offset, srcSize); |
555 | newOffsets.push_back(Elt: newOffset); |
556 | |
557 | // The original ExtractSliceOp was reading until position `offset + |
558 | // length`. Therefore, the corresponding position within the source tensor |
559 | // is: |
560 | // |
561 | // offset + length - low |
562 | // |
563 | // In case the original ExtractSliceOp stopped reading within the low |
564 | // padding zone, this value can be negative. In that case, the end |
565 | // position of the read should be zero. (Similar to newOffset.) |
566 | // |
567 | // The original read could also have stopped in the high padding zone. |
568 | // In that case, set the end positition of the read should be the end of |
569 | // the source tensor. (Similar to newOffset.) |
570 | // |
571 | // endLoc = min(max(offset - low + length, 0), srcSize) |
572 | // |
573 | // The new ExtractSliceOp length is `endLoc - newOffset`. |
574 | // |
575 | // Optimization: If low = 0, then the formula can be simplified. |
576 | OpFoldResult endLoc = |
577 | hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) |
578 | : min(add(offset, length), srcSize); |
579 | OpFoldResult newLength = sub(endLoc, newOffset); |
580 | newLengths.push_back(Elt: newLength); |
581 | |
582 | // Check if newLength is zero. In that case, no SubTensorOp should be |
583 | // executed. |
584 | if (isConstantIntValue(ofr: newLength, value: 0)) { |
585 | hasZeroLen = true; |
586 | } else if (!hasZeroLen) { |
587 | Value check = b.create<arith::CmpIOp>( |
588 | loc, arith::CmpIPredicate::eq, |
589 | getValueOrCreateConstantIndexOp(b, loc, newLength), |
590 | getValueOrCreateConstantIndexOp(b, loc, zero)); |
591 | dynHasZeroLenCond = |
592 | dynHasZeroLenCond |
593 | ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) |
594 | : check; |
595 | } |
596 | |
597 | // The amount of high padding is simply the number of elements remaining, |
598 | // so that the result has the same length as the original ExtractSliceOp. |
599 | // As an optimization, if the original high padding is zero, then the new |
600 | // high padding must also be zero. |
601 | OpFoldResult newHigh = |
602 | hasHighPad ? sub(sub(length, newLength), newLow) : zero; |
603 | newHighs.push_back(Elt: newHigh); |
604 | |
605 | // Only unit stride supported. |
606 | newStrides.push_back(b.getIndexAttr(1)); |
607 | } |
608 | |
609 | // The shape of the result can be obtained from the sizes passed in. |
610 | SmallVector<Value> dynDims; |
611 | SmallVector<int64_t> shape; |
612 | dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynDims, staticVec&: shape); |
613 | RankedTensorType resultType = |
614 | RankedTensorType::get(shape, padOp.getResultType().getElementType()); |
615 | |
616 | // Insert cast to ensure that types match. (May be folded away.) |
617 | auto castResult = [&](Value val) -> Value { |
618 | if (resultType == val.getType()) |
619 | return val; |
620 | return b.create<tensor::CastOp>(loc, resultType, val); |
621 | }; |
622 | |
623 | // In cases where the original data source is unused: Emit a GenerateOp and |
624 | // do not generate a SliceOp. (The result shape of the SliceOp would |
625 | // have a dimension of size 0, the semantics of which is unclear.) |
626 | auto createGenerateOp = [&]() { |
627 | // Create GenerateOp. |
628 | auto generateOp = b.create<tensor::GenerateOp>( |
629 | loc, resultType, dynDims, |
630 | [&](OpBuilder &builder, Location gLoc, ValueRange indices) { |
631 | builder.create<tensor::YieldOp>(gLoc, padValue); |
632 | }); |
633 | return generateOp; |
634 | }; |
635 | |
636 | // Emit a SliceOp and a PadOp. Should not be used in cases where |
637 | // the result shape of the new SliceOp has a zero dimension. |
638 | auto = [&]() { |
639 | // Create pad(extract_slice(x)). |
640 | Value newSliceOp = b.create<tensor::ExtractSliceOp>( |
641 | loc, padOp.getSource(), newOffsets, newLengths, newStrides); |
642 | auto newPadOp = b.create<PadOp>( |
643 | loc, Type(), newSliceOp, newLows, newHighs, |
644 | /*nofold=*/padOp.getNofold(), |
645 | getPrunedAttributeList(padOp, PadOp::getAttributeNames())); |
646 | |
647 | // Copy region to new PadOp. |
648 | IRMapping bvm; |
649 | padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); |
650 | |
651 | // Cast result and return. |
652 | return newPadOp; |
653 | }; |
654 | |
655 | // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that |
656 | // the original data source x is not used. |
657 | if (hasZeroLen) { |
658 | Operation *generateOp = createGenerateOp(); |
659 | return TilingResult{.tiledOps: {generateOp}, .tiledValues: {castResult(generateOp->getResult(idx: 0))}}; |
660 | } |
661 | |
662 | // If there are dynamic dimensions: Generate an scf.if check to avoid |
663 | // creating SliceOps with result dimensions of size 0 at runtime. |
664 | if (generateZeroSliceGuard && dynHasZeroLenCond) { |
665 | Operation *thenOp; |
666 | Operation *elseOp; |
667 | auto result = b.create<scf::IfOp>( |
668 | loc, dynHasZeroLenCond, |
669 | /*thenBuilder=*/ |
670 | [&](OpBuilder &b, Location loc) { |
671 | thenOp = createGenerateOp(); |
672 | b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); |
673 | }, |
674 | /*elseBuilder=*/ |
675 | [&](OpBuilder &b, Location loc) { |
676 | elseOp = createPadOfExtractSlice(); |
677 | b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); |
678 | }); |
679 | return TilingResult{.tiledOps: {elseOp}, .tiledValues: SmallVector<Value>(result->getResults())}; |
680 | } |
681 | |
682 | Operation *newPadOp = createPadOfExtractSlice(); |
683 | return TilingResult{.tiledOps: {newPadOp}, .tiledValues: {castResult(newPadOp->getResult(idx: 0))}}; |
684 | } |
685 | |
686 | void mlir::tensor::registerTilingInterfaceExternalModels( |
687 | DialectRegistry ®istry) { |
688 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) { |
689 | tensor::PadOp::attachInterface<PadOpTiling>(*ctx); |
690 | tensor::PackOp::attachInterface<PackOpTiling>(*ctx); |
691 | tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); |
692 | }); |
693 | } |
694 | |
695 | void mlir::tensor::registerTilingInterfaceExternalModelsForPackUnPackOps( |
696 | DialectRegistry ®istry) { |
697 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) { |
698 | tensor::PackOp::attachInterface<PackOpTiling>(*ctx); |
699 | tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx); |
700 | }); |
701 | } |
702 | |