| 1 | //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" |
| 10 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 11 | #include "mlir/Dialect/Affine/Utils.h" |
| 12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 14 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 15 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 17 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| 18 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 19 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
| 20 | #include "mlir/Interfaces/TilingInterface.h" |
| 21 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 22 | |
| 23 | using namespace mlir; |
| 24 | using namespace mlir::tensor; |
| 25 | |
| 26 | namespace { |
| 27 | |
| 28 | struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { |
| 29 | |
| 30 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 31 | auto padOp = cast<PadOp>(op); |
| 32 | SmallVector<utils::IteratorType> iteratorTypes( |
| 33 | padOp.getResultType().getRank(), utils::IteratorType::parallel); |
| 34 | return iteratorTypes; |
| 35 | } |
| 36 | |
| 37 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 38 | ReifiedRankedShapedTypeDims reifiedShapes; |
| 39 | (void)reifyResultShapes(b, op, reifiedShapes); |
| 40 | OpFoldResult zero = b.getIndexAttr(0); |
| 41 | OpFoldResult one = b.getIndexAttr(1); |
| 42 | // Initialize all the ranges to {zero, one, one}. All the `ub`s are |
| 43 | // overwritten. |
| 44 | SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one}); |
| 45 | for (const auto &ub : enumerate(reifiedShapes[0])) |
| 46 | loopRanges[ub.index()].size = ub.value(); |
| 47 | return loopRanges; |
| 48 | } |
| 49 | |
| 50 | FailureOr<TilingResult> |
| 51 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 52 | ArrayRef<OpFoldResult> offsets, |
| 53 | ArrayRef<OpFoldResult> sizes) const { |
| 54 | FailureOr<TilingResult> result = |
| 55 | tensor::bubbleUpPadSlice(b, cast<PadOp>(op), offsets, sizes); |
| 56 | if (failed(result)) |
| 57 | return failure(); |
| 58 | return result.value(); |
| 59 | } |
| 60 | |
| 61 | LogicalResult |
| 62 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 63 | ArrayRef<OpFoldResult> offsets, |
| 64 | ArrayRef<OpFoldResult> sizes, |
| 65 | SmallVector<OpFoldResult> &resultOffsets, |
| 66 | SmallVector<OpFoldResult> &resultSizes) const { |
| 67 | resultOffsets.assign(offsets.begin(), offsets.end()); |
| 68 | resultSizes.assign(sizes.begin(), sizes.end()); |
| 69 | return success(); |
| 70 | } |
| 71 | |
| 72 | LogicalResult getIterationDomainTileFromResultTile( |
| 73 | Operation *op, OpBuilder &b, unsigned resultNumber, |
| 74 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 75 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
| 76 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
| 77 | iterDomainOffsets.assign(offsets.begin(), offsets.end()); |
| 78 | iterDomainSizes.assign(sizes.begin(), sizes.end()); |
| 79 | return success(); |
| 80 | } |
| 81 | |
| 82 | FailureOr<TilingResult> |
| 83 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 84 | ArrayRef<OpFoldResult> offsets, |
| 85 | ArrayRef<OpFoldResult> sizes) const { |
| 86 | return getTiledImplementation(op, b, offsets, sizes); |
| 87 | } |
| 88 | }; |
| 89 | |
| 90 | } // namespace |
| 91 | |
| 92 | FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, |
| 93 | tensor::PadOp padOp, |
| 94 | ArrayRef<OpFoldResult> offsets, |
| 95 | ArrayRef<OpFoldResult> sizes, |
| 96 | bool generateZeroSliceGuard) { |
| 97 | // Only constant padding value supported. |
| 98 | Value padValue = padOp.getConstantPaddingValue(); |
| 99 | if (!padValue) |
| 100 | return failure(); |
| 101 | |
| 102 | // Helper variables and functions for various arithmetic operations. These |
| 103 | // are used extensively for computing new offset/length and padding values. |
| 104 | Location loc = padOp->getLoc(); |
| 105 | AffineExpr dim0, dim1; |
| 106 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
| 107 | // Subtract two integers. |
| 108 | auto subMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 - dim1}); |
| 109 | auto sub = [&](OpFoldResult v1, OpFoldResult v2) { |
| 110 | return affine::makeComposedFoldedAffineApply(b, loc, subMap, {v1, v2}); |
| 111 | }; |
| 112 | // Take the minimum of two integers. |
| 113 | auto idMap = AffineMap::getMultiDimIdentityMap(numDims: 2, context: b.getContext()); |
| 114 | auto min = [&](OpFoldResult v1, OpFoldResult v2) { |
| 115 | return affine::makeComposedFoldedAffineMin(b, loc, idMap, {v1, v2}); |
| 116 | }; |
| 117 | // Take the maximum of two integers. |
| 118 | auto max = [&](OpFoldResult v1, OpFoldResult v2) { |
| 119 | return affine::makeComposedFoldedAffineMax(b, loc, idMap, {v1, v2}); |
| 120 | }; |
| 121 | // Zero index-typed integer. |
| 122 | OpFoldResult zero = b.getIndexAttr(0); |
| 123 | |
| 124 | // Compute new offsets, lengths, low padding, high padding. |
| 125 | SmallVector<OpFoldResult> newOffsets, newLengths; |
| 126 | SmallVector<OpFoldResult> newLows, newHighs; |
| 127 | // Set to true if the original data source is not read at all. |
| 128 | bool hasZeroLen = false; |
| 129 | // Same as hasZeroLen, but for dynamic dimension sizes. This condition |
| 130 | // is true if the original data source turns out to be unused at runtime. |
| 131 | Value dynHasZeroLenCond; |
| 132 | |
| 133 | int64_t rank = padOp.getSourceType().getRank(); |
| 134 | // Only unit stride supported. |
| 135 | SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1)); |
| 136 | for (unsigned dim = 0; dim < rank; ++dim) { |
| 137 | auto low = padOp.getMixedLowPad()[dim]; |
| 138 | bool hasLowPad = !isZeroInteger(low); |
| 139 | auto high = padOp.getMixedHighPad()[dim]; |
| 140 | bool hasHighPad = !isZeroInteger(high); |
| 141 | auto offset = offsets[dim]; |
| 142 | auto length = sizes[dim]; |
| 143 | // If the dim has no padding, we dont need to calculate new values for that |
| 144 | // dim as the exisiting ones are correct even after the pattern. |
| 145 | if (!hasLowPad && !hasHighPad) { |
| 146 | newOffsets.push_back(Elt: offset); |
| 147 | newLengths.push_back(Elt: length); |
| 148 | newLows.push_back(Elt: low); |
| 149 | newHighs.push_back(Elt: high); |
| 150 | continue; |
| 151 | } |
| 152 | |
| 153 | auto srcSize = tensor::getMixedSize(builder&: b, loc, value: padOp.getSource(), dim); |
| 154 | |
| 155 | // The new amount of low padding is `low - offset`. Except for the case |
| 156 | // where none of the low padding is read. In that case, the new amount of |
| 157 | // low padding is zero. |
| 158 | // |
| 159 | // Optimization: If low = 0, then newLow = 0. |
| 160 | OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; |
| 161 | newLows.push_back(Elt: newLow); |
| 162 | |
| 163 | // Start reading the data from position `offset - low`. Since the original |
| 164 | // read may have started in the low padding zone, this value could be |
| 165 | // negative. Therefore, start reading from: |
| 166 | // |
| 167 | // max(offset - low, 0) |
| 168 | // |
| 169 | // The original read could also have started in the high padding zone. |
| 170 | // In that case, set the offset to the end of source tensor. The new |
| 171 | // ExtractSliceOp length will be zero in that case. (Effectively reading |
| 172 | // no data from the source.) |
| 173 | // |
| 174 | // Optimization: If low = 0, then the formula can be simplified. |
| 175 | OpFoldResult newOffset = hasLowPad |
| 176 | ? min(max(sub(offset, low), zero), srcSize) |
| 177 | : min(offset, srcSize); |
| 178 | newOffsets.push_back(Elt: newOffset); |
| 179 | |
| 180 | // The original ExtractSliceOp was reading until position `offset + |
| 181 | // length`. Therefore, the corresponding position within the source tensor |
| 182 | // is: |
| 183 | // |
| 184 | // offset + length - low |
| 185 | // |
| 186 | // In case the original ExtractSliceOp stopped reading within the low |
| 187 | // padding zone, this value can be negative. In that case, the end |
| 188 | // position of the read should be zero. (Similar to newOffset.) |
| 189 | // |
| 190 | // The original read could also have stopped in the high padding zone. |
| 191 | // In that case, set the end positition of the read should be the end of |
| 192 | // the source tensor. (Similar to newOffset.) |
| 193 | // srcSize - newOffset represents how much length we have available |
| 194 | // and length - newLow represents how much length we want at most. |
| 195 | // Note that there are many ways to order this indexing math to compute |
| 196 | // newLength, but we want to make sure that the final affine.min ops in the |
| 197 | // sequence are bounding the index to as small a value as possible. If |
| 198 | // ValueBoundsOpInterface is used, this calculation will get upper bounds |
| 199 | // from the affine.min ops, so we want to use the smallest known value to |
| 200 | // set the bound at the end of the computation sequence. In this case, the |
| 201 | // index will be upper bounded by length - newLow. |
| 202 | OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow)); |
| 203 | // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming |
| 204 | // length >= 0. |
| 205 | if (hasLowPad) |
| 206 | newLength = max(newLength, zero); |
| 207 | newLengths.push_back(Elt: newLength); |
| 208 | |
| 209 | // Check if newLength is zero. In that case, no SubTensorOp should be |
| 210 | // executed. |
| 211 | if (isZeroInteger(v: newLength)) { |
| 212 | hasZeroLen = true; |
| 213 | } else if (!hasZeroLen) { |
| 214 | Value check = b.create<arith::CmpIOp>( |
| 215 | loc, arith::CmpIPredicate::eq, |
| 216 | getValueOrCreateConstantIndexOp(b, loc, newLength), |
| 217 | getValueOrCreateConstantIndexOp(b, loc, zero)); |
| 218 | dynHasZeroLenCond = |
| 219 | dynHasZeroLenCond |
| 220 | ? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond) |
| 221 | : check; |
| 222 | } |
| 223 | |
| 224 | // The amount of high padding is simply the number of elements remaining, |
| 225 | // so that the result has the same length as the original ExtractSliceOp. |
| 226 | // As an optimization, if the original high padding is zero, then the new |
| 227 | // high padding must also be zero. |
| 228 | OpFoldResult newHigh = |
| 229 | hasHighPad ? sub(sub(length, newLength), newLow) : zero; |
| 230 | newHighs.push_back(Elt: newHigh); |
| 231 | } |
| 232 | |
| 233 | // The shape of the result can be obtained from the sizes passed in. |
| 234 | SmallVector<Value> dynDims; |
| 235 | SmallVector<int64_t> shape; |
| 236 | dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynDims, staticVec&: shape); |
| 237 | RankedTensorType resultType = |
| 238 | RankedTensorType::get(shape, padOp.getResultType().getElementType()); |
| 239 | |
| 240 | // Insert cast to ensure that types match. (May be folded away.) |
| 241 | auto castResult = [&](Value val) -> Value { |
| 242 | if (resultType == val.getType()) |
| 243 | return val; |
| 244 | return b.create<tensor::CastOp>(loc, resultType, val); |
| 245 | }; |
| 246 | |
| 247 | // In cases where the original data source is unused: Emit a GenerateOp and |
| 248 | // do not generate a SliceOp. (The result shape of the SliceOp would |
| 249 | // have a dimension of size 0, the semantics of which is unclear.) |
| 250 | auto createGenerateOp = [&]() { |
| 251 | // Create GenerateOp. |
| 252 | auto generateOp = b.create<tensor::GenerateOp>( |
| 253 | loc, resultType, dynDims, |
| 254 | [&](OpBuilder &builder, Location gLoc, ValueRange indices) { |
| 255 | builder.create<tensor::YieldOp>(gLoc, padValue); |
| 256 | }); |
| 257 | return generateOp; |
| 258 | }; |
| 259 | |
| 260 | // Emit a SliceOp and a PadOp. Should not be used in cases where |
| 261 | // the result shape of the new SliceOp has a zero dimension. |
| 262 | auto = [&]() { |
| 263 | // Create pad(extract_slice(x)). |
| 264 | auto newSliceOp = b.create<tensor::ExtractSliceOp>( |
| 265 | loc, padOp.getSource(), newOffsets, newLengths, newStrides); |
| 266 | auto newPadOp = b.create<PadOp>( |
| 267 | loc, Type(), newSliceOp, newLows, newHighs, |
| 268 | /*nofold=*/padOp.getNofold(), |
| 269 | getPrunedAttributeList(padOp, PadOp::getAttributeNames())); |
| 270 | |
| 271 | // Copy region to new PadOp. |
| 272 | IRMapping bvm; |
| 273 | padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); |
| 274 | |
| 275 | // Cast result and return. |
| 276 | return std::make_tuple(newPadOp, newSliceOp); |
| 277 | }; |
| 278 | |
| 279 | // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that |
| 280 | // the original data source x is not used. |
| 281 | if (hasZeroLen) { |
| 282 | Operation *generateOp = createGenerateOp(); |
| 283 | return TilingResult{.tiledOps: {generateOp}, |
| 284 | .tiledValues: {castResult(generateOp->getResult(idx: 0))}, |
| 285 | /*generatedSlices=*/{}}; |
| 286 | } |
| 287 | |
| 288 | // If there are dynamic dimensions: Generate an scf.if check to avoid |
| 289 | // creating SliceOps with result dimensions of size 0 at runtime. |
| 290 | if (generateZeroSliceGuard && dynHasZeroLenCond) { |
| 291 | Operation *thenOp; |
| 292 | Operation *elseOp; |
| 293 | Operation *sliceOp; |
| 294 | auto result = b.create<scf::IfOp>( |
| 295 | loc, dynHasZeroLenCond, |
| 296 | /*thenBuilder=*/ |
| 297 | [&](OpBuilder &b, Location loc) { |
| 298 | thenOp = createGenerateOp(); |
| 299 | b.create<scf::YieldOp>(loc, castResult(thenOp->getResult(0))); |
| 300 | }, |
| 301 | /*elseBuilder=*/ |
| 302 | [&](OpBuilder &b, Location loc) { |
| 303 | std::tie(elseOp, sliceOp) = createPadOfExtractSlice(); |
| 304 | b.create<scf::YieldOp>(loc, castResult(elseOp->getResult(0))); |
| 305 | }); |
| 306 | return TilingResult{ |
| 307 | .tiledOps: {elseOp}, .tiledValues: SmallVector<Value>(result->getResults()), .generatedSlices: {sliceOp}}; |
| 308 | } |
| 309 | |
| 310 | auto [newPadOp, sliceOp] = createPadOfExtractSlice(); |
| 311 | return TilingResult{ |
| 312 | {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}}; |
| 313 | } |
| 314 | |
| 315 | void mlir::tensor::registerTilingInterfaceExternalModels( |
| 316 | DialectRegistry ®istry) { |
| 317 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) { |
| 318 | tensor::PadOp::attachInterface<PadOpTiling>(*ctx); |
| 319 | }); |
| 320 | } |
| 321 | |