| 1 | //===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include <numeric> |
| 10 | |
| 11 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 12 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 13 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
| 14 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 15 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 16 | #include "mlir/IR/Builders.h" |
| 17 | #include "mlir/IR/TypeUtilities.h" |
| 18 | |
| 19 | #define DEBUG_TYPE "vector-drop-unit-dim" |
| 20 | |
| 21 | using namespace mlir; |
| 22 | using namespace mlir::vector; |
| 23 | |
| 24 | // Trims leading one dimensions from `oldType` and returns the result type. |
| 25 | // Returns `vector<1xT>` if `oldType` only has one element. |
| 26 | static VectorType trimLeadingOneDims(VectorType oldType) { |
| 27 | ArrayRef<int64_t> oldShape = oldType.getShape(); |
| 28 | ArrayRef<int64_t> newShape = oldShape; |
| 29 | |
| 30 | ArrayRef<bool> oldScalableDims = oldType.getScalableDims(); |
| 31 | ArrayRef<bool> newScalableDims = oldScalableDims; |
| 32 | |
| 33 | while (!newShape.empty() && newShape.front() == 1 && |
| 34 | !newScalableDims.front()) { |
| 35 | newShape = newShape.drop_front(N: 1); |
| 36 | newScalableDims = newScalableDims.drop_front(N: 1); |
| 37 | } |
| 38 | |
| 39 | // Make sure we have at least 1 dimension per vector type requirements. |
| 40 | if (newShape.empty()) { |
| 41 | newShape = oldShape.take_back(); |
| 42 | newScalableDims = oldType.getScalableDims().take_back(); |
| 43 | } |
| 44 | return VectorType::get(newShape, oldType.getElementType(), newScalableDims); |
| 45 | } |
| 46 | |
| 47 | /// Return a smallVector of size `rank` containing all zeros. |
| 48 | static SmallVector<int64_t> splatZero(int64_t rank) { |
| 49 | return SmallVector<int64_t>(rank, 0); |
| 50 | } |
| 51 | namespace { |
| 52 | |
| 53 | // Casts away leading one dimensions in vector.extract_strided_slice's vector |
| 54 | // input by inserting vector.broadcast. |
| 55 | struct |
| 56 | : public OpRewritePattern<vector::ExtractStridedSliceOp> { |
| 57 | using OpRewritePattern::OpRewritePattern; |
| 58 | |
| 59 | LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp , |
| 60 | PatternRewriter &rewriter) const override { |
| 61 | // vector.extract_strided_slice requires the input and output vector to have |
| 62 | // the same rank. Here we drop leading one dimensions from the input vector |
| 63 | // type to make sure we don't cause mismatch. |
| 64 | VectorType oldSrcType = extractOp.getSourceVectorType(); |
| 65 | VectorType newSrcType = trimLeadingOneDims(oldSrcType); |
| 66 | |
| 67 | if (newSrcType.getRank() == oldSrcType.getRank()) |
| 68 | return failure(); |
| 69 | |
| 70 | int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); |
| 71 | |
| 72 | VectorType oldDstType = extractOp.getType(); |
| 73 | VectorType newDstType = |
| 74 | VectorType::get(oldDstType.getShape().drop_front(dropCount), |
| 75 | oldDstType.getElementType(), |
| 76 | oldDstType.getScalableDims().drop_front(dropCount)); |
| 77 | |
| 78 | Location loc = extractOp.getLoc(); |
| 79 | |
| 80 | Value newSrcVector = rewriter.create<vector::ExtractOp>( |
| 81 | loc, extractOp.getVector(), splatZero(dropCount)); |
| 82 | |
| 83 | // The offsets/sizes/strides attribute can have a less number of elements |
| 84 | // than the input vector's rank: it is meant for the leading dimensions. |
| 85 | auto newOffsets = rewriter.getArrayAttr( |
| 86 | value: extractOp.getOffsets().getValue().drop_front(dropCount)); |
| 87 | auto newSizes = rewriter.getArrayAttr( |
| 88 | value: extractOp.getSizes().getValue().drop_front(dropCount)); |
| 89 | auto newStrides = rewriter.getArrayAttr( |
| 90 | value: extractOp.getStrides().getValue().drop_front(dropCount)); |
| 91 | |
| 92 | auto = rewriter.create<vector::ExtractStridedSliceOp>( |
| 93 | loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); |
| 94 | |
| 95 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType, |
| 96 | newExtractOp); |
| 97 | |
| 98 | return success(); |
| 99 | } |
| 100 | }; |
| 101 | |
| 102 | // Casts away leading one dimensions in vector.insert_strided_slice's vector |
| 103 | // inputs by inserting vector.broadcast. |
| 104 | struct CastAwayInsertStridedSliceLeadingOneDim |
| 105 | : public OpRewritePattern<vector::InsertStridedSliceOp> { |
| 106 | using OpRewritePattern::OpRewritePattern; |
| 107 | |
| 108 | LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, |
| 109 | PatternRewriter &rewriter) const override { |
| 110 | VectorType oldSrcType = insertOp.getSourceVectorType(); |
| 111 | VectorType newSrcType = trimLeadingOneDims(oldSrcType); |
| 112 | VectorType oldDstType = insertOp.getDestVectorType(); |
| 113 | VectorType newDstType = trimLeadingOneDims(oldDstType); |
| 114 | |
| 115 | int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); |
| 116 | int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); |
| 117 | if (srcDropCount == 0 && dstDropCount == 0) |
| 118 | return failure(); |
| 119 | |
| 120 | // Trim leading one dimensions from both operands. |
| 121 | Location loc = insertOp.getLoc(); |
| 122 | |
| 123 | Value newSrcVector = rewriter.create<vector::ExtractOp>( |
| 124 | loc, insertOp.getValueToStore(), splatZero(srcDropCount)); |
| 125 | Value newDstVector = rewriter.create<vector::ExtractOp>( |
| 126 | loc, insertOp.getDest(), splatZero(dstDropCount)); |
| 127 | |
| 128 | auto newOffsets = rewriter.getArrayAttr( |
| 129 | value: insertOp.getOffsets().getValue().take_back(newDstType.getRank())); |
| 130 | auto newStrides = rewriter.getArrayAttr( |
| 131 | value: insertOp.getStrides().getValue().take_back(newSrcType.getRank())); |
| 132 | |
| 133 | auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>( |
| 134 | loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); |
| 135 | |
| 136 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, |
| 137 | newInsertOp); |
| 138 | |
| 139 | return success(); |
| 140 | } |
| 141 | }; |
| 142 | |
| 143 | // Casts away leading one dimensions in vector.insert's vector inputs by |
| 144 | // inserting vector.broadcast. |
| 145 | struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { |
| 146 | using OpRewritePattern::OpRewritePattern; |
| 147 | |
| 148 | LogicalResult matchAndRewrite(vector::InsertOp insertOp, |
| 149 | PatternRewriter &rewriter) const override { |
| 150 | Type oldSrcType = insertOp.getValueToStoreType(); |
| 151 | Type newSrcType = oldSrcType; |
| 152 | int64_t oldSrcRank = 0, newSrcRank = 0; |
| 153 | if (auto type = dyn_cast<VectorType>(oldSrcType)) { |
| 154 | newSrcType = trimLeadingOneDims(type); |
| 155 | oldSrcRank = type.getRank(); |
| 156 | newSrcRank = cast<VectorType>(newSrcType).getRank(); |
| 157 | } |
| 158 | |
| 159 | VectorType oldDstType = insertOp.getDestVectorType(); |
| 160 | VectorType newDstType = trimLeadingOneDims(oldDstType); |
| 161 | |
| 162 | int64_t srcDropCount = oldSrcRank - newSrcRank; |
| 163 | int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); |
| 164 | if (srcDropCount == 0 && dstDropCount == 0) |
| 165 | return failure(); |
| 166 | |
| 167 | // Trim leading one dimensions from both operands. |
| 168 | Location loc = insertOp.getLoc(); |
| 169 | |
| 170 | Value newSrcVector = insertOp.getValueToStore(); |
| 171 | if (oldSrcRank != 0) { |
| 172 | newSrcVector = rewriter.create<vector::ExtractOp>( |
| 173 | loc, insertOp.getValueToStore(), splatZero(srcDropCount)); |
| 174 | } |
| 175 | Value newDstVector = rewriter.create<vector::ExtractOp>( |
| 176 | loc, insertOp.getDest(), splatZero(dstDropCount)); |
| 177 | |
| 178 | // New position rank needs to be computed in two steps: (1) if destination |
| 179 | // type has leading unit dims, we also trim the position array accordingly, |
| 180 | // then (2) if source type also has leading unit dims, we need to append |
| 181 | // zeroes to the position array accordingly. |
| 182 | unsigned oldPosRank = insertOp.getNumIndices(); |
| 183 | unsigned newPosRank = std::max<int64_t>(a: 0, b: oldPosRank - dstDropCount); |
| 184 | SmallVector<OpFoldResult> oldPosition = insertOp.getMixedPosition(); |
| 185 | SmallVector<OpFoldResult> newPosition = |
| 186 | llvm::to_vector(Range: ArrayRef(oldPosition).take_back(N: newPosRank)); |
| 187 | newPosition.resize(newDstType.getRank() - newSrcRank, |
| 188 | rewriter.getI64IntegerAttr(0)); |
| 189 | |
| 190 | auto newInsertOp = rewriter.create<vector::InsertOp>( |
| 191 | loc, newSrcVector, newDstVector, newPosition); |
| 192 | |
| 193 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType, |
| 194 | newInsertOp); |
| 195 | |
| 196 | return success(); |
| 197 | } |
| 198 | }; |
| 199 | |
| 200 | static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, |
| 201 | VectorType newType, AffineMap newMap, |
| 202 | VectorType oldMaskType) { |
| 203 | // Infer the type of the new mask from the new map. |
| 204 | VectorType newMaskType = inferTransferOpMaskType(newType, newMap); |
| 205 | |
| 206 | // If the new mask is broadcastable to the old result type, we can safely |
| 207 | // use a `vector.extract` to get the new mask. Otherwise the best we can |
| 208 | // do is shape cast. |
| 209 | if (vector::isBroadcastableTo(srcType: newMaskType, dstVectorType: oldMaskType) == |
| 210 | BroadcastableToResult::Success) { |
| 211 | int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); |
| 212 | return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim)); |
| 213 | } |
| 214 | return b.create<vector::ShapeCastOp>(loc, newMaskType, mask); |
| 215 | } |
| 216 | |
| 217 | // Turns vector.transfer_read on vector with leading 1 dimensions into |
| 218 | // vector.shape_cast followed by vector.transfer_read on vector without leading |
| 219 | // 1 dimensions. |
| 220 | struct CastAwayTransferReadLeadingOneDim |
| 221 | : public OpRewritePattern<vector::TransferReadOp> { |
| 222 | using OpRewritePattern::OpRewritePattern; |
| 223 | |
| 224 | LogicalResult matchAndRewrite(vector::TransferReadOp read, |
| 225 | PatternRewriter &rewriter) const override { |
| 226 | // TODO(#78787): Not supported masked op yet. |
| 227 | if (cast<MaskableOpInterface>(read.getOperation()).isMasked()) |
| 228 | return failure(); |
| 229 | // TODO: support 0-d corner case. |
| 230 | if (read.getTransferRank() == 0) |
| 231 | return failure(); |
| 232 | |
| 233 | auto shapedType = cast<ShapedType>(read.getBase().getType()); |
| 234 | if (shapedType.getElementType() != read.getVectorType().getElementType()) |
| 235 | return failure(); |
| 236 | |
| 237 | VectorType oldType = read.getVectorType(); |
| 238 | VectorType newType = trimLeadingOneDims(oldType); |
| 239 | |
| 240 | if (newType == oldType) |
| 241 | return failure(); |
| 242 | |
| 243 | AffineMap oldMap = read.getPermutationMap(); |
| 244 | ArrayRef<AffineExpr> newResults = |
| 245 | oldMap.getResults().take_back(N: newType.getRank()); |
| 246 | AffineMap newMap = |
| 247 | AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(), results: newResults, |
| 248 | context: rewriter.getContext()); |
| 249 | |
| 250 | ArrayAttr inBoundsAttr; |
| 251 | if (read.getInBounds()) |
| 252 | inBoundsAttr = rewriter.getArrayAttr( |
| 253 | value: read.getInBoundsAttr().getValue().take_back(newType.getRank())); |
| 254 | |
| 255 | Value mask = Value(); |
| 256 | if (read.getMask()) { |
| 257 | VectorType maskType = read.getMaskType(); |
| 258 | mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(), |
| 259 | newType, newMap, maskType); |
| 260 | } |
| 261 | |
| 262 | auto newRead = rewriter.create<vector::TransferReadOp>( |
| 263 | read.getLoc(), newType, read.getBase(), read.getIndices(), |
| 264 | AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); |
| 265 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead); |
| 266 | |
| 267 | return success(); |
| 268 | } |
| 269 | }; |
| 270 | |
| 271 | // Turns vector.transfer_write on vector with leading 1 dimensions into |
| 272 | // vector.shape_cast followed by vector.transfer_write on vector without leading |
| 273 | // 1 dimensions. |
| 274 | struct CastAwayTransferWriteLeadingOneDim |
| 275 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 276 | using OpRewritePattern::OpRewritePattern; |
| 277 | |
| 278 | LogicalResult matchAndRewrite(vector::TransferWriteOp write, |
| 279 | PatternRewriter &rewriter) const override { |
| 280 | // TODO(#78787): Not supported masked op yet. |
| 281 | if (cast<MaskableOpInterface>(write.getOperation()).isMasked()) |
| 282 | return failure(); |
| 283 | // TODO: support 0-d corner case. |
| 284 | if (write.getTransferRank() == 0) |
| 285 | return failure(); |
| 286 | |
| 287 | auto shapedType = dyn_cast<ShapedType>(write.getBase().getType()); |
| 288 | if (shapedType.getElementType() != write.getVectorType().getElementType()) |
| 289 | return failure(); |
| 290 | |
| 291 | VectorType oldType = write.getVectorType(); |
| 292 | VectorType newType = trimLeadingOneDims(oldType); |
| 293 | if (newType == oldType) |
| 294 | return failure(); |
| 295 | int64_t dropDim = oldType.getRank() - newType.getRank(); |
| 296 | |
| 297 | AffineMap oldMap = write.getPermutationMap(); |
| 298 | ArrayRef<AffineExpr> newResults = |
| 299 | oldMap.getResults().take_back(N: newType.getRank()); |
| 300 | AffineMap newMap = |
| 301 | AffineMap::get(dimCount: oldMap.getNumDims(), symbolCount: oldMap.getNumSymbols(), results: newResults, |
| 302 | context: rewriter.getContext()); |
| 303 | |
| 304 | ArrayAttr inBoundsAttr; |
| 305 | if (write.getInBounds()) |
| 306 | inBoundsAttr = rewriter.getArrayAttr( |
| 307 | value: write.getInBoundsAttr().getValue().take_back(newType.getRank())); |
| 308 | |
| 309 | auto newVector = rewriter.create<vector::ExtractOp>( |
| 310 | write.getLoc(), write.getVector(), splatZero(dropDim)); |
| 311 | |
| 312 | if (write.getMask()) { |
| 313 | VectorType maskType = write.getMaskType(); |
| 314 | Value newMask = dropUnitDimsFromMask( |
| 315 | rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType); |
| 316 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| 317 | write, newVector, write.getBase(), write.getIndices(), |
| 318 | AffineMapAttr::get(newMap), newMask, inBoundsAttr); |
| 319 | return success(); |
| 320 | } |
| 321 | |
| 322 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| 323 | write, newVector, write.getBase(), write.getIndices(), |
| 324 | AffineMapAttr::get(newMap), inBoundsAttr); |
| 325 | return success(); |
| 326 | } |
| 327 | }; |
| 328 | |
| 329 | } // namespace |
| 330 | |
| 331 | FailureOr<Value> |
| 332 | mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, |
| 333 | MaskingOpInterface maskingOp, |
| 334 | RewriterBase &rewriter) { |
| 335 | VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType()); |
| 336 | if (oldAccType == nullptr) |
| 337 | return failure(); |
| 338 | if (oldAccType.getRank() < 2) |
| 339 | return failure(); |
| 340 | if (oldAccType.getShape()[0] != 1) |
| 341 | return failure(); |
| 342 | // currently we support only dropping one dim but the pattern can be applied |
| 343 | // greedily to drop more. |
| 344 | int64_t dropDim = 1; |
| 345 | |
| 346 | auto oldIndexingMaps = contractOp.getIndexingMapsArray(); |
| 347 | SmallVector<AffineMap> newIndexingMaps; |
| 348 | |
| 349 | auto oldIteratorTypes = contractOp.getIteratorTypes(); |
| 350 | SmallVector<Attribute> newIteratorTypes; |
| 351 | |
| 352 | int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); |
| 353 | |
| 354 | if (!isParallelIterator(oldIteratorTypes[dimToDrop])) |
| 355 | // only parallel type iterators can be dropped. |
| 356 | return failure(); |
| 357 | |
| 358 | for (const auto &it : llvm::enumerate(oldIteratorTypes)) { |
| 359 | int64_t currDim = it.index(); |
| 360 | if (currDim == dimToDrop) |
| 361 | continue; |
| 362 | newIteratorTypes.push_back(it.value()); |
| 363 | } |
| 364 | |
| 365 | SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(), |
| 366 | contractOp.getAcc()}; |
| 367 | SmallVector<Value> newOperands; |
| 368 | auto loc = contractOp.getLoc(); |
| 369 | |
| 370 | for (const auto &it : llvm::enumerate(oldIndexingMaps)) { |
| 371 | // Check if the dim to be dropped exists as a leading dim in the operand |
| 372 | // if it does then we use vector.extract to drop it. |
| 373 | bool validExtract = false; |
| 374 | SmallVector<AffineExpr> results; |
| 375 | auto map = it.value(); |
| 376 | int64_t orginalZeroDim = it.value().getDimPosition(0); |
| 377 | if (orginalZeroDim != dimToDrop) { |
| 378 | // There are two reasons to be in this path, 1. We need to |
| 379 | // transpose the operand to make the dim to be dropped |
| 380 | // leading. 2. The dim to be dropped does not exist and in |
| 381 | // that case we dont want to add a unit transpose but we must |
| 382 | // check all the indices to make sure this is the case. |
| 383 | bool transposeNeeded = false; |
| 384 | SmallVector<int64_t> perm; |
| 385 | SmallVector<AffineExpr> transposeResults; |
| 386 | |
| 387 | for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
| 388 | int64_t currDim = map.getDimPosition(i); |
| 389 | if (currDim == dimToDrop) { |
| 390 | transposeNeeded = true; |
| 391 | perm.insert(perm.begin(), i); |
| 392 | auto targetExpr = rewriter.getAffineDimExpr(currDim); |
| 393 | transposeResults.insert(transposeResults.begin(), targetExpr); |
| 394 | } else { |
| 395 | perm.push_back(i); |
| 396 | auto targetExpr = rewriter.getAffineDimExpr(currDim); |
| 397 | transposeResults.push_back(targetExpr); |
| 398 | } |
| 399 | } |
| 400 | |
| 401 | // Checks if only the outer, unit dimensions (of size 1) are permuted. |
| 402 | // Such transposes do not materially effect the underlying vector and can |
| 403 | // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32> |
| 404 | bool transposeNonOuterUnitDims = false; |
| 405 | auto operandShape = cast<ShapedType>(operands[it.index()].getType()); |
| 406 | for (auto [index, dim] : |
| 407 | llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) { |
| 408 | if (dim != static_cast<int64_t>(index) && |
| 409 | operandShape.getDimSize(index) != 1) { |
| 410 | transposeNonOuterUnitDims = true; |
| 411 | break; |
| 412 | } |
| 413 | } |
| 414 | |
| 415 | // Do the transpose now if needed so that we can drop the |
| 416 | // correct dim using extract later. |
| 417 | if (transposeNeeded) { |
| 418 | map = AffineMap::get(map.getNumDims(), 0, transposeResults, |
| 419 | contractOp.getContext()); |
| 420 | if (transposeNonOuterUnitDims) { |
| 421 | operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>( |
| 422 | loc, operands[it.index()], perm); |
| 423 | } |
| 424 | } |
| 425 | } |
| 426 | // We have taken care to have the dim to be dropped be |
| 427 | // the leading dim. If its still not leading that means it |
| 428 | // does not exist in this operand and hence we do not need |
| 429 | // an extract. |
| 430 | if (map.getDimPosition(0) == dimToDrop) |
| 431 | validExtract = true; |
| 432 | |
| 433 | for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { |
| 434 | int64_t currDim = map.getDimPosition(i); |
| 435 | if (currDim == dimToDrop) |
| 436 | // This is the dim we are dropping. |
| 437 | continue; |
| 438 | auto targetExpr = rewriter.getAffineDimExpr( |
| 439 | currDim < dimToDrop ? currDim : currDim - 1); |
| 440 | results.push_back(targetExpr); |
| 441 | } |
| 442 | newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, |
| 443 | contractOp.getContext())); |
| 444 | // Extract if its a valid extraction, otherwise use the operand |
| 445 | // without extraction. |
| 446 | newOperands.push_back( |
| 447 | validExtract ? rewriter.create<vector::ExtractOp>( |
| 448 | loc, operands[it.index()], splatZero(dropDim)) |
| 449 | : operands[it.index()]); |
| 450 | } |
| 451 | |
| 452 | // Depending on whether this vector.contract is masked, the replacing Op |
| 453 | // should either be a new vector.contract Op or vector.mask Op. |
| 454 | Operation *newOp = rewriter.create<vector::ContractionOp>( |
| 455 | loc, newOperands[0], newOperands[1], newOperands[2], |
| 456 | rewriter.getAffineMapArrayAttr(newIndexingMaps), |
| 457 | rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); |
| 458 | |
| 459 | if (maskingOp) { |
| 460 | auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(), |
| 461 | splatZero(dropDim)); |
| 462 | |
| 463 | newOp = mlir::vector::maskOperation(builder&: rewriter, maskableOp: newOp, mask: newMask); |
| 464 | } |
| 465 | |
| 466 | return rewriter |
| 467 | .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0], |
| 468 | newOp->getResults()[0]) |
| 469 | .getResult(); |
| 470 | } |
| 471 | |
| 472 | namespace { |
| 473 | |
| 474 | /// Turns vector.contract on vector with leading 1 dimensions into |
| 475 | /// vector.extract followed by vector.contract on vector without leading |
| 476 | /// 1 dimensions. Also performs transpose of lhs and rhs operands if required |
| 477 | /// prior to extract. |
| 478 | struct CastAwayContractionLeadingOneDim |
| 479 | : public MaskableOpRewritePattern<vector::ContractionOp> { |
| 480 | using MaskableOpRewritePattern::MaskableOpRewritePattern; |
| 481 | |
| 482 | FailureOr<Value> |
| 483 | matchAndRewriteMaskableOp(vector::ContractionOp contractOp, |
| 484 | MaskingOpInterface maskingOp, |
| 485 | PatternRewriter &rewriter) const override { |
| 486 | return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter); |
| 487 | } |
| 488 | }; |
| 489 | |
| 490 | /// Looks at elementwise operations on vectors with at least one leading |
| 491 | /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), |
| 492 | /// and cast aways the leading one dimensions (_plural_) and then broadcasts |
| 493 | /// the results. |
| 494 | /// |
| 495 | /// Example before: |
| 496 | /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> |
| 497 | /// Example after: |
| 498 | /// %2 = arith.mulf %0, %1 : vector<4x1xf32> |
| 499 | /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> |
| 500 | /// |
| 501 | /// Does support scalable vectors. |
| 502 | class CastAwayElementwiseLeadingOneDim : public RewritePattern { |
| 503 | public: |
| 504 | CastAwayElementwiseLeadingOneDim(MLIRContext *context, |
| 505 | PatternBenefit benefit = 1) |
| 506 | : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} |
| 507 | |
| 508 | LogicalResult matchAndRewrite(Operation *op, |
| 509 | PatternRewriter &rewriter) const override { |
| 510 | if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) |
| 511 | return failure(); |
| 512 | auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]); |
| 513 | if (!vecType) |
| 514 | return failure(); |
| 515 | VectorType newVecType = trimLeadingOneDims(vecType); |
| 516 | if (newVecType == vecType) |
| 517 | return failure(); |
| 518 | int64_t dropDim = vecType.getRank() - newVecType.getRank(); |
| 519 | SmallVector<Value, 4> newOperands; |
| 520 | for (Value operand : op->getOperands()) { |
| 521 | if (auto opVecType = dyn_cast<VectorType>(operand.getType())) { |
| 522 | newOperands.push_back(rewriter.create<vector::ExtractOp>( |
| 523 | op->getLoc(), operand, splatZero(dropDim))); |
| 524 | } else { |
| 525 | newOperands.push_back(Elt: operand); |
| 526 | } |
| 527 | } |
| 528 | Operation *newOp = |
| 529 | rewriter.create(op->getLoc(), op->getName().getIdentifier(), |
| 530 | newOperands, newVecType, op->getAttrs()); |
| 531 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType, |
| 532 | newOp->getResult(0)); |
| 533 | return success(); |
| 534 | } |
| 535 | }; |
| 536 | |
| 537 | // Drops leading 1 dimensions from vector.constant_mask and inserts a |
| 538 | // vector.broadcast back to the original shape. |
| 539 | struct CastAwayConstantMaskLeadingOneDim |
| 540 | : public OpRewritePattern<vector::ConstantMaskOp> { |
| 541 | using OpRewritePattern::OpRewritePattern; |
| 542 | |
| 543 | LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, |
| 544 | PatternRewriter &rewriter) const override { |
| 545 | VectorType oldType = mask.getType(); |
| 546 | VectorType newType = trimLeadingOneDims(oldType); |
| 547 | |
| 548 | if (newType == oldType) |
| 549 | return failure(); |
| 550 | |
| 551 | int64_t dropDim = oldType.getRank() - newType.getRank(); |
| 552 | ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes(); |
| 553 | |
| 554 | // If any of the dropped unit dims has a size of `0`, the entire mask is a |
| 555 | // zero mask, else the unit dim has no effect on the mask. |
| 556 | int64_t flatLeadingSize = |
| 557 | std::accumulate(first: dimSizes.begin(), last: dimSizes.begin() + dropDim + 1, |
| 558 | init: static_cast<int64_t>(1), binary_op: std::multiplies<int64_t>()); |
| 559 | SmallVector<int64_t> newDimSizes = {flatLeadingSize}; |
| 560 | newDimSizes.append(in_start: dimSizes.begin() + dropDim + 1, in_end: dimSizes.end()); |
| 561 | |
| 562 | auto newMask = rewriter.create<vector::ConstantMaskOp>( |
| 563 | mask.getLoc(), newType, newDimSizes); |
| 564 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask); |
| 565 | return success(); |
| 566 | } |
| 567 | }; |
| 568 | |
| 569 | } // namespace |
| 570 | |
| 571 | void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( |
| 572 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 573 | patterns |
| 574 | .add<CastAwayExtractStridedSliceLeadingOneDim, |
| 575 | CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, |
| 576 | CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim, |
| 577 | CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, |
| 578 | CastAwayContractionLeadingOneDim>(arg: patterns.getContext(), args&: benefit); |
| 579 | } |
| 580 | |