1 | //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" |
10 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
11 | #include "mlir/Dialect/Affine/Utils.h" |
12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
18 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
19 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
20 | #include "mlir/Interfaces/TilingInterface.h" |
21 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::tensor; |
25 | |
26 | namespace { |
27 | |
28 | struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { |
29 | |
30 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
31 | auto padOp = cast<PadOp>(op); |
32 | SmallVector<utils::IteratorType> iteratorTypes( |
33 | padOp.getResultType().getRank(), utils::IteratorType::parallel); |
34 | return iteratorTypes; |
35 | } |
36 | |
37 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
38 | ReifiedRankedShapedTypeDims reifiedShapes; |
39 | (void)reifyResultShapes(b, op, reifiedShapes); |
40 | OpFoldResult zero = b.getIndexAttr(0); |
41 | OpFoldResult one = b.getIndexAttr(1); |
42 | // Initialize all the ranges to {zero, one, one}. All the `ub`s are |
43 | // overwritten. |
44 | SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); |
45 | for (const auto &ub : enumerate(reifiedShapes[0])) |
46 | loopRanges[ub.index()].size = ub.value(); |
47 | return loopRanges; |
48 | } |
49 | |
50 | FailureOr<TilingResult> |
51 | getTiledImplementation(Operation *op, OpBuilder &b, |
52 | ArrayRef<OpFoldResult> offsets, |
53 | ArrayRef<OpFoldResult> sizes) const { |
54 | FailureOr<TilingResult> result = |
55 | tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); |
56 | if (failed(result)) |
57 | return failure(); |
58 | return result.value(); |
59 | } |
60 | |
61 | LogicalResult |
62 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
63 | ArrayRef<OpFoldResult> offsets, |
64 | ArrayRef<OpFoldResult> sizes, |
65 | SmallVector<OpFoldResult> &resultOffsets, |
66 | SmallVector<OpFoldResult> &resultSizes) const { |
67 | resultOffsets.assign(offsets.begin(), offsets.end()); |
68 | resultSizes.assign(sizes.begin(), sizes.end()); |
69 | return success(); |
70 | } |
71 | |
72 | LogicalResult getIterationDomainTileFromResultTile( |
73 | Operation *op, OpBuilder &b, unsigned resultNumber, |
74 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
75 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
76 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
77 | iterDomainOffsets.assign(offsets.begin(), offsets.end()); |
78 | iterDomainSizes.assign(sizes.begin(), sizes.end()); |
79 | return success(); |
80 | } |
81 | |
82 | FailureOr<TilingResult> |
83 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
84 | ArrayRef<OpFoldResult> offsets, |
85 | ArrayRef<OpFoldResult> sizes) const { |
86 | return getTiledImplementation(op, b, offsets, sizes); |
87 | } |
88 | }; |
89 | |
90 | } // namespace |
91 | |
92 | FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, |
93 | tensor::PadOp padOp, |
94 | ArrayRef<OpFoldResult> offsets, |
95 | ArrayRef<OpFoldResult> sizes, |
96 | bool generateZeroSliceGuard) { |
97 | // Only constant padding value supported. |
98 | Value padValue = padOp.getConstantPaddingValue(); |
99 | if (!padValue) |
100 | return failure(); |
101 | |
102 | // Helper variables and functions for various arithmetic operations. These |
103 | // are used extensively for computing new offset/length and padding values. |
104 | Location loc = padOp->getLoc(); |
105 | AffineExpr dim0, dim1; |
106 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
107 | // Subtract two integers. |
108 | auto subMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 - dim1}); |
109 | auto sub = [&](OpFoldResult v1, OpFoldResult v2) { |
110 | return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); |
111 | }; |
112 | // Take the minimum of two integers. |
113 | auto idMap = AffineMap::getMultiDimIdentityMap(numDims: 2, context: b.getContext()); |
114 | auto min = [&](OpFoldResult v1, OpFoldResult v2) { |
115 | return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); |
116 | }; |
117 | // Take the maximum of two integers. |
118 | auto max = [&](OpFoldResult v1, OpFoldResult v2) { |
119 | return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); |
120 | }; |
121 | // Zero index-typed integer. |
122 | OpFoldResult zero = b.getIndexAttr(0); |
123 | |
124 | // Compute new offsets, lengths, low padding, high padding. |
125 | SmallVector<OpFoldResult> newOffsets, newLengths; |
126 | SmallVector<OpFoldResult> newLows, newHighs; |
127 | // Set to true if the original data source is not read at all. |
128 | bool hasZeroLen = false; |
129 | // Same as hasZeroLen, but for dynamic dimension sizes. This condition |
130 | // is true if the original data source turns out to be unused at runtime. |
131 | Value dynHasZeroLenCond; |
132 | |
133 | int64_t rank = padOp.getSourceType().getRank(); |
134 | // Only unit stride supported. |
135 | SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1)); |
136 | for (unsigned dim = 0; dim < rank; ++dim) { |
137 | auto low = padOp.getMixedLowPad()[dim]; |
138 | bool hasLowPad = !isZeroInteger(low); |
139 | auto high = padOp.getMixedHighPad()[dim]; |
140 | bool hasHighPad = !isZeroInteger(high); |
141 | auto offset = offsets[dim]; |
142 | auto length = sizes[dim]; |
143 | // If the dim has no padding, we dont need to calculate new values for that |
144 | // dim as the exisiting ones are correct even after the pattern. |
145 | if (!hasLowPad && !hasHighPad) { |
146 | newOffsets.push_back(Elt: offset); |
147 | newLengths.push_back(Elt: length); |
148 | newLows.push_back(Elt: low); |
149 | newHighs.push_back(Elt: high); |
150 | continue; |
151 | } |
152 | |
153 | auto srcSize = tensor::getMixedSize(builder&: b, loc, value: padOp.getSource(), dim); |
154 | |
155 | // The new amount of low padding is `low - offset`. Except for the case |
156 | // where none of the low padding is read. In that case, the new amount of |
157 | // low padding is zero. |
158 | // |
159 | // Optimization: If low = 0, then newLow = 0. |
160 | OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; |
161 | newLows.push_back(Elt: newLow); |
162 | |
163 | // Start reading the data from position `offset - low`. Since the original |
164 | // read may have started in the low padding zone, this value could be |
165 | // negative. Therefore, start reading from: |
166 | // |
167 | // max(offset - low, 0) |
168 | // |
169 | // The original read could also have started in the high padding zone. |
170 | // In that case, set the offset to the end of source tensor. The new |
171 | // ExtractSliceOp length will be zero in that case. (Effectively reading |
172 | // no data from the source.) |
173 | // |
174 | // Optimization: If low = 0, then the formula can be simplified. |
175 | OpFoldResult newOffset = hasLowPad |
176 | ? min(max(sub(offset, low), zero), srcSize) |
177 | : min(offset, srcSize); |
178 | newOffsets.push_back(Elt: newOffset); |
179 | |
180 | // The original ExtractSliceOp was reading until position `offset + |
181 | // length`. Therefore, the corresponding position within the source tensor |
182 | // is: |
183 | // |
184 | // offset + length - low |
185 | // |
186 | // In case the original ExtractSliceOp stopped reading within the low |
187 | // padding zone, this value can be negative. In that case, the end |
188 | // position of the read should be zero. (Similar to newOffset.) |
189 | // |
190 | // The original read could also have stopped in the high padding zone. |
191 | // In that case, set the end positition of the read should be the end of |
192 | // the source tensor. (Similar to newOffset.) |
193 | // srcSize - newOffset represents how much length we have available |
194 | // and length - newLow represents how much length we want at most. |
195 | // Note that there are many ways to order this indexing math to compute |
196 | // newLength, but we want to make sure that the final affine.min ops in the |
197 | // sequence are bounding the index to as small a value as possible. If |
198 | // ValueBoundsOpInterface is used, this calculation will get upper bounds |
199 | // from the affine.min ops, so we want to use the smallest known value to |
200 | // set the bound at the end of the computation sequence. In this case, the |
201 | // index will be upper bounded by length - newLow. |
202 | OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow)); |
203 | // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming |
204 | // length >= 0. |
205 | if (hasLowPad) |
206 | newLength = max(newLength, zero); |
207 | newLengths.push_back(Elt: newLength); |
208 | |
209 | // Check if newLength is zero. In that case, no SubTensorOp should be |
210 | // executed. |
211 | if (isZeroInteger(v: newLength)) { |
212 | hasZeroLen = true; |
213 | } else if (!hasZeroLen) { |
214 | Value check = b.create<arith::CmpIOp>( |
215 | loc, arith::CmpIPredicate::eq, |
216 | getValueOrCreateConstantIndexOp(b, loc, newLength), |
217 | getValueOrCreateConstantIndexOp(b, loc, zero)); |
218 | dynHasZeroLenCond = |
219 | dynHasZeroLenCond |
220 | ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) |
221 | : check; |
222 | } |
223 | |
224 | // The amount of high padding is simply the number of elements remaining, |
225 | // so that the result has the same length as the original ExtractSliceOp. |
226 | // As an optimization, if the original high padding is zero, then the new |
227 | // high padding must also be zero. |
228 | OpFoldResult newHigh = |
229 | hasHighPad ? sub(sub(length, newLength), newLow) : zero; |
230 | newHighs.push_back(Elt: newHigh); |
231 | } |
232 | |
233 | // The shape of the result can be obtained from the sizes passed in. |
234 | SmallVector<Value> dynDims; |
235 | SmallVector<int64_t> shape; |
236 | dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynDims, staticVec&: shape); |
237 | RankedTensorType resultType = |
238 | RankedTensorType::get(shape, padOp.getResultType().getElementType()); |
239 | |
240 | // Insert cast to ensure that types match. (May be folded away.) |
241 | auto castResult = [&](Value val) -> Value { |
242 | if (resultType == val.getType()) |
243 | return val; |
244 | return b.create<tensor::CastOp>(loc, resultType, val); |
245 | }; |
246 | |
247 | // In cases where the original data source is unused: Emit a GenerateOp and |
248 | // do not generate a SliceOp. (The result shape of the SliceOp would |
249 | // have a dimension of size 0, the semantics of which is unclear.) |
250 | auto createGenerateOp = [&]() { |
251 | // Create GenerateOp. |
252 | auto generateOp = b.create<tensor::GenerateOp>( |
253 | loc, resultType, dynDims, |
254 | [&](OpBuilder &builder, Location gLoc, ValueRange indices) { |
255 | builder.create<tensor::YieldOp>(gLoc, padValue); |
256 | }); |
257 | return generateOp; |
258 | }; |
259 | |
260 | // Emit a SliceOp and a PadOp. Should not be used in cases where |
261 | // the result shape of the new SliceOp has a zero dimension. |
262 | auto = [&]() { |
263 | // Create pad(extract_slice(x)). |
264 | auto newSliceOp = b.create<tensor::ExtractSliceOp>( |
265 | loc, padOp.getSource(), newOffsets, newLengths, newStrides); |
266 | auto newPadOp = b.create<PadOp>( |
267 | loc, Type(), newSliceOp, newLows, newHighs, |
268 | /*nofold=*/padOp.getNofold(), |
269 | getPrunedAttributeList(padOp, PadOp::getAttributeNames())); |
270 | |
271 | // Copy region to new PadOp. |
272 | IRMapping bvm; |
273 | padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); |
274 | |
275 | // Cast result and return. |
276 | return std::make_tuple(newPadOp, newSliceOp); |
277 | }; |
278 | |
279 | // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that |
280 | // the original data source x is not used. |
281 | if (hasZeroLen) { |
282 | Operation *generateOp = createGenerateOp(); |
283 | return TilingResult{.tiledOps: {generateOp}, |
284 | .tiledValues: {castResult(generateOp->getResult(idx: 0))}, |
285 | /*generatedSlices=*/{}}; |
286 | } |
287 | |
288 | // If there are dynamic dimensions: Generate an scf.if check to avoid |
289 | // creating SliceOps with result dimensions of size 0 at runtime. |
290 | if (generateZeroSliceGuard && dynHasZeroLenCond) { |
291 | Operation *thenOp; |
292 | Operation *elseOp; |
293 | Operation *sliceOp; |
294 | auto result = b.create<scf::IfOp>( |
295 | loc, dynHasZeroLenCond, |
296 | /*thenBuilder=*/ |
297 | [&](OpBuilder &b, Location loc) { |
298 | thenOp = createGenerateOp(); |
299 | b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); |
300 | }, |
301 | /*elseBuilder=*/ |
302 | [&](OpBuilder &b, Location loc) { |
303 | std::tie(elseOp, sliceOp) = createPadOfExtractSlice(); |
304 | b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); |
305 | }); |
306 | return TilingResult{ |
307 | .tiledOps: {elseOp}, .tiledValues: SmallVector<Value>(result->getResults()), .generatedSlices: {sliceOp}}; |
308 | } |
309 | |
310 | auto [newPadOp, sliceOp] = createPadOfExtractSlice(); |
311 | return TilingResult{ |
312 | {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}}; |
313 | } |
314 | |
315 | void mlir::tensor::registerTilingInterfaceExternalModels( |
316 | DialectRegistry ®istry) { |
317 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) { |
318 | tensor::PadOp::attachInterface<PadOpTiling>(*ctx); |
319 | }); |
320 | } |
321 | |