| 1 | //===- TensorTilingInterface.cpp - Tiling Interface models *- C++ ------*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" |
| 10 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 11 | #include "mlir/Dialect/Affine/Utils.h" |
| 12 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 13 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 14 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 15 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 16 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
| 17 | #include "mlir/Interfaces/TilingInterface.h" |
| 18 | |
| 19 | using namespace mlir; |
| 20 | using namespace mlir::tensor; |
| 21 | |
| 22 | namespace { |
| 23 | |
| 24 | struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> { |
| 25 | |
| 26 | SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { |
| 27 | auto padOp = cast<PadOp>(Val: op); |
| 28 | SmallVector<utils::IteratorType> iteratorTypes( |
| 29 | padOp.getResultType().getRank(), utils::IteratorType::parallel); |
| 30 | return iteratorTypes; |
| 31 | } |
| 32 | |
| 33 | SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const { |
| 34 | ReifiedRankedShapedTypeDims reifiedShapes; |
| 35 | (void)reifyResultShapes(b, op, reifiedReturnShapes&: reifiedShapes); |
| 36 | OpFoldResult zero = b.getIndexAttr(value: 0); |
| 37 | OpFoldResult one = b.getIndexAttr(value: 1); |
| 38 | // Initialize all the ranges to {zero, one, one}. All the `ub`s are |
| 39 | // overwritten. |
| 40 | SmallVector<Range> loopRanges(reifiedShapes[0].size(), {.offset: zero, .size: one, .stride: one}); |
| 41 | for (const auto &ub : enumerate(First&: reifiedShapes[0])) |
| 42 | loopRanges[ub.index()].size = ub.value(); |
| 43 | return loopRanges; |
| 44 | } |
| 45 | |
| 46 | FailureOr<TilingResult> |
| 47 | getTiledImplementation(Operation *op, OpBuilder &b, |
| 48 | ArrayRef<OpFoldResult> offsets, |
| 49 | ArrayRef<OpFoldResult> sizes) const { |
| 50 | FailureOr<TilingResult> result = |
| 51 | tensor::bubbleUpPadSlice(b, padOp: cast<PadOp>(Val: op), offsets, sizes); |
| 52 | if (failed(Result: result)) |
| 53 | return failure(); |
| 54 | return result.value(); |
| 55 | } |
| 56 | |
| 57 | LogicalResult |
| 58 | getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 59 | ArrayRef<OpFoldResult> offsets, |
| 60 | ArrayRef<OpFoldResult> sizes, |
| 61 | SmallVector<OpFoldResult> &resultOffsets, |
| 62 | SmallVector<OpFoldResult> &resultSizes) const { |
| 63 | resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end()); |
| 64 | resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end()); |
| 65 | return success(); |
| 66 | } |
| 67 | |
| 68 | LogicalResult getIterationDomainTileFromResultTile( |
| 69 | Operation *op, OpBuilder &b, unsigned resultNumber, |
| 70 | ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| 71 | SmallVectorImpl<OpFoldResult> &iterDomainOffsets, |
| 72 | SmallVectorImpl<OpFoldResult> &iterDomainSizes) const { |
| 73 | iterDomainOffsets.assign(in_start: offsets.begin(), in_end: offsets.end()); |
| 74 | iterDomainSizes.assign(in_start: sizes.begin(), in_end: sizes.end()); |
| 75 | return success(); |
| 76 | } |
| 77 | |
| 78 | FailureOr<TilingResult> |
| 79 | generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, |
| 80 | ArrayRef<OpFoldResult> offsets, |
| 81 | ArrayRef<OpFoldResult> sizes) const { |
| 82 | return getTiledImplementation(op, b, offsets, sizes); |
| 83 | } |
| 84 | }; |
| 85 | |
| 86 | } // namespace |
| 87 | |
| 88 | FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b, |
| 89 | tensor::PadOp padOp, |
| 90 | ArrayRef<OpFoldResult> offsets, |
| 91 | ArrayRef<OpFoldResult> sizes, |
| 92 | bool generateZeroSliceGuard) { |
| 93 | // Only constant padding value supported. |
| 94 | Value padValue = padOp.getConstantPaddingValue(); |
| 95 | if (!padValue) |
| 96 | return failure(); |
| 97 | |
| 98 | // Helper variables and functions for various arithmetic operations. These |
| 99 | // are used extensively for computing new offset/length and padding values. |
| 100 | Location loc = padOp->getLoc(); |
| 101 | AffineExpr dim0, dim1; |
| 102 | bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1); |
| 103 | // Subtract two integers. |
| 104 | auto subMap = AffineMap::get(dimCount: 2, symbolCount: 0, result: {dim0 - dim1}); |
| 105 | auto sub = [&](OpFoldResult v1, OpFoldResult v2) { |
| 106 | return affine::makeComposedFoldedAffineApply(b, loc, map: subMap, operands: {v1, v2}); |
| 107 | }; |
| 108 | // Take the minimum of two integers. |
| 109 | auto idMap = AffineMap::getMultiDimIdentityMap(numDims: 2, context: b.getContext()); |
| 110 | auto min = [&](OpFoldResult v1, OpFoldResult v2) { |
| 111 | return affine::makeComposedFoldedAffineMin(b, loc, map: idMap, operands: {v1, v2}); |
| 112 | }; |
| 113 | // Take the maximum of two integers. |
| 114 | auto max = [&](OpFoldResult v1, OpFoldResult v2) { |
| 115 | return affine::makeComposedFoldedAffineMax(b, loc, map: idMap, operands: {v1, v2}); |
| 116 | }; |
| 117 | // Zero index-typed integer. |
| 118 | OpFoldResult zero = b.getIndexAttr(value: 0); |
| 119 | |
| 120 | // Compute new offsets, lengths, low padding, high padding. |
| 121 | SmallVector<OpFoldResult> newOffsets, newLengths; |
| 122 | SmallVector<OpFoldResult> newLows, newHighs; |
| 123 | // Set to true if the original data source is not read at all. |
| 124 | bool hasZeroLen = false; |
| 125 | // Same as hasZeroLen, but for dynamic dimension sizes. This condition |
| 126 | // is true if the original data source turns out to be unused at runtime. |
| 127 | Value dynHasZeroLenCond; |
| 128 | |
| 129 | int64_t rank = padOp.getSourceType().getRank(); |
| 130 | // Only unit stride supported. |
| 131 | SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(value: 1)); |
| 132 | for (unsigned dim = 0; dim < rank; ++dim) { |
| 133 | auto low = padOp.getMixedLowPad()[dim]; |
| 134 | bool hasLowPad = !isZeroInteger(v: low); |
| 135 | auto high = padOp.getMixedHighPad()[dim]; |
| 136 | bool hasHighPad = !isZeroInteger(v: high); |
| 137 | auto offset = offsets[dim]; |
| 138 | auto length = sizes[dim]; |
| 139 | // If the dim has no padding, we dont need to calculate new values for that |
| 140 | // dim as the exisiting ones are correct even after the pattern. |
| 141 | if (!hasLowPad && !hasHighPad) { |
| 142 | newOffsets.push_back(Elt: offset); |
| 143 | newLengths.push_back(Elt: length); |
| 144 | newLows.push_back(Elt: low); |
| 145 | newHighs.push_back(Elt: high); |
| 146 | continue; |
| 147 | } |
| 148 | |
| 149 | auto srcSize = tensor::getMixedSize(builder&: b, loc, value: padOp.getSource(), dim); |
| 150 | |
| 151 | // The new amount of low padding is `low - offset`. Except for the case |
| 152 | // where none of the low padding is read. In that case, the new amount of |
| 153 | // low padding is zero. |
| 154 | // |
| 155 | // Optimization: If low = 0, then newLow = 0. |
| 156 | OpFoldResult newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; |
| 157 | newLows.push_back(Elt: newLow); |
| 158 | |
| 159 | // Start reading the data from position `offset - low`. Since the original |
| 160 | // read may have started in the low padding zone, this value could be |
| 161 | // negative. Therefore, start reading from: |
| 162 | // |
| 163 | // max(offset - low, 0) |
| 164 | // |
| 165 | // The original read could also have started in the high padding zone. |
| 166 | // In that case, set the offset to the end of source tensor. The new |
| 167 | // ExtractSliceOp length will be zero in that case. (Effectively reading |
| 168 | // no data from the source.) |
| 169 | // |
| 170 | // Optimization: If low = 0, then the formula can be simplified. |
| 171 | OpFoldResult newOffset = hasLowPad |
| 172 | ? min(max(sub(offset, low), zero), srcSize) |
| 173 | : min(offset, srcSize); |
| 174 | newOffsets.push_back(Elt: newOffset); |
| 175 | |
| 176 | // The original ExtractSliceOp was reading until position `offset + |
| 177 | // length`. Therefore, the corresponding position within the source tensor |
| 178 | // is: |
| 179 | // |
| 180 | // offset + length - low |
| 181 | // |
| 182 | // In case the original ExtractSliceOp stopped reading within the low |
| 183 | // padding zone, this value can be negative. In that case, the end |
| 184 | // position of the read should be zero. (Similar to newOffset.) |
| 185 | // |
| 186 | // The original read could also have stopped in the high padding zone. |
| 187 | // In that case, set the end positition of the read should be the end of |
| 188 | // the source tensor. (Similar to newOffset.) |
| 189 | // srcSize - newOffset represents how much length we have available |
| 190 | // and length - newLow represents how much length we want at most. |
| 191 | // Note that there are many ways to order this indexing math to compute |
| 192 | // newLength, but we want to make sure that the final affine.min ops in the |
| 193 | // sequence are bounding the index to as small a value as possible. If |
| 194 | // ValueBoundsOpInterface is used, this calculation will get upper bounds |
| 195 | // from the affine.min ops, so we want to use the smallest known value to |
| 196 | // set the bound at the end of the computation sequence. In this case, the |
| 197 | // index will be upper bounded by length - newLow. |
| 198 | OpFoldResult newLength = min(sub(srcSize, newOffset), sub(length, newLow)); |
| 199 | // Optimization: If low = 0, then newLow = 0. then newLength >= 0 assuming |
| 200 | // length >= 0. |
| 201 | if (hasLowPad) |
| 202 | newLength = max(newLength, zero); |
| 203 | newLengths.push_back(Elt: newLength); |
| 204 | |
| 205 | // Check if newLength is zero. In that case, no SubTensorOp should be |
| 206 | // executed. |
| 207 | if (isZeroInteger(v: newLength)) { |
| 208 | hasZeroLen = true; |
| 209 | } else if (!hasZeroLen) { |
| 210 | Value check = b.create<arith::CmpIOp>( |
| 211 | location: loc, args: arith::CmpIPredicate::eq, |
| 212 | args: getValueOrCreateConstantIndexOp(b, loc, ofr: newLength), |
| 213 | args: getValueOrCreateConstantIndexOp(b, loc, ofr: zero)); |
| 214 | dynHasZeroLenCond = |
| 215 | dynHasZeroLenCond |
| 216 | ? b.create<arith::OrIOp>(location: loc, args&: check, args&: dynHasZeroLenCond) |
| 217 | : check; |
| 218 | } |
| 219 | |
| 220 | // The amount of high padding is simply the number of elements remaining, |
| 221 | // so that the result has the same length as the original ExtractSliceOp. |
| 222 | // As an optimization, if the original high padding is zero, then the new |
| 223 | // high padding must also be zero. |
| 224 | OpFoldResult newHigh = |
| 225 | hasHighPad ? sub(sub(length, newLength), newLow) : zero; |
| 226 | newHighs.push_back(Elt: newHigh); |
| 227 | } |
| 228 | |
| 229 | // The shape of the result can be obtained from the sizes passed in. |
| 230 | SmallVector<Value> dynDims; |
| 231 | SmallVector<int64_t> shape; |
| 232 | dispatchIndexOpFoldResults(ofrs: sizes, dynamicVec&: dynDims, staticVec&: shape); |
| 233 | RankedTensorType resultType = |
| 234 | RankedTensorType::get(shape, elementType: padOp.getResultType().getElementType()); |
| 235 | |
| 236 | // Insert cast to ensure that types match. (May be folded away.) |
| 237 | auto castResult = [&](Value val) -> Value { |
| 238 | if (resultType == val.getType()) |
| 239 | return val; |
| 240 | return b.create<tensor::CastOp>(location: loc, args&: resultType, args&: val); |
| 241 | }; |
| 242 | |
| 243 | // In cases where the original data source is unused: Emit a GenerateOp and |
| 244 | // do not generate a SliceOp. (The result shape of the SliceOp would |
| 245 | // have a dimension of size 0, the semantics of which is unclear.) |
| 246 | auto createGenerateOp = [&]() { |
| 247 | // Create GenerateOp. |
| 248 | auto generateOp = b.create<tensor::GenerateOp>( |
| 249 | location: loc, args&: resultType, args&: dynDims, |
| 250 | args: [&](OpBuilder &builder, Location gLoc, ValueRange indices) { |
| 251 | builder.create<tensor::YieldOp>(location: gLoc, args&: padValue); |
| 252 | }); |
| 253 | return generateOp; |
| 254 | }; |
| 255 | |
| 256 | // Emit a SliceOp and a PadOp. Should not be used in cases where |
| 257 | // the result shape of the new SliceOp has a zero dimension. |
| 258 | auto = [&]() { |
| 259 | // Create pad(extract_slice(x)). |
| 260 | auto newSliceOp = b.create<tensor::ExtractSliceOp>( |
| 261 | location: loc, args: padOp.getSource(), args&: newOffsets, args&: newLengths, args&: newStrides); |
| 262 | auto newPadOp = b.create<PadOp>( |
| 263 | location: loc, args: Type(), args&: newSliceOp, args&: newLows, args&: newHighs, |
| 264 | /*nofold=*/args: padOp.getNofold(), |
| 265 | args: getPrunedAttributeList(op: padOp, elidedAttrs: PadOp::getAttributeNames())); |
| 266 | |
| 267 | // Copy region to new PadOp. |
| 268 | IRMapping bvm; |
| 269 | padOp.getRegion().cloneInto(dest: &newPadOp.getRegion(), mapper&: bvm); |
| 270 | |
| 271 | // Cast result and return. |
| 272 | return std::make_tuple(args&: newPadOp, args&: newSliceOp); |
| 273 | }; |
| 274 | |
| 275 | // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that |
| 276 | // the original data source x is not used. |
| 277 | if (hasZeroLen) { |
| 278 | Operation *generateOp = createGenerateOp(); |
| 279 | return TilingResult{.tiledOps: {generateOp}, |
| 280 | .tiledValues: {castResult(generateOp->getResult(idx: 0))}, |
| 281 | /*generatedSlices=*/{}}; |
| 282 | } |
| 283 | |
| 284 | // If there are dynamic dimensions: Generate an scf.if check to avoid |
| 285 | // creating SliceOps with result dimensions of size 0 at runtime. |
| 286 | if (generateZeroSliceGuard && dynHasZeroLenCond) { |
| 287 | Operation *thenOp; |
| 288 | Operation *elseOp; |
| 289 | Operation *sliceOp; |
| 290 | auto result = b.create<scf::IfOp>( |
| 291 | location: loc, args&: dynHasZeroLenCond, |
| 292 | /*thenBuilder=*/ |
| 293 | args: [&](OpBuilder &b, Location loc) { |
| 294 | thenOp = createGenerateOp(); |
| 295 | b.create<scf::YieldOp>(location: loc, args: castResult(thenOp->getResult(idx: 0))); |
| 296 | }, |
| 297 | /*elseBuilder=*/ |
| 298 | args: [&](OpBuilder &b, Location loc) { |
| 299 | std::tie(args&: elseOp, args&: sliceOp) = createPadOfExtractSlice(); |
| 300 | b.create<scf::YieldOp>(location: loc, args: castResult(elseOp->getResult(idx: 0))); |
| 301 | }); |
| 302 | return TilingResult{ |
| 303 | .tiledOps: {elseOp}, .tiledValues: SmallVector<Value>(result->getResults()), .generatedSlices: {sliceOp}}; |
| 304 | } |
| 305 | |
| 306 | auto [newPadOp, sliceOp] = createPadOfExtractSlice(); |
| 307 | return TilingResult{ |
| 308 | .tiledOps: {newPadOp}, .tiledValues: {castResult(newPadOp->getResult(idx: 0))}, .generatedSlices: {sliceOp}}; |
| 309 | } |
| 310 | |
| 311 | void mlir::tensor::registerTilingInterfaceExternalModels( |
| 312 | DialectRegistry ®istry) { |
| 313 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, TensorDialect *dialect) { |
| 314 | tensor::PadOp::attachInterface<PadOpTiling>(context&: *ctx); |
| 315 | }); |
| 316 | } |
| 317 | |