| 1 | //===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/Complex/IR/Complex.h" |
| 13 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 14 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 15 | #include "mlir/IR/AffineExpr.h" |
| 16 | #include "mlir/IR/BuiltinAttributes.h" |
| 17 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 18 | #include "mlir/IR/BuiltinTypes.h" |
| 19 | #include "mlir/IR/OpDefinition.h" |
| 20 | #include "mlir/IR/Value.h" |
| 21 | #include "mlir/Interfaces/TilingInterface.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/Support/Casting.h" |
| 24 | |
| 25 | #define DEBUG_TYPE "pad-tiling-interface" |
| 26 | |
| 27 | using namespace mlir; |
| 28 | using namespace mlir::linalg; |
| 29 | using namespace mlir::tensor; |
| 30 | |
| 31 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| 32 | #define DBGSNL() (llvm::dbgs() << "\n") |
| 33 | |
| 34 | /// Form a "full-rank" padding specification so that the application is easy. |
| 35 | static SmallVector<OpFoldResult> |
| 36 | getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, |
| 37 | const PadTilingInterfaceOptions &options) { |
| 38 | SmallVector<OpFoldResult> paddingSizes; |
| 39 | // Complete the padding specification to specify all dimensions. |
| 40 | for (size_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) { |
| 41 | // Complete to zero if needed. |
| 42 | paddingSizes.push_back(Elt: options.paddingSizes.size() > idx |
| 43 | ? options.paddingSizes[idx] |
| 44 | : b.getIndexAttr(value: 0)); |
| 45 | // If a dimension is zero (either specified or completed), replace by: |
| 46 | // - 1 if we are padding to the next multiple of. |
| 47 | // - indexingSizes[idx] otherwise |
| 48 | if (isZeroInteger(v: paddingSizes[idx])) { |
| 49 | paddingSizes[idx] = |
| 50 | options.padToMultipleOf ? b.getIndexAttr(value: 1) : indexingSizes[idx]; |
| 51 | } |
| 52 | LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << paddingSizes[idx] |
| 53 | << "\n" ); |
| 54 | } |
| 55 | return paddingSizes; |
| 56 | } |
| 57 | |
| 58 | /// Compute the padded shape of the given value `v` of `RankedTensorType` given |
| 59 | /// - `indexingSizes` a list of OpFoldResult. |
| 60 | /// - an `indexingMap` that encodes how the shape of varies with increases |
| 61 | /// in `indexingSizes`. |
| 62 | /// The `indexingMap` encodes how the shape of varies with `indexingSizes`. |
| 63 | /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps. |
| 64 | /// The implementaiton below iteratively combines increases from contributing |
| 65 | /// dimensions using affine.apply operations. |
| 66 | /// In the future, more general interfaces can be devised to encode similar |
| 67 | /// shape evolutions and map between an op and its operands. |
| 68 | SmallVector<OpFoldResult> linalg::computePaddedShape( |
| 69 | RewriterBase &rewriter, TypedValue<RankedTensorType> v, |
| 70 | AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes, |
| 71 | const PadTilingInterfaceOptions &options) { |
| 72 | Location loc = v.getLoc(); |
| 73 | SmallVector<OpFoldResult> paddedShape; |
| 74 | auto tensorType = cast<RankedTensorType>(Val: v.getType()); |
| 75 | paddedShape.resize_for_overwrite(N: tensorType.getRank()); |
| 76 | assert(tensorType.getRank() == indexingMap.getNumResults() && |
| 77 | "expect the number of results of the affine map to match the tensor " |
| 78 | "rank" ); |
| 79 | |
| 80 | // "Full-rank" padding specification. |
| 81 | SmallVector<OpFoldResult> paddingSizes = |
| 82 | getFullRankPaddingSizes(b&: rewriter, indexingSizes, options); |
| 83 | |
| 84 | // For each dimension in the operand's shape, iterate over indexingSizes and |
| 85 | // add the various term contributions. |
| 86 | for (const auto &enResults : enumerate(First: indexingMap.getResults())) { |
| 87 | int64_t resultIndex = enResults.index(); |
| 88 | AffineMap partialIndexingMap = indexingMap.getSubMap( |
| 89 | resultPos: ArrayRef<unsigned>{static_cast<unsigned>(resultIndex)}); |
| 90 | |
| 91 | LLVM_DEBUG(DBGS() << "----resultIndex: " << resultIndex |
| 92 | << " with partialIndexingMap: " << partialIndexingMap |
| 93 | << "\n" ); |
| 94 | |
| 95 | // Find all padding dimensions that contribute to this operand dimension |
| 96 | // and compute the padded term contribution to the final padded shape. |
| 97 | SmallVector<OpFoldResult> terms; |
| 98 | for (size_t paddingDim = 0, e = paddingSizes.size(); paddingDim != e; |
| 99 | ++paddingDim) { |
| 100 | OpFoldResult paddingSize = paddingSizes[paddingDim]; |
| 101 | LLVM_DEBUG(DBGS() << "------try apply padding of dim: " << paddingDim |
| 102 | << " to: " << paddingSize << "\n" ); |
| 103 | if (!enResults.value().isFunctionOfDim(position: paddingDim)) |
| 104 | continue; |
| 105 | |
| 106 | LLVM_DEBUG(DBGS() << "------apply padding of dim: " << paddingDim |
| 107 | << " to: " << paddingSize << "\n" ); |
| 108 | |
| 109 | // Project non-'paddingDim' dimensions and compress the result. |
| 110 | llvm::SmallBitVector projectedDims(partialIndexingMap.getNumDims(), true); |
| 111 | projectedDims.flip(Idx: paddingDim); |
| 112 | AffineMap projectedMap = |
| 113 | mlir::projectDims(map: partialIndexingMap, projectedDimensions: projectedDims, |
| 114 | /*compressDims=*/compressDimsFlag: true); |
| 115 | |
| 116 | // If we are padding to the next multiple of, compose with ceil(sz) * sz. |
| 117 | if (options.padToMultipleOf) { |
| 118 | AffineExpr d0, s0; |
| 119 | bindDims(ctx: rewriter.getContext(), exprs&: d0); |
| 120 | bindSymbols(ctx: rewriter.getContext(), exprs&: s0); |
| 121 | AffineMap ceilMap = AffineMap::get(dimCount: 1, symbolCount: 1, result: d0.ceilDiv(other: s0) * s0); |
| 122 | AffineMap composedMap = projectedMap.compose(map: ceilMap); |
| 123 | OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( |
| 124 | b&: rewriter, loc, map: composedMap, |
| 125 | operands: {indexingSizes[paddingDim], paddingSize}, |
| 126 | /*composeAffineMin=*/true); |
| 127 | terms.push_back(Elt: paddingDimOfr); |
| 128 | } else { |
| 129 | // Otherwise just set to paddingSize. |
| 130 | OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( |
| 131 | b&: rewriter, loc, map: projectedMap, operands: paddingSize); |
| 132 | terms.push_back(Elt: paddingDimOfr); |
| 133 | } |
| 134 | |
| 135 | LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n" ); |
| 136 | } |
| 137 | |
| 138 | // If there are no terms, just return the dim. |
| 139 | if (terms.empty()) { |
| 140 | paddedShape[resultIndex] = |
| 141 | createFoldedDimOp(b&: rewriter, loc, val: v, dim: resultIndex); |
| 142 | continue; |
| 143 | } |
| 144 | |
| 145 | // Sum individual terms' contributions. |
| 146 | SmallVector<AffineExpr> dims(terms.size()); |
| 147 | bindDimsList(ctx: rewriter.getContext(), exprs: MutableArrayRef{dims}); |
| 148 | AffineExpr sumExpr = dims.front(); |
| 149 | for (unsigned i = 1; i < dims.size(); ++i) |
| 150 | sumExpr = sumExpr + dims[i]; |
| 151 | OpFoldResult paddedDimOfr = |
| 152 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: sumExpr, operands: terms); |
| 153 | paddedShape[resultIndex] = paddedDimOfr; |
| 154 | } |
| 155 | |
| 156 | return paddedShape; |
| 157 | } |
| 158 | |
| 159 | FailureOr<SmallVector<OpFoldResult>> |
| 160 | linalg::computeIndexingMapOpInterfacePaddedShape( |
| 161 | RewriterBase &rewriter, OpOperand &operandToPad, |
| 162 | ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) { |
| 163 | auto transferOp = |
| 164 | llvm::dyn_cast<IndexingMapOpInterface>(Val: operandToPad.getOwner()); |
| 165 | if (!transferOp) |
| 166 | return failure(); |
| 167 | |
| 168 | // clang-format off |
| 169 | assert(llvm::all_of(iterationDomain, [&rewriter](Range r) { |
| 170 | return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) && |
| 171 | r.stride == OpFoldResult(rewriter.getIndexAttr(1)); |
| 172 | }) && "expected 0-offset 1-stride loop ranges" ); |
| 173 | // clang-format on |
| 174 | SmallVector<OpFoldResult> loopUpperBounds; |
| 175 | loopUpperBounds.reserve(N: iterationDomain.size()); |
| 176 | for (const Range &range : iterationDomain) |
| 177 | loopUpperBounds.push_back(Elt: range.size); |
| 178 | |
| 179 | AffineMap indexingMap = transferOp.getMatchingIndexingMap(opOperand: &operandToPad); |
| 180 | return computePaddedShape( |
| 181 | rewriter, v: cast<TypedValue<RankedTensorType>>(Val: operandToPad.get()), |
| 182 | indexingMap, indexingSizes: loopUpperBounds, options); |
| 183 | } |
| 184 | |
| 185 | /// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding |
| 186 | /// Value. |
| 187 | static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, |
| 188 | TypedValue<RankedTensorType> v, |
| 189 | ArrayRef<OpFoldResult> paddedShape, |
| 190 | Attribute paddingValueAttr) { |
| 191 | Value paddingValue; |
| 192 | if (auto complexTy = |
| 193 | dyn_cast<ComplexType>(Val: getElementTypeOrSelf(type: v.getType()))) { |
| 194 | auto complexAttr = cast<ArrayAttr>(Val&: paddingValueAttr); |
| 195 | paddingValue = rewriter.create<complex::ConstantOp>(location: opToPad.getLoc(), |
| 196 | args&: complexTy, args&: complexAttr); |
| 197 | } else { |
| 198 | paddingValue = rewriter.create<arith::ConstantOp>( |
| 199 | location: opToPad.getLoc(), args: cast<TypedAttr>(Val&: paddingValueAttr)); |
| 200 | } |
| 201 | |
| 202 | // Pad the operand to the bounding box defined by `paddedShape`. |
| 203 | SmallVector<int64_t> tensorShape; |
| 204 | SmallVector<Value> dynDims; |
| 205 | for (OpFoldResult ofr : paddedShape) { |
| 206 | std::optional<int64_t> cst = getConstantIntValue(ofr); |
| 207 | tensorShape.push_back(Elt: cst.has_value() ? *cst : ShapedType::kDynamic); |
| 208 | if (!cst.has_value()) |
| 209 | dynDims.push_back(Elt: ofr.dyn_cast<Value>()); |
| 210 | } |
| 211 | // TODO: use dispatchIndexOpFoldResults(paddedShape, dynDims, paddedShape); |
| 212 | |
| 213 | auto paddedTensorType = |
| 214 | RankedTensorType::get(shape: tensorShape, elementType: getElementTypeOrSelf(val: v)); |
| 215 | LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: " |
| 216 | << paddedTensorType); |
| 217 | return makeComposedPadHighOp(b&: rewriter, loc: opToPad.getLoc(), type: paddedTensorType, source: v, |
| 218 | padding: paddingValue, /*nofold=*/false, typeDynDims: dynDims); |
| 219 | } |
| 220 | |
| 221 | FailureOr<TilingInterface> |
| 222 | linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, |
| 223 | const PadTilingInterfaceOptions &constOptions, |
| 224 | SmallVector<tensor::PadOp> &padOps, |
| 225 | PadSizeComputationFunction computePaddingSizeFun) { |
| 226 | LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n" ); |
| 227 | |
| 228 | Location loc = opToPad.getLoc(); |
| 229 | PadTilingInterfaceOptions options(constOptions); |
| 230 | // Allow inference of pad values if they are not explicitly specified. |
| 231 | // TODO: be mindful about the value depending on the actual operation. |
| 232 | if (options.paddingValues.empty()) { |
| 233 | SmallVector<Type> types(opToPad->getOperandTypes()); |
| 234 | llvm::append_range(C&: types, R: opToPad->getResultTypes()); |
| 235 | for (Type t : types) { |
| 236 | options.paddingValues.push_back( |
| 237 | Elt: rewriter.getZeroAttr(type: getElementTypeOrSelf(type: t))); |
| 238 | } |
| 239 | } |
| 240 | |
| 241 | if (llvm::any_of(Range: opToPad->getOperands(), |
| 242 | P: [](Value v) { return isa<MemRefType>(Val: v.getType()); })) { |
| 243 | return rewriter.notifyMatchFailure(arg&: opToPad, |
| 244 | msg: "expected operation on tensors" ); |
| 245 | } |
| 246 | |
| 247 | OpBuilder::InsertionGuard g(rewriter); |
| 248 | // Set IP after opToPad because we also take the dims of opToPad's output. |
| 249 | rewriter.setInsertionPointAfter(opToPad); |
| 250 | |
| 251 | // 1. Get the loopUpperBounds from the TilingInterface. |
| 252 | SmallVector<Range> iterationDomain = opToPad.getIterationDomain(b&: rewriter); |
| 253 | |
| 254 | // 2. For each operand. |
| 255 | SmallVector<Value> newOperands; |
| 256 | newOperands.reserve(N: opToPad->getNumOperands()); |
| 257 | for (OpOperand &opOperand : opToPad->getOpOperands()) { |
| 258 | Value operand = opOperand.get(); |
| 259 | LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n" ); |
| 260 | |
| 261 | // 2.a. Skip scalar-like operands. |
| 262 | Type operandType = operand.getType(); |
| 263 | if (!isa<RankedTensorType>(Val: operandType)) { |
| 264 | assert((!isa<ShapedType>(operandType) || isa<VectorType>(operandType)) && |
| 265 | "Unexpected non-vector ShapedType" ); |
| 266 | newOperands.push_back(Elt: operand); |
| 267 | continue; |
| 268 | } |
| 269 | // 2.a. Compute padded shape. |
| 270 | FailureOr<SmallVector<OpFoldResult>> maybePaddedShape = |
| 271 | computePaddingSizeFun(rewriter, opOperand, iterationDomain, options); |
| 272 | if (failed(Result: maybePaddedShape)) { |
| 273 | return rewriter.notifyMatchFailure(arg&: opToPad, msg: "could not pad op" ); |
| 274 | } |
| 275 | |
| 276 | // 2.b. Expect proper `paddingValues`. |
| 277 | // TODO: we may want to allow garbage padding in the future, in which case |
| 278 | // we would just not assert. |
| 279 | if (opOperand.getOperandNumber() >= options.paddingValues.size()) { |
| 280 | return rewriter.notifyMatchFailure(arg&: opToPad, |
| 281 | msg: "--no padding value specified" ); |
| 282 | } |
| 283 | Attribute paddingValueAttr = |
| 284 | options.paddingValues[opOperand.getOperandNumber()]; |
| 285 | |
| 286 | // 2.c. Perform actual padding. |
| 287 | Value paddedOperand = padOperand( |
| 288 | rewriter, opToPad, v: cast<TypedValue<RankedTensorType>>(Val&: operand), |
| 289 | paddedShape: *maybePaddedShape, paddingValueAttr); |
| 290 | LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n" ); |
| 291 | |
| 292 | // 2.d. Perform actual padding. |
| 293 | newOperands.push_back(Elt: paddedOperand); |
| 294 | if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>()) |
| 295 | padOps.push_back(Elt: padOp); |
| 296 | } |
| 297 | |
| 298 | // 3. Form the resulting tensor::ExtractSliceOp. |
| 299 | ReifiedRankedShapedTypeDims reifiedResultShapes; |
| 300 | if (failed(Result: reifyResultShapes(b&: rewriter, op: opToPad, reifiedReturnShapes&: reifiedResultShapes))) { |
| 301 | LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n" ); |
| 302 | return rewriter.notifyMatchFailure(arg&: opToPad, |
| 303 | msg: "failed to reify result shapes" ); |
| 304 | } |
| 305 | assert(reifiedResultShapes.size() == opToPad->getNumResults() && |
| 306 | "expected same number of results" ); |
| 307 | |
| 308 | // Clone `opToPad` to operate on the statically padded shapes. |
| 309 | auto resultTensorTypes = |
| 310 | ValueRange(newOperands).take_back(n: opToPad->getNumResults()).getTypes(); |
| 311 | // clone **should** properly notify the rewriter. |
| 312 | TilingInterface paddedOp = |
| 313 | clone(b&: rewriter, op: opToPad, newResultTypes: resultTensorTypes, newOperands); |
| 314 | LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n" ); |
| 315 | |
| 316 | // Recover the slice out of the new static results. This keeps the original |
| 317 | // opToPad around because it uses the dims of the original results. |
| 318 | SmallVector<Value> paddedSubtensorResults; |
| 319 | paddedSubtensorResults.reserve(N: opToPad->getNumResults()); |
| 320 | for (const auto &en : llvm::enumerate(First: paddedOp->getResults())) { |
| 321 | Value paddedResult = en.value(); |
| 322 | int64_t resultNumber = en.index(); |
| 323 | int64_t rank = cast<RankedTensorType>(Val: paddedResult.getType()).getRank(); |
| 324 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0)); |
| 325 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1)); |
| 326 | paddedSubtensorResults.push_back(Elt: rewriter.create<tensor::ExtractSliceOp>( |
| 327 | location: loc, args&: paddedResult, args&: offsets, args&: reifiedResultShapes[resultNumber], |
| 328 | args&: strides)); |
| 329 | } |
| 330 | |
| 331 | rewriter.replaceOp(op: opToPad, newValues: paddedSubtensorResults); |
| 332 | |
| 333 | return paddedOp; |
| 334 | } |
| 335 | |