| 1 | //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements functions concerned with optimizing transfer_read and |
| 10 | // transfer_write ops. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 21 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 22 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 23 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 24 | #include "mlir/IR/Dominance.h" |
| 25 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 26 | #include "llvm/ADT/STLExtras.h" |
| 27 | #include "llvm/ADT/StringRef.h" |
| 28 | #include "llvm/Support/Debug.h" |
| 29 | |
| 30 | #define DEBUG_TYPE "vector-transfer-opt" |
| 31 | |
| 32 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 33 | |
| 34 | using namespace mlir; |
| 35 | |
| 36 | /// Return the ancestor op in the region or nullptr if the region is not |
| 37 | /// an ancestor of the op. |
| 38 | static Operation *findAncestorOpInRegion(Region *region, Operation *op) { |
| 39 | for (; op != nullptr && op->getParentRegion() != region; |
| 40 | op = op->getParentOp()) |
| 41 | ; |
| 42 | return op; |
| 43 | } |
| 44 | |
| 45 | namespace { |
| 46 | |
| 47 | class TransferOptimization { |
| 48 | public: |
| 49 | TransferOptimization(RewriterBase &rewriter, Operation *op) |
| 50 | : rewriter(rewriter), dominators(op), postDominators(op) {} |
| 51 | void deadStoreOp(vector::TransferWriteOp); |
| 52 | void storeToLoadForwarding(vector::TransferReadOp); |
| 53 | void removeDeadOp() { |
| 54 | for (Operation *op : opToErase) |
| 55 | rewriter.eraseOp(op); |
| 56 | opToErase.clear(); |
| 57 | } |
| 58 | |
| 59 | private: |
| 60 | RewriterBase &rewriter; |
| 61 | bool isReachable(Operation *start, Operation *dest); |
| 62 | DominanceInfo dominators; |
| 63 | PostDominanceInfo postDominators; |
| 64 | std::vector<Operation *> opToErase; |
| 65 | }; |
| 66 | |
| 67 | } // namespace |
| 68 | /// Return true if there is a path from start operation to dest operation, |
| 69 | /// otherwise return false. The operations have to be in the same region. |
| 70 | bool TransferOptimization::isReachable(Operation *start, Operation *dest) { |
| 71 | assert(start->getParentRegion() == dest->getParentRegion() && |
| 72 | "This function only works for ops i the same region" ); |
| 73 | // Simple case where the start op dominate the destination. |
| 74 | if (dominators.dominates(a: start, b: dest)) |
| 75 | return true; |
| 76 | return start->getBlock()->isReachable(other: dest->getBlock()); |
| 77 | } |
| 78 | |
| 79 | /// For transfer_write to overwrite fully another transfer_write must: |
| 80 | /// 1. Access the same memref with the same indices and vector type. |
| 81 | /// 2. Post-dominate the other transfer_write operation. |
| 82 | /// If several candidates are available, one must be post-dominated by all the |
| 83 | /// others since they are all post-dominating the same transfer_write. We only |
| 84 | /// consider the transfer_write post-dominated by all the other candidates as |
| 85 | /// this will be the first transfer_write executed after the potentially dead |
| 86 | /// transfer_write. |
| 87 | /// If we found such an overwriting transfer_write we know that the original |
| 88 | /// transfer_write is dead if all reads that can be reached from the potentially |
| 89 | /// dead transfer_write are dominated by the overwriting transfer_write. |
| 90 | void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { |
| 91 | LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() |
| 92 | << "\n" ); |
| 93 | llvm::SmallVector<Operation *, 8> blockingAccesses; |
| 94 | Operation *firstOverwriteCandidate = nullptr; |
| 95 | Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(write.getBase())); |
| 96 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
| 97 | source.getUsers().end()); |
| 98 | llvm::SmallDenseSet<Operation *, 32> processed; |
| 99 | while (!users.empty()) { |
| 100 | Operation *user = users.pop_back_val(); |
| 101 | // If the user has already been processed skip. |
| 102 | if (!processed.insert(V: user).second) |
| 103 | continue; |
| 104 | if (isa<ViewLikeOpInterface>(user)) { |
| 105 | users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end()); |
| 106 | continue; |
| 107 | } |
| 108 | if (isMemoryEffectFree(op: user)) |
| 109 | continue; |
| 110 | if (user == write.getOperation()) |
| 111 | continue; |
| 112 | if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { |
| 113 | // Check candidate that can override the store. |
| 114 | if (memref::isSameViewOrTrivialAlias( |
| 115 | a: cast<MemrefValue>(nextWrite.getBase()), |
| 116 | b: cast<MemrefValue>(write.getBase())) && |
| 117 | checkSameValueWAW(nextWrite, write) && |
| 118 | postDominators.postDominates(nextWrite, write)) { |
| 119 | if (firstOverwriteCandidate == nullptr || |
| 120 | postDominators.postDominates(firstOverwriteCandidate, nextWrite)) |
| 121 | firstOverwriteCandidate = nextWrite; |
| 122 | else |
| 123 | assert( |
| 124 | postDominators.postDominates(nextWrite, firstOverwriteCandidate)); |
| 125 | continue; |
| 126 | } |
| 127 | } |
| 128 | if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { |
| 129 | // Don't need to consider disjoint accesses. |
| 130 | if (vector::isDisjointTransferSet( |
| 131 | cast<VectorTransferOpInterface>(write.getOperation()), |
| 132 | cast<VectorTransferOpInterface>(transferOp.getOperation()), |
| 133 | /*testDynamicValueUsingBounds=*/true)) |
| 134 | continue; |
| 135 | } |
| 136 | blockingAccesses.push_back(Elt: user); |
| 137 | } |
| 138 | if (firstOverwriteCandidate == nullptr) |
| 139 | return; |
| 140 | Region *topRegion = firstOverwriteCandidate->getParentRegion(); |
| 141 | Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); |
| 142 | assert(writeAncestor && |
| 143 | "write op should be recursively part of the top region" ); |
| 144 | |
| 145 | for (Operation *access : blockingAccesses) { |
| 146 | Operation *accessAncestor = findAncestorOpInRegion(region: topRegion, op: access); |
| 147 | // TODO: if the access and write have the same ancestor we could recurse in |
| 148 | // the region to know if the access is reachable with more precision. |
| 149 | if (accessAncestor == nullptr || |
| 150 | !isReachable(start: writeAncestor, dest: accessAncestor)) |
| 151 | continue; |
| 152 | if (!dominators.dominates(a: firstOverwriteCandidate, b: accessAncestor)) { |
| 153 | LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " |
| 154 | << *accessAncestor << "\n" ); |
| 155 | return; |
| 156 | } |
| 157 | } |
| 158 | LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() |
| 159 | << " overwritten by: " << *firstOverwriteCandidate << "\n" ); |
| 160 | opToErase.push_back(write.getOperation()); |
| 161 | } |
| 162 | |
| 163 | /// A transfer_write candidate to storeToLoad forwarding must: |
| 164 | /// 1. Access the same memref with the same indices and vector type as the |
| 165 | /// transfer_read. |
| 166 | /// 2. Dominate the transfer_read operation. |
| 167 | /// If several candidates are available, one must be dominated by all the others |
| 168 | /// since they are all dominating the same transfer_read. We only consider the |
| 169 | /// transfer_write dominated by all the other candidates as this will be the |
| 170 | /// last transfer_write executed before the transfer_read. |
| 171 | /// If we found such a candidate we can do the forwarding if all the other |
| 172 | /// potentially aliasing ops that may reach the transfer_read are post-dominated |
| 173 | /// by the transfer_write. |
| 174 | void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { |
| 175 | if (read.hasOutOfBoundsDim()) |
| 176 | return; |
| 177 | LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() |
| 178 | << "\n" ); |
| 179 | SmallVector<Operation *, 8> blockingWrites; |
| 180 | vector::TransferWriteOp lastwrite = nullptr; |
| 181 | Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(read.getBase())); |
| 182 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
| 183 | source.getUsers().end()); |
| 184 | llvm::SmallDenseSet<Operation *, 32> processed; |
| 185 | while (!users.empty()) { |
| 186 | Operation *user = users.pop_back_val(); |
| 187 | // If the user has already been processed skip. |
| 188 | if (!processed.insert(V: user).second) |
| 189 | continue; |
| 190 | if (isa<ViewLikeOpInterface>(user)) { |
| 191 | users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end()); |
| 192 | continue; |
| 193 | } |
| 194 | if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user)) |
| 195 | continue; |
| 196 | if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { |
| 197 | // If there is a write, but we can prove that it is disjoint we can ignore |
| 198 | // the write. |
| 199 | if (vector::isDisjointTransferSet( |
| 200 | cast<VectorTransferOpInterface>(write.getOperation()), |
| 201 | cast<VectorTransferOpInterface>(read.getOperation()), |
| 202 | /*testDynamicValueUsingBounds=*/true)) |
| 203 | continue; |
| 204 | if (memref::isSameViewOrTrivialAlias( |
| 205 | a: cast<MemrefValue>(read.getBase()), |
| 206 | b: cast<MemrefValue>(write.getBase())) && |
| 207 | dominators.dominates(write, read) && checkSameValueRAW(write, read)) { |
| 208 | if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) |
| 209 | lastwrite = write; |
| 210 | else |
| 211 | assert(dominators.dominates(write, lastwrite)); |
| 212 | continue; |
| 213 | } |
| 214 | } |
| 215 | blockingWrites.push_back(Elt: user); |
| 216 | } |
| 217 | |
| 218 | if (lastwrite == nullptr) |
| 219 | return; |
| 220 | |
| 221 | Region *topRegion = lastwrite->getParentRegion(); |
| 222 | Operation *readAncestor = findAncestorOpInRegion(topRegion, read); |
| 223 | assert(readAncestor && |
| 224 | "read op should be recursively part of the top region" ); |
| 225 | |
| 226 | for (Operation *write : blockingWrites) { |
| 227 | Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write); |
| 228 | // TODO: if the store and read have the same ancestor we could recurse in |
| 229 | // the region to know if the read is reachable with more precision. |
| 230 | if (writeAncestor == nullptr || !isReachable(start: writeAncestor, dest: readAncestor)) |
| 231 | continue; |
| 232 | if (!postDominators.postDominates(lastwrite, write)) { |
| 233 | LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " |
| 234 | << *write << "\n" ); |
| 235 | return; |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() |
| 240 | << " to: " << *read.getOperation() << "\n" ); |
| 241 | read.replaceAllUsesWith(lastwrite.getVector()); |
| 242 | opToErase.push_back(read.getOperation()); |
| 243 | } |
| 244 | |
| 245 | /// Converts OpFoldResults to int64_t shape without unit dims. |
| 246 | static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { |
| 247 | SmallVector<int64_t> reducedShape; |
| 248 | for (const auto size : mixedSizes) { |
| 249 | if (llvm::dyn_cast_if_present<Value>(Val: size)) { |
| 250 | reducedShape.push_back(ShapedType::kDynamic); |
| 251 | continue; |
| 252 | } |
| 253 | |
| 254 | auto value = cast<IntegerAttr>(cast<Attribute>(Val: size)).getValue(); |
| 255 | if (value == 1) |
| 256 | continue; |
| 257 | reducedShape.push_back(Elt: value.getSExtValue()); |
| 258 | } |
| 259 | return reducedShape; |
| 260 | } |
| 261 | |
| 262 | /// Drops unit dimensions from the input MemRefType. |
| 263 | static MemRefType dropUnitDims(MemRefType inputType, |
| 264 | ArrayRef<OpFoldResult> offsets, |
| 265 | ArrayRef<OpFoldResult> sizes, |
| 266 | ArrayRef<OpFoldResult> strides) { |
| 267 | auto targetShape = getReducedShape(mixedSizes: sizes); |
| 268 | MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType( |
| 269 | targetShape, inputType, offsets, sizes, strides); |
| 270 | return rankReducedType.canonicalizeStridedLayout(); |
| 271 | } |
| 272 | |
| 273 | /// Creates a rank-reducing memref.subview op that drops unit dims from its |
| 274 | /// input. Or just returns the input if it was already without unit dims. |
| 275 | static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, |
| 276 | mlir::Location loc, |
| 277 | Value input) { |
| 278 | MemRefType inputType = cast<MemRefType>(input.getType()); |
| 279 | SmallVector<OpFoldResult> offsets(inputType.getRank(), |
| 280 | rewriter.getIndexAttr(0)); |
| 281 | SmallVector<OpFoldResult> sizes = memref::getMixedSizes(builder&: rewriter, loc, value: input); |
| 282 | SmallVector<OpFoldResult> strides(inputType.getRank(), |
| 283 | rewriter.getIndexAttr(1)); |
| 284 | MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); |
| 285 | |
| 286 | if (resultType.canonicalizeStridedLayout() == |
| 287 | inputType.canonicalizeStridedLayout()) |
| 288 | return input; |
| 289 | return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, |
| 290 | sizes, strides); |
| 291 | } |
| 292 | |
| 293 | /// Returns the number of dims that aren't unit dims. |
| 294 | static int getReducedRank(ArrayRef<int64_t> shape) { |
| 295 | return llvm::count_if(Range&: shape, P: [](int64_t dimSize) { return dimSize != 1; }); |
| 296 | } |
| 297 | |
| 298 | /// Trims non-scalable one dimensions from `oldType` and returns the result |
| 299 | /// type. |
| 300 | static VectorType trimNonScalableUnitDims(VectorType oldType) { |
| 301 | SmallVector<int64_t> newShape; |
| 302 | SmallVector<bool> newScalableDims; |
| 303 | for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { |
| 304 | if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) |
| 305 | continue; |
| 306 | newShape.push_back(dimSize); |
| 307 | newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); |
| 308 | } |
| 309 | return VectorType::get(newShape, oldType.getElementType(), newScalableDims); |
| 310 | } |
| 311 | |
| 312 | // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. |
| 313 | static FailureOr<Value> |
| 314 | createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, |
| 315 | vector::CreateMaskOp op) { |
| 316 | auto type = op.getType(); |
| 317 | VectorType reducedType = trimNonScalableUnitDims(type); |
| 318 | if (reducedType.getRank() == type.getRank()) |
| 319 | return failure(); |
| 320 | |
| 321 | SmallVector<Value> reducedOperands; |
| 322 | for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( |
| 323 | type.getShape(), type.getScalableDims(), op.getOperands())) { |
| 324 | if (dim == 1 && !dimIsScalable) { |
| 325 | // If the mask for the unit dim is not a constant of 1, do nothing. |
| 326 | auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); |
| 327 | if (!constant || (constant.value() != 1)) |
| 328 | return failure(); |
| 329 | continue; |
| 330 | } |
| 331 | reducedOperands.push_back(operand); |
| 332 | } |
| 333 | return rewriter |
| 334 | .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) |
| 335 | .getResult(); |
| 336 | } |
| 337 | |
| 338 | namespace { |
| 339 | |
| 340 | /// Rewrites `vector.transfer_read` ops where the source has unit dims, by |
| 341 | /// inserting a memref.subview dropping those unit dims. The vector shapes are |
| 342 | /// also reduced accordingly. |
| 343 | class TransferReadDropUnitDimsPattern |
| 344 | : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { |
| 345 | using MaskableOpRewritePattern::MaskableOpRewritePattern; |
| 346 | |
| 347 | FailureOr<Value> |
| 348 | matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp, |
| 349 | vector::MaskingOpInterface maskingOp, |
| 350 | PatternRewriter &rewriter) const override { |
| 351 | auto loc = transferReadOp.getLoc(); |
| 352 | Value vector = transferReadOp.getVector(); |
| 353 | VectorType vectorType = cast<VectorType>(vector.getType()); |
| 354 | Value source = transferReadOp.getBase(); |
| 355 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
| 356 | // TODO: support tensor types. |
| 357 | if (!sourceType) |
| 358 | return failure(); |
| 359 | // TODO: generalize this pattern, relax the requirements here. |
| 360 | if (transferReadOp.hasOutOfBoundsDim()) |
| 361 | return failure(); |
| 362 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
| 363 | return failure(); |
| 364 | // Check if the source shape can be further reduced. |
| 365 | int reducedRank = getReducedRank(sourceType.getShape()); |
| 366 | if (reducedRank == sourceType.getRank()) |
| 367 | return failure(); |
| 368 | // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail |
| 369 | // out. |
| 370 | if (reducedRank == 0 && maskingOp) |
| 371 | return failure(); |
| 372 | // Check if the reduced vector shape matches the reduced source shape. |
| 373 | // Otherwise, this case is not supported yet. |
| 374 | VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); |
| 375 | if (reducedRank != reducedVectorType.getRank()) |
| 376 | return failure(); |
| 377 | if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { |
| 378 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
| 379 | })) |
| 380 | return failure(); |
| 381 | |
| 382 | Value maskOp = transferReadOp.getMask(); |
| 383 | if (maskOp) { |
| 384 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
| 385 | if (!createMaskOp) |
| 386 | return rewriter.notifyMatchFailure( |
| 387 | transferReadOp, "unsupported mask op, only 'vector.create_mask' is " |
| 388 | "currently supported" ); |
| 389 | FailureOr<Value> rankReducedCreateMask = |
| 390 | createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); |
| 391 | if (failed(Result: rankReducedCreateMask)) |
| 392 | return failure(); |
| 393 | maskOp = *rankReducedCreateMask; |
| 394 | } |
| 395 | |
| 396 | Value reducedShapeSource = |
| 397 | rankReducingSubviewDroppingUnitDims(rewriter, loc, source); |
| 398 | Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 399 | SmallVector<Value> zeros(reducedRank, c0); |
| 400 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
| 401 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
| 402 | Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>( |
| 403 | loc, reducedVectorType, reducedShapeSource, zeros, identityMap, |
| 404 | transferReadOp.getPadding(), maskOp, |
| 405 | rewriter.getBoolArrayAttr(inBounds)); |
| 406 | |
| 407 | if (maskingOp) { |
| 408 | auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( |
| 409 | loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), |
| 410 | maskingOp.getMask()); |
| 411 | newTransferReadOp = mlir::vector::maskOperation( |
| 412 | builder&: rewriter, maskableOp: newTransferReadOp, mask: shapeCastMask); |
| 413 | } |
| 414 | |
| 415 | auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( |
| 416 | loc, vectorType, newTransferReadOp->getResults()[0]); |
| 417 | |
| 418 | return shapeCast; |
| 419 | } |
| 420 | }; |
| 421 | |
| 422 | /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) |
| 423 | /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The |
| 424 | /// vector shapes are also reduced accordingly. |
| 425 | class TransferWriteDropUnitDimsPattern |
| 426 | : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> { |
| 427 | using MaskableOpRewritePattern::MaskableOpRewritePattern; |
| 428 | |
| 429 | FailureOr<Value> |
| 430 | matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp, |
| 431 | vector::MaskingOpInterface maskingOp, |
| 432 | PatternRewriter &rewriter) const override { |
| 433 | auto loc = transferWriteOp.getLoc(); |
| 434 | Value vector = transferWriteOp.getVector(); |
| 435 | VectorType vectorType = cast<VectorType>(vector.getType()); |
| 436 | Value source = transferWriteOp.getBase(); |
| 437 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
| 438 | // TODO: support tensor type. |
| 439 | if (!sourceType) |
| 440 | return failure(); |
| 441 | // TODO: generalize this pattern, relax the requirements here. |
| 442 | if (transferWriteOp.hasOutOfBoundsDim()) |
| 443 | return failure(); |
| 444 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
| 445 | return failure(); |
| 446 | // Check if the destination shape can be further reduced. |
| 447 | int reducedRank = getReducedRank(sourceType.getShape()); |
| 448 | if (reducedRank == sourceType.getRank()) |
| 449 | return failure(); |
| 450 | // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail |
| 451 | // out. |
| 452 | if (reducedRank == 0 && maskingOp) |
| 453 | return failure(); |
| 454 | // Check if the reduced vector shape matches the reduced destination shape. |
| 455 | // Otherwise, this case is not supported yet. |
| 456 | VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); |
| 457 | if (reducedRank != reducedVectorType.getRank()) |
| 458 | return failure(); |
| 459 | if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { |
| 460 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
| 461 | })) |
| 462 | return failure(); |
| 463 | |
| 464 | Value maskOp = transferWriteOp.getMask(); |
| 465 | if (maskOp) { |
| 466 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
| 467 | if (!createMaskOp) |
| 468 | return rewriter.notifyMatchFailure( |
| 469 | transferWriteOp, |
| 470 | "unsupported mask op, only 'vector.create_mask' is " |
| 471 | "currently supported" ); |
| 472 | FailureOr<Value> rankReducedCreateMask = |
| 473 | createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); |
| 474 | if (failed(Result: rankReducedCreateMask)) |
| 475 | return failure(); |
| 476 | maskOp = *rankReducedCreateMask; |
| 477 | } |
| 478 | Value reducedShapeSource = |
| 479 | rankReducingSubviewDroppingUnitDims(rewriter, loc, source); |
| 480 | Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 481 | SmallVector<Value> zeros(reducedRank, c0); |
| 482 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
| 483 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
| 484 | auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>( |
| 485 | loc, reducedVectorType, vector); |
| 486 | Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>( |
| 487 | loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, |
| 488 | maskOp, rewriter.getBoolArrayAttr(inBounds)); |
| 489 | |
| 490 | if (maskingOp) { |
| 491 | auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( |
| 492 | loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), |
| 493 | maskingOp.getMask()); |
| 494 | newXferWrite = |
| 495 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: newXferWrite, mask: shapeCastMask); |
| 496 | } |
| 497 | |
| 498 | if (transferWriteOp.hasPureTensorSemantics()) |
| 499 | return newXferWrite->getResults()[0]; |
| 500 | |
| 501 | // With Memref semantics, there's no return value. Use empty value to signal |
| 502 | // success. |
| 503 | return Value(); |
| 504 | } |
| 505 | }; |
| 506 | |
| 507 | } // namespace |
| 508 | |
| 509 | /// Creates a memref.collapse_shape collapsing all inner dimensions of the |
| 510 | /// input starting at `firstDimToCollapse`. |
| 511 | static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, |
| 512 | Value input, int64_t firstDimToCollapse) { |
| 513 | ShapedType inputType = cast<ShapedType>(input.getType()); |
| 514 | if (inputType.getRank() == 1) |
| 515 | return input; |
| 516 | SmallVector<ReassociationIndices> reassociation; |
| 517 | for (int64_t i = 0; i < firstDimToCollapse; ++i) |
| 518 | reassociation.push_back(Elt: ReassociationIndices{i}); |
| 519 | ReassociationIndices collapsedIndices; |
| 520 | for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) |
| 521 | collapsedIndices.push_back(Elt: i); |
| 522 | reassociation.push_back(Elt: collapsedIndices); |
| 523 | return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); |
| 524 | } |
| 525 | |
| 526 | /// Returns the new indices that collapses the inner dimensions starting from |
| 527 | /// the `firstDimToCollapse` dimension. |
| 528 | static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter, |
| 529 | Location loc, |
| 530 | ArrayRef<int64_t> shape, |
| 531 | ValueRange indices, |
| 532 | int64_t firstDimToCollapse) { |
| 533 | assert(firstDimToCollapse < static_cast<int64_t>(indices.size())); |
| 534 | |
| 535 | // If all the collapsed indices are zero then no extra logic is needed. |
| 536 | // Otherwise, a new offset/index has to be computed. |
| 537 | SmallVector<Value> indicesAfterCollapsing( |
| 538 | indices.begin(), indices.begin() + firstDimToCollapse); |
| 539 | SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, |
| 540 | indices.end()); |
| 541 | if (llvm::all_of(Range&: indicesToCollapse, P: isZeroInteger)) { |
| 542 | indicesAfterCollapsing.push_back(Elt: indicesToCollapse[0]); |
| 543 | return indicesAfterCollapsing; |
| 544 | } |
| 545 | |
| 546 | // Compute the remaining trailing index/offset required for reading from |
| 547 | // the collapsed memref: |
| 548 | // |
| 549 | // offset = 0 |
| 550 | // for (i = firstDimToCollapse; i < outputRank; ++i) |
| 551 | // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] |
| 552 | // |
| 553 | // For this example: |
| 554 | // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : |
| 555 | // memref<1x43x2xi32>, vector<1x2xi32> |
| 556 | // which would be collapsed to: |
| 557 | // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : |
| 558 | // memref<1x86xi32>, vector<2xi32> |
| 559 | // one would get the following offset: |
| 560 | // %offset = %arg0 * 43 |
| 561 | OpFoldResult collapsedOffset = |
| 562 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult(); |
| 563 | |
| 564 | auto collapsedStrides = computeSuffixProduct( |
| 565 | sizes: ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end())); |
| 566 | |
| 567 | // Compute the collapsed offset. |
| 568 | auto &&[collapsedExpr, collapsedVals] = |
| 569 | computeLinearIndex(sourceOffset: collapsedOffset, strides: collapsedStrides, indices: indicesToCollapse); |
| 570 | collapsedOffset = affine::makeComposedFoldedAffineApply( |
| 571 | rewriter, loc, collapsedExpr, collapsedVals); |
| 572 | |
| 573 | if (auto value = dyn_cast<Value>(collapsedOffset)) { |
| 574 | indicesAfterCollapsing.push_back(Elt: value); |
| 575 | } else { |
| 576 | indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>( |
| 577 | location: loc, args: *getConstantIntValue(ofr: collapsedOffset))); |
| 578 | } |
| 579 | |
| 580 | return indicesAfterCollapsing; |
| 581 | } |
| 582 | |
| 583 | namespace { |
| 584 | |
| 585 | /// Rewrites contiguous row-major vector.transfer_read ops by inserting |
| 586 | /// memref.collapse_shape on the source so that the resulting |
| 587 | /// vector.transfer_read has a 1D source. Requires the source shape to be |
| 588 | /// already reduced i.e. without unit dims. |
| 589 | /// |
| 590 | /// If `targetVectorBitwidth` is provided, the flattening will only happen if |
| 591 | /// the trailing dimension of the vector read is smaller than the provided |
| 592 | /// bitwidth. |
| 593 | class FlattenContiguousRowMajorTransferReadPattern |
| 594 | : public OpRewritePattern<vector::TransferReadOp> { |
| 595 | public: |
| 596 | FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, |
| 597 | unsigned vectorBitwidth, |
| 598 | PatternBenefit benefit) |
| 599 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
| 600 | targetVectorBitwidth(vectorBitwidth) {} |
| 601 | |
| 602 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
| 603 | PatternRewriter &rewriter) const override { |
| 604 | auto loc = transferReadOp.getLoc(); |
| 605 | Value vector = transferReadOp.getVector(); |
| 606 | VectorType vectorType = cast<VectorType>(vector.getType()); |
| 607 | auto source = transferReadOp.getBase(); |
| 608 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
| 609 | |
| 610 | // 0. Check pre-conditions |
| 611 | // Contiguity check is valid on tensors only. |
| 612 | if (!sourceType) |
| 613 | return failure(); |
| 614 | // If this is already 0D/1D, there's nothing to do. |
| 615 | if (vectorType.getRank() <= 1) |
| 616 | return failure(); |
| 617 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
| 618 | return failure(); |
| 619 | unsigned trailingVectorDimBitwidth = |
| 620 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
| 621 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
| 622 | return failure(); |
| 623 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType)) |
| 624 | return failure(); |
| 625 | // TODO: generalize this pattern, relax the requirements here. |
| 626 | if (transferReadOp.hasOutOfBoundsDim()) |
| 627 | return failure(); |
| 628 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
| 629 | return failure(); |
| 630 | if (transferReadOp.getMask()) |
| 631 | return failure(); |
| 632 | |
| 633 | int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); |
| 634 | |
| 635 | // 1. Collapse the source memref |
| 636 | Value collapsedSource = |
| 637 | collapseInnerDims(rewriter, loc, source, firstDimToCollapse); |
| 638 | MemRefType collapsedSourceType = |
| 639 | cast<MemRefType>(collapsedSource.getType()); |
| 640 | int64_t collapsedRank = collapsedSourceType.getRank(); |
| 641 | assert(collapsedRank == firstDimToCollapse + 1); |
| 642 | |
| 643 | // 2. Generate input args for a new vector.transfer_read that will read |
| 644 | // from the collapsed memref. |
| 645 | // 2.1. New dim exprs + affine map |
| 646 | SmallVector<AffineExpr, 1> dimExprs{ |
| 647 | getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())}; |
| 648 | auto collapsedMap = |
| 649 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
| 650 | |
| 651 | // 2.2 New indices |
| 652 | SmallVector<Value> collapsedIndices = |
| 653 | getCollapsedIndices(rewriter, loc, sourceType.getShape(), |
| 654 | transferReadOp.getIndices(), firstDimToCollapse); |
| 655 | |
| 656 | // 3. Create new vector.transfer_read that reads from the collapsed memref |
| 657 | VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, |
| 658 | vectorType.getElementType()); |
| 659 | vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( |
| 660 | loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); |
| 661 | flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); |
| 662 | |
| 663 | // 4. Replace the old transfer_read with the new one reading from the |
| 664 | // collapsed shape |
| 665 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
| 666 | transferReadOp, cast<VectorType>(vector.getType()), flatRead); |
| 667 | return success(); |
| 668 | } |
| 669 | |
| 670 | private: |
| 671 | // Minimum bitwidth that the trailing vector dimension should have after |
| 672 | // flattening. |
| 673 | unsigned targetVectorBitwidth; |
| 674 | }; |
| 675 | |
| 676 | /// Rewrites contiguous row-major vector.transfer_write ops by inserting |
| 677 | /// memref.collapse_shape on the source so that the resulting |
| 678 | /// vector.transfer_write has a 1D source. Requires the source shape to be |
| 679 | /// already reduced i.e. without unit dims. |
| 680 | /// |
| 681 | /// If `targetVectorBitwidth` is provided, the flattening will only happen if |
| 682 | /// the trailing dimension of the vector read is smaller than the provided |
| 683 | /// bitwidth. |
| 684 | class FlattenContiguousRowMajorTransferWritePattern |
| 685 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 686 | public: |
| 687 | FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, |
| 688 | unsigned vectorBitwidth, |
| 689 | PatternBenefit benefit) |
| 690 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
| 691 | targetVectorBitwidth(vectorBitwidth) {} |
| 692 | |
| 693 | LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, |
| 694 | PatternRewriter &rewriter) const override { |
| 695 | auto loc = transferWriteOp.getLoc(); |
| 696 | Value vector = transferWriteOp.getVector(); |
| 697 | VectorType vectorType = cast<VectorType>(vector.getType()); |
| 698 | Value source = transferWriteOp.getBase(); |
| 699 | MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); |
| 700 | |
| 701 | // 0. Check pre-conditions |
| 702 | // Contiguity check is valid on tensors only. |
| 703 | if (!sourceType) |
| 704 | return failure(); |
| 705 | // If this is already 0D/1D, there's nothing to do. |
| 706 | if (vectorType.getRank() <= 1) |
| 707 | // Already 0D/1D, nothing to do. |
| 708 | return failure(); |
| 709 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
| 710 | return failure(); |
| 711 | unsigned trailingVectorDimBitwidth = |
| 712 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
| 713 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
| 714 | return failure(); |
| 715 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType)) |
| 716 | return failure(); |
| 717 | // TODO: generalize this pattern, relax the requirements here. |
| 718 | if (transferWriteOp.hasOutOfBoundsDim()) |
| 719 | return failure(); |
| 720 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
| 721 | return failure(); |
| 722 | if (transferWriteOp.getMask()) |
| 723 | return failure(); |
| 724 | |
| 725 | int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); |
| 726 | |
| 727 | // 1. Collapse the source memref |
| 728 | Value collapsedSource = |
| 729 | collapseInnerDims(rewriter, loc, source, firstDimToCollapse); |
| 730 | MemRefType collapsedSourceType = |
| 731 | cast<MemRefType>(collapsedSource.getType()); |
| 732 | int64_t collapsedRank = collapsedSourceType.getRank(); |
| 733 | assert(collapsedRank == firstDimToCollapse + 1); |
| 734 | |
| 735 | // 2. Generate input args for a new vector.transfer_read that will read |
| 736 | // from the collapsed memref. |
| 737 | // 2.1. New dim exprs + affine map |
| 738 | SmallVector<AffineExpr, 1> dimExprs{ |
| 739 | getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())}; |
| 740 | auto collapsedMap = |
| 741 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
| 742 | |
| 743 | // 2.2 New indices |
| 744 | SmallVector<Value> collapsedIndices = |
| 745 | getCollapsedIndices(rewriter, loc, sourceType.getShape(), |
| 746 | transferWriteOp.getIndices(), firstDimToCollapse); |
| 747 | |
| 748 | // 3. Create new vector.transfer_write that writes to the collapsed memref |
| 749 | VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, |
| 750 | vectorType.getElementType()); |
| 751 | Value flatVector = |
| 752 | rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); |
| 753 | vector::TransferWriteOp flatWrite = |
| 754 | rewriter.create<vector::TransferWriteOp>( |
| 755 | loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); |
| 756 | flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); |
| 757 | |
| 758 | // 4. Replace the old transfer_write with the new one writing the |
| 759 | // collapsed shape |
| 760 | rewriter.eraseOp(op: transferWriteOp); |
| 761 | return success(); |
| 762 | } |
| 763 | |
| 764 | private: |
| 765 | // Minimum bitwidth that the trailing vector dimension should have after |
| 766 | // flattening. |
| 767 | unsigned targetVectorBitwidth; |
| 768 | }; |
| 769 | |
| 770 | /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` |
| 771 | /// to `memref.load` patterns. The `match` method is shared for both |
| 772 | /// `vector.extract` and `vector.extract_element`. |
| 773 | template <class VectorExtractOp> |
| 774 | class |
| 775 | : public OpRewritePattern<VectorExtractOp> { |
| 776 | using = OpRewritePattern<VectorExtractOp>; |
| 777 | |
| 778 | public: |
| 779 | (MLIRContext *context, |
| 780 | PatternBenefit benefit, |
| 781 | bool allowMultipleUses) |
| 782 | : Base(context, benefit), allowMultipleUses(allowMultipleUses) {} |
| 783 | |
| 784 | LogicalResult (VectorExtractOp ) const { |
| 785 | auto xferOp = |
| 786 | extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); |
| 787 | if (!xferOp) |
| 788 | return failure(); |
| 789 | // Check that we are extracting a scalar and not a sub-vector. |
| 790 | if (isa<VectorType>(extractOp.getResult().getType())) |
| 791 | return failure(); |
| 792 | // If multiple uses are not allowed, check if xfer has a single use. |
| 793 | if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) |
| 794 | return failure(); |
| 795 | // If multiple uses are allowed, check if all the xfer uses are extract ops. |
| 796 | if (allowMultipleUses && |
| 797 | !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { |
| 798 | return isa<vector::ExtractOp, vector::ExtractElementOp>( |
| 799 | use.getOwner()); |
| 800 | })) |
| 801 | return failure(); |
| 802 | // Mask not supported. |
| 803 | if (xferOp.getMask()) |
| 804 | return failure(); |
| 805 | // Map not supported. |
| 806 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
| 807 | return failure(); |
| 808 | // Cannot rewrite if the indices may be out of bounds. |
| 809 | if (xferOp.hasOutOfBoundsDim()) |
| 810 | return failure(); |
| 811 | return success(); |
| 812 | } |
| 813 | |
| 814 | private: |
| 815 | bool ; |
| 816 | }; |
| 817 | |
| 818 | /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. |
| 819 | /// |
| 820 | /// All the users of the transfer op must be either `vector.extractelement` or |
| 821 | /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite |
| 822 | /// transfer ops with any number of users. Otherwise, rewrite only if the |
| 823 | /// extract op is the single user of the transfer op. Rewriting a single |
| 824 | /// vector load with multiple scalar loads may negatively affect performance. |
| 825 | class |
| 826 | : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { |
| 827 | using RewriteScalarExtractOfTransferReadBase:: |
| 828 | RewriteScalarExtractOfTransferReadBase; |
| 829 | |
| 830 | LogicalResult matchAndRewrite(vector::ExtractElementOp , |
| 831 | PatternRewriter &rewriter) const override { |
| 832 | if (failed(match(extractOp))) |
| 833 | return failure(); |
| 834 | |
| 835 | // Construct scalar load. |
| 836 | auto loc = extractOp.getLoc(); |
| 837 | auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); |
| 838 | SmallVector<Value> newIndices(xferOp.getIndices().begin(), |
| 839 | xferOp.getIndices().end()); |
| 840 | if (extractOp.getPosition()) { |
| 841 | AffineExpr sym0, sym1; |
| 842 | bindSymbols(extractOp.getContext(), sym0, sym1); |
| 843 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
| 844 | rewriter, loc, sym0 + sym1, |
| 845 | {newIndices[newIndices.size() - 1], extractOp.getPosition()}); |
| 846 | if (auto value = dyn_cast<Value>(ofr)) { |
| 847 | newIndices[newIndices.size() - 1] = value; |
| 848 | } else { |
| 849 | newIndices[newIndices.size() - 1] = |
| 850 | rewriter.create<arith::ConstantIndexOp>(loc, |
| 851 | *getConstantIntValue(ofr)); |
| 852 | } |
| 853 | } |
| 854 | if (isa<MemRefType>(xferOp.getBase().getType())) { |
| 855 | rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(), |
| 856 | newIndices); |
| 857 | } else { |
| 858 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
| 859 | extractOp, xferOp.getBase(), newIndices); |
| 860 | } |
| 861 | |
| 862 | return success(); |
| 863 | } |
| 864 | }; |
| 865 | |
| 866 | /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. |
| 867 | /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. |
| 868 | /// |
| 869 | /// All the users of the transfer op must be either `vector.extractelement` or |
| 870 | /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite |
| 871 | /// transfer ops with any number of users. Otherwise, rewrite only if the |
| 872 | /// extract op is the single user of the transfer op. Rewriting a single |
| 873 | /// vector load with multiple scalar loads may negatively affect performance. |
| 874 | class |
| 875 | : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { |
| 876 | using RewriteScalarExtractOfTransferReadBase:: |
| 877 | RewriteScalarExtractOfTransferReadBase; |
| 878 | |
| 879 | LogicalResult matchAndRewrite(vector::ExtractOp , |
| 880 | PatternRewriter &rewriter) const override { |
| 881 | if (failed(match(extractOp))) |
| 882 | return failure(); |
| 883 | |
| 884 | // Construct scalar load. |
| 885 | auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); |
| 886 | SmallVector<Value> newIndices(xferOp.getIndices().begin(), |
| 887 | xferOp.getIndices().end()); |
| 888 | for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { |
| 889 | assert(isa<Attribute>(pos) && "Unexpected non-constant index" ); |
| 890 | int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt(); |
| 891 | int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; |
| 892 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
| 893 | rewriter, extractOp.getLoc(), |
| 894 | rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); |
| 895 | if (auto value = dyn_cast<Value>(ofr)) { |
| 896 | newIndices[idx] = value; |
| 897 | } else { |
| 898 | newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( |
| 899 | extractOp.getLoc(), *getConstantIntValue(ofr)); |
| 900 | } |
| 901 | } |
| 902 | if (isa<MemRefType>(xferOp.getBase().getType())) { |
| 903 | rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(), |
| 904 | newIndices); |
| 905 | } else { |
| 906 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
| 907 | extractOp, xferOp.getBase(), newIndices); |
| 908 | } |
| 909 | |
| 910 | return success(); |
| 911 | } |
| 912 | }; |
| 913 | |
| 914 | /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) |
| 915 | /// to memref.store. |
| 916 | class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { |
| 917 | using OpRewritePattern::OpRewritePattern; |
| 918 | |
| 919 | LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, |
| 920 | PatternRewriter &rewriter) const override { |
| 921 | // Must be a scalar write. |
| 922 | auto vecType = xferOp.getVectorType(); |
| 923 | if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) |
| 924 | return failure(); |
| 925 | // Mask not supported. |
| 926 | if (xferOp.getMask()) |
| 927 | return failure(); |
| 928 | // Map not supported. |
| 929 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
| 930 | return failure(); |
| 931 | // Only float and integer element types are supported. |
| 932 | Value scalar = |
| 933 | rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector()); |
| 934 | // Construct a scalar store. |
| 935 | if (isa<MemRefType>(xferOp.getBase().getType())) { |
| 936 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
| 937 | xferOp, scalar, xferOp.getBase(), xferOp.getIndices()); |
| 938 | } else { |
| 939 | rewriter.replaceOpWithNewOp<tensor::InsertOp>( |
| 940 | xferOp, scalar, xferOp.getBase(), xferOp.getIndices()); |
| 941 | } |
| 942 | return success(); |
| 943 | } |
| 944 | }; |
| 945 | |
| 946 | } // namespace |
| 947 | |
| 948 | void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, |
| 949 | Operation *rootOp) { |
| 950 | TransferOptimization opt(rewriter, rootOp); |
| 951 | // Run store to load forwarding first since it can expose more dead store |
| 952 | // opportunity. |
| 953 | rootOp->walk([&](vector::TransferReadOp read) { |
| 954 | if (isa<MemRefType>(read.getShapedType())) |
| 955 | opt.storeToLoadForwarding(read: read); |
| 956 | }); |
| 957 | opt.removeDeadOp(); |
| 958 | rootOp->walk([&](vector::TransferWriteOp write) { |
| 959 | if (isa<MemRefType>(write.getShapedType())) |
| 960 | opt.deadStoreOp(write: write); |
| 961 | }); |
| 962 | opt.removeDeadOp(); |
| 963 | } |
| 964 | |
| 965 | void mlir::vector::populateScalarVectorTransferLoweringPatterns( |
| 966 | RewritePatternSet &patterns, PatternBenefit benefit, |
| 967 | bool allowMultipleUses) { |
| 968 | patterns.add<RewriteScalarExtractElementOfTransferRead, |
| 969 | RewriteScalarExtractOfTransferRead>(arg: patterns.getContext(), |
| 970 | args&: benefit, args&: allowMultipleUses); |
| 971 | patterns.add<RewriteScalarWrite>(arg: patterns.getContext(), args&: benefit); |
| 972 | } |
| 973 | |
| 974 | void mlir::vector::populateVectorTransferDropUnitDimsPatterns( |
| 975 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 976 | patterns |
| 977 | .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( |
| 978 | arg: patterns.getContext(), args&: benefit); |
| 979 | } |
| 980 | |
| 981 | void mlir::vector::populateFlattenVectorTransferPatterns( |
| 982 | RewritePatternSet &patterns, unsigned targetVectorBitwidth, |
| 983 | PatternBenefit benefit) { |
| 984 | patterns.add<FlattenContiguousRowMajorTransferReadPattern, |
| 985 | FlattenContiguousRowMajorTransferWritePattern>( |
| 986 | arg: patterns.getContext(), args&: targetVectorBitwidth, args&: benefit); |
| 987 | populateDropUnitDimWithShapeCastPatterns(patterns, benefit); |
| 988 | } |
| 989 | |