| 1 | //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements functions concerned with optimizing transfer_read and |
| 10 | // transfer_write ops. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 17 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| 18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 19 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 20 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 21 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 22 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
| 23 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 24 | #include "mlir/IR/Dominance.h" |
| 25 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 26 | #include "llvm/ADT/STLExtras.h" |
| 27 | #include "llvm/ADT/StringRef.h" |
| 28 | #include "llvm/Support/Debug.h" |
| 29 | |
| 30 | #define DEBUG_TYPE "vector-transfer-opt" |
| 31 | |
| 32 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 33 | |
| 34 | using namespace mlir; |
| 35 | |
| 36 | /// Return the ancestor op in the region or nullptr if the region is not |
| 37 | /// an ancestor of the op. |
| 38 | static Operation *findAncestorOpInRegion(Region *region, Operation *op) { |
| 39 | for (; op != nullptr && op->getParentRegion() != region; |
| 40 | op = op->getParentOp()) |
| 41 | ; |
| 42 | return op; |
| 43 | } |
| 44 | |
| 45 | namespace { |
| 46 | |
| 47 | class TransferOptimization { |
| 48 | public: |
| 49 | TransferOptimization(RewriterBase &rewriter, Operation *op) |
| 50 | : rewriter(rewriter), dominators(op), postDominators(op) {} |
| 51 | void deadStoreOp(vector::TransferWriteOp); |
| 52 | void storeToLoadForwarding(vector::TransferReadOp); |
| 53 | void removeDeadOp() { |
| 54 | for (Operation *op : opToErase) |
| 55 | rewriter.eraseOp(op); |
| 56 | opToErase.clear(); |
| 57 | } |
| 58 | |
| 59 | private: |
| 60 | RewriterBase &rewriter; |
| 61 | bool isReachable(Operation *start, Operation *dest); |
| 62 | DominanceInfo dominators; |
| 63 | PostDominanceInfo postDominators; |
| 64 | std::vector<Operation *> opToErase; |
| 65 | }; |
| 66 | |
| 67 | } // namespace |
| 68 | /// Return true if there is a path from start operation to dest operation, |
| 69 | /// otherwise return false. The operations have to be in the same region. |
| 70 | bool TransferOptimization::isReachable(Operation *start, Operation *dest) { |
| 71 | assert(start->getParentRegion() == dest->getParentRegion() && |
| 72 | "This function only works for ops i the same region" ); |
| 73 | // Simple case where the start op dominate the destination. |
| 74 | if (dominators.dominates(a: start, b: dest)) |
| 75 | return true; |
| 76 | return start->getBlock()->isReachable(other: dest->getBlock()); |
| 77 | } |
| 78 | |
| 79 | /// For transfer_write to overwrite fully another transfer_write must: |
| 80 | /// 1. Access the same memref with the same indices and vector type. |
| 81 | /// 2. Post-dominate the other transfer_write operation. |
| 82 | /// If several candidates are available, one must be post-dominated by all the |
| 83 | /// others since they are all post-dominating the same transfer_write. We only |
| 84 | /// consider the transfer_write post-dominated by all the other candidates as |
| 85 | /// this will be the first transfer_write executed after the potentially dead |
| 86 | /// transfer_write. |
| 87 | /// If we found such an overwriting transfer_write we know that the original |
| 88 | /// transfer_write is dead if all reads that can be reached from the potentially |
| 89 | /// dead transfer_write are dominated by the overwriting transfer_write. |
| 90 | void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { |
| 91 | LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() |
| 92 | << "\n" ); |
| 93 | llvm::SmallVector<Operation *, 8> blockingAccesses; |
| 94 | Operation *firstOverwriteCandidate = nullptr; |
| 95 | Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(Val: write.getBase())); |
| 96 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
| 97 | source.getUsers().end()); |
| 98 | llvm::SmallDenseSet<Operation *, 32> processed; |
| 99 | while (!users.empty()) { |
| 100 | Operation *user = users.pop_back_val(); |
| 101 | // If the user has already been processed skip. |
| 102 | if (!processed.insert(V: user).second) |
| 103 | continue; |
| 104 | if (isa<ViewLikeOpInterface>(Val: user)) { |
| 105 | users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end()); |
| 106 | continue; |
| 107 | } |
| 108 | if (isMemoryEffectFree(op: user)) |
| 109 | continue; |
| 110 | if (user == write.getOperation()) |
| 111 | continue; |
| 112 | if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(Val: user)) { |
| 113 | // Check candidate that can override the store. |
| 114 | if (memref::isSameViewOrTrivialAlias( |
| 115 | a: cast<MemrefValue>(Val: nextWrite.getBase()), |
| 116 | b: cast<MemrefValue>(Val: write.getBase())) && |
| 117 | checkSameValueWAW(write: nextWrite, priorWrite: write) && |
| 118 | postDominators.postDominates(a: nextWrite, b: write)) { |
| 119 | if (firstOverwriteCandidate == nullptr || |
| 120 | postDominators.postDominates(a: firstOverwriteCandidate, b: nextWrite)) |
| 121 | firstOverwriteCandidate = nextWrite; |
| 122 | else |
| 123 | assert( |
| 124 | postDominators.postDominates(nextWrite, firstOverwriteCandidate)); |
| 125 | continue; |
| 126 | } |
| 127 | } |
| 128 | if (auto transferOp = dyn_cast<VectorTransferOpInterface>(Val: user)) { |
| 129 | // Don't need to consider disjoint accesses. |
| 130 | if (vector::isDisjointTransferSet( |
| 131 | transferA: cast<VectorTransferOpInterface>(Val: write.getOperation()), |
| 132 | transferB: cast<VectorTransferOpInterface>(Val: transferOp.getOperation()), |
| 133 | /*testDynamicValueUsingBounds=*/true)) |
| 134 | continue; |
| 135 | } |
| 136 | blockingAccesses.push_back(Elt: user); |
| 137 | } |
| 138 | if (firstOverwriteCandidate == nullptr) |
| 139 | return; |
| 140 | Region *topRegion = firstOverwriteCandidate->getParentRegion(); |
| 141 | Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write); |
| 142 | assert(writeAncestor && |
| 143 | "write op should be recursively part of the top region" ); |
| 144 | |
| 145 | for (Operation *access : blockingAccesses) { |
| 146 | Operation *accessAncestor = findAncestorOpInRegion(region: topRegion, op: access); |
| 147 | // TODO: if the access and write have the same ancestor we could recurse in |
| 148 | // the region to know if the access is reachable with more precision. |
| 149 | if (accessAncestor == nullptr || |
| 150 | !isReachable(start: writeAncestor, dest: accessAncestor)) |
| 151 | continue; |
| 152 | if (!dominators.dominates(a: firstOverwriteCandidate, b: accessAncestor)) { |
| 153 | LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " |
| 154 | << *accessAncestor << "\n" ); |
| 155 | return; |
| 156 | } |
| 157 | } |
| 158 | LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() |
| 159 | << " overwritten by: " << *firstOverwriteCandidate << "\n" ); |
| 160 | opToErase.push_back(x: write.getOperation()); |
| 161 | } |
| 162 | |
| 163 | /// A transfer_write candidate to storeToLoad forwarding must: |
| 164 | /// 1. Access the same memref with the same indices and vector type as the |
| 165 | /// transfer_read. |
| 166 | /// 2. Dominate the transfer_read operation. |
| 167 | /// If several candidates are available, one must be dominated by all the others |
| 168 | /// since they are all dominating the same transfer_read. We only consider the |
| 169 | /// transfer_write dominated by all the other candidates as this will be the |
| 170 | /// last transfer_write executed before the transfer_read. |
| 171 | /// If we found such a candidate we can do the forwarding if all the other |
| 172 | /// potentially aliasing ops that may reach the transfer_read are post-dominated |
| 173 | /// by the transfer_write. |
| 174 | void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { |
| 175 | if (read.hasOutOfBoundsDim()) |
| 176 | return; |
| 177 | LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() |
| 178 | << "\n" ); |
| 179 | SmallVector<Operation *, 8> blockingWrites; |
| 180 | vector::TransferWriteOp lastwrite = nullptr; |
| 181 | Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(Val: read.getBase())); |
| 182 | llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), |
| 183 | source.getUsers().end()); |
| 184 | llvm::SmallDenseSet<Operation *, 32> processed; |
| 185 | while (!users.empty()) { |
| 186 | Operation *user = users.pop_back_val(); |
| 187 | // If the user has already been processed skip. |
| 188 | if (!processed.insert(V: user).second) |
| 189 | continue; |
| 190 | if (isa<ViewLikeOpInterface>(Val: user)) { |
| 191 | users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end()); |
| 192 | continue; |
| 193 | } |
| 194 | if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user)) |
| 195 | continue; |
| 196 | if (auto write = dyn_cast<vector::TransferWriteOp>(Val: user)) { |
| 197 | // If there is a write, but we can prove that it is disjoint we can ignore |
| 198 | // the write. |
| 199 | if (vector::isDisjointTransferSet( |
| 200 | transferA: cast<VectorTransferOpInterface>(Val: write.getOperation()), |
| 201 | transferB: cast<VectorTransferOpInterface>(Val: read.getOperation()), |
| 202 | /*testDynamicValueUsingBounds=*/true)) |
| 203 | continue; |
| 204 | if (memref::isSameViewOrTrivialAlias( |
| 205 | a: cast<MemrefValue>(Val: read.getBase()), |
| 206 | b: cast<MemrefValue>(Val: write.getBase())) && |
| 207 | dominators.dominates(a: write, b: read) && checkSameValueRAW(defWrite: write, read)) { |
| 208 | if (lastwrite == nullptr || dominators.dominates(a: lastwrite, b: write)) |
| 209 | lastwrite = write; |
| 210 | else |
| 211 | assert(dominators.dominates(write, lastwrite)); |
| 212 | continue; |
| 213 | } |
| 214 | } |
| 215 | blockingWrites.push_back(Elt: user); |
| 216 | } |
| 217 | |
| 218 | if (lastwrite == nullptr) |
| 219 | return; |
| 220 | |
| 221 | Region *topRegion = lastwrite->getParentRegion(); |
| 222 | Operation *readAncestor = findAncestorOpInRegion(region: topRegion, op: read); |
| 223 | assert(readAncestor && |
| 224 | "read op should be recursively part of the top region" ); |
| 225 | |
| 226 | for (Operation *write : blockingWrites) { |
| 227 | Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write); |
| 228 | // TODO: if the store and read have the same ancestor we could recurse in |
| 229 | // the region to know if the read is reachable with more precision. |
| 230 | if (writeAncestor == nullptr || !isReachable(start: writeAncestor, dest: readAncestor)) |
| 231 | continue; |
| 232 | if (!postDominators.postDominates(a: lastwrite, b: write)) { |
| 233 | LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " |
| 234 | << *write << "\n" ); |
| 235 | return; |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() |
| 240 | << " to: " << *read.getOperation() << "\n" ); |
| 241 | read.replaceAllUsesWith(newValue: lastwrite.getVector()); |
| 242 | opToErase.push_back(x: read.getOperation()); |
| 243 | } |
| 244 | |
| 245 | /// Converts OpFoldResults to int64_t shape without unit dims. |
| 246 | static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { |
| 247 | SmallVector<int64_t> reducedShape; |
| 248 | for (const auto size : mixedSizes) { |
| 249 | if (llvm::dyn_cast_if_present<Value>(Val: size)) { |
| 250 | reducedShape.push_back(Elt: ShapedType::kDynamic); |
| 251 | continue; |
| 252 | } |
| 253 | |
| 254 | auto value = cast<IntegerAttr>(Val: cast<Attribute>(Val: size)).getValue(); |
| 255 | if (value == 1) |
| 256 | continue; |
| 257 | reducedShape.push_back(Elt: value.getSExtValue()); |
| 258 | } |
| 259 | return reducedShape; |
| 260 | } |
| 261 | |
| 262 | /// Drops unit dimensions from the input MemRefType. |
| 263 | static MemRefType dropUnitDims(MemRefType inputType, |
| 264 | ArrayRef<OpFoldResult> offsets, |
| 265 | ArrayRef<OpFoldResult> sizes, |
| 266 | ArrayRef<OpFoldResult> strides) { |
| 267 | auto targetShape = getReducedShape(mixedSizes: sizes); |
| 268 | MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType( |
| 269 | resultShape: targetShape, sourceMemRefType: inputType, staticOffsets: offsets, staticSizes: sizes, staticStrides: strides); |
| 270 | return rankReducedType.canonicalizeStridedLayout(); |
| 271 | } |
| 272 | |
| 273 | /// Creates a rank-reducing memref.subview op that drops unit dims from its |
| 274 | /// input. Or just returns the input if it was already without unit dims. |
| 275 | static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, |
| 276 | mlir::Location loc, |
| 277 | Value input) { |
| 278 | MemRefType inputType = cast<MemRefType>(Val: input.getType()); |
| 279 | SmallVector<OpFoldResult> offsets(inputType.getRank(), |
| 280 | rewriter.getIndexAttr(value: 0)); |
| 281 | SmallVector<OpFoldResult> sizes = memref::getMixedSizes(builder&: rewriter, loc, value: input); |
| 282 | SmallVector<OpFoldResult> strides(inputType.getRank(), |
| 283 | rewriter.getIndexAttr(value: 1)); |
| 284 | MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); |
| 285 | |
| 286 | if (resultType.canonicalizeStridedLayout() == |
| 287 | inputType.canonicalizeStridedLayout()) |
| 288 | return input; |
| 289 | return rewriter.create<memref::SubViewOp>(location: loc, args&: resultType, args&: input, args&: offsets, |
| 290 | args&: sizes, args&: strides); |
| 291 | } |
| 292 | |
| 293 | /// Returns the number of dims that aren't unit dims. |
| 294 | static int getReducedRank(ArrayRef<int64_t> shape) { |
| 295 | return llvm::count_if(Range&: shape, P: [](int64_t dimSize) { return dimSize != 1; }); |
| 296 | } |
| 297 | |
| 298 | /// Trims non-scalable one dimensions from `oldType` and returns the result |
| 299 | /// type. |
| 300 | static VectorType trimNonScalableUnitDims(VectorType oldType) { |
| 301 | SmallVector<int64_t> newShape; |
| 302 | SmallVector<bool> newScalableDims; |
| 303 | for (auto [dimIdx, dimSize] : llvm::enumerate(First: oldType.getShape())) { |
| 304 | if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) |
| 305 | continue; |
| 306 | newShape.push_back(Elt: dimSize); |
| 307 | newScalableDims.push_back(Elt: oldType.getScalableDims()[dimIdx]); |
| 308 | } |
| 309 | return VectorType::get(shape: newShape, elementType: oldType.getElementType(), scalableDims: newScalableDims); |
| 310 | } |
| 311 | |
| 312 | // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. |
| 313 | static FailureOr<Value> |
| 314 | createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, |
| 315 | vector::CreateMaskOp op) { |
| 316 | auto type = op.getType(); |
| 317 | VectorType reducedType = trimNonScalableUnitDims(oldType: type); |
| 318 | if (reducedType.getRank() == type.getRank()) |
| 319 | return failure(); |
| 320 | |
| 321 | SmallVector<Value> reducedOperands; |
| 322 | for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( |
| 323 | t: type.getShape(), u: type.getScalableDims(), args: op.getOperands())) { |
| 324 | if (dim == 1 && !dimIsScalable) { |
| 325 | // If the mask for the unit dim is not a constant of 1, do nothing. |
| 326 | auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); |
| 327 | if (!constant || (constant.value() != 1)) |
| 328 | return failure(); |
| 329 | continue; |
| 330 | } |
| 331 | reducedOperands.push_back(Elt: operand); |
| 332 | } |
| 333 | return rewriter |
| 334 | .create<vector::CreateMaskOp>(location: loc, args&: reducedType, args&: reducedOperands) |
| 335 | .getResult(); |
| 336 | } |
| 337 | |
| 338 | namespace { |
| 339 | |
| 340 | /// Rewrites `vector.transfer_read` ops where the source has unit dims, by |
| 341 | /// inserting a memref.subview dropping those unit dims. The vector shapes are |
| 342 | /// also reduced accordingly. |
| 343 | class TransferReadDropUnitDimsPattern |
| 344 | : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { |
| 345 | using MaskableOpRewritePattern::MaskableOpRewritePattern; |
| 346 | |
| 347 | FailureOr<Value> |
| 348 | matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp, |
| 349 | vector::MaskingOpInterface maskingOp, |
| 350 | PatternRewriter &rewriter) const override { |
| 351 | auto loc = transferReadOp.getLoc(); |
| 352 | Value vector = transferReadOp.getVector(); |
| 353 | VectorType vectorType = cast<VectorType>(Val: vector.getType()); |
| 354 | Value source = transferReadOp.getBase(); |
| 355 | MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType()); |
| 356 | // TODO: support tensor types. |
| 357 | if (!sourceType) |
| 358 | return failure(); |
| 359 | // TODO: generalize this pattern, relax the requirements here. |
| 360 | if (transferReadOp.hasOutOfBoundsDim()) |
| 361 | return failure(); |
| 362 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
| 363 | return failure(); |
| 364 | // Check if the source shape can be further reduced. |
| 365 | int reducedRank = getReducedRank(shape: sourceType.getShape()); |
| 366 | if (reducedRank == sourceType.getRank()) |
| 367 | return failure(); |
| 368 | // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail |
| 369 | // out. |
| 370 | if (reducedRank == 0 && maskingOp) |
| 371 | return failure(); |
| 372 | // Check if the reduced vector shape matches the reduced source shape. |
| 373 | // Otherwise, this case is not supported yet. |
| 374 | VectorType reducedVectorType = trimNonScalableUnitDims(oldType: vectorType); |
| 375 | if (reducedRank != reducedVectorType.getRank()) |
| 376 | return failure(); |
| 377 | if (llvm::any_of(Range: transferReadOp.getIndices(), P: [](Value v) { |
| 378 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
| 379 | })) |
| 380 | return failure(); |
| 381 | |
| 382 | Value maskOp = transferReadOp.getMask(); |
| 383 | if (maskOp) { |
| 384 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
| 385 | if (!createMaskOp) |
| 386 | return rewriter.notifyMatchFailure( |
| 387 | arg&: transferReadOp, msg: "unsupported mask op, only 'vector.create_mask' is " |
| 388 | "currently supported" ); |
| 389 | FailureOr<Value> rankReducedCreateMask = |
| 390 | createMaskDropNonScalableUnitDims(rewriter, loc, op: createMaskOp); |
| 391 | if (failed(Result: rankReducedCreateMask)) |
| 392 | return failure(); |
| 393 | maskOp = *rankReducedCreateMask; |
| 394 | } |
| 395 | |
| 396 | Value reducedShapeSource = |
| 397 | rankReducingSubviewDroppingUnitDims(rewriter, loc, input: source); |
| 398 | Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 399 | SmallVector<Value> zeros(reducedRank, c0); |
| 400 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
| 401 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
| 402 | Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>( |
| 403 | location: loc, args&: reducedVectorType, args&: reducedShapeSource, args&: zeros, args&: identityMap, |
| 404 | args: transferReadOp.getPadding(), args&: maskOp, |
| 405 | args: rewriter.getBoolArrayAttr(values: inBounds)); |
| 406 | |
| 407 | if (maskingOp) { |
| 408 | auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( |
| 409 | location: loc, args: reducedVectorType.cloneWith(shape: std::nullopt, elementType: rewriter.getI1Type()), |
| 410 | args: maskingOp.getMask()); |
| 411 | newTransferReadOp = mlir::vector::maskOperation( |
| 412 | builder&: rewriter, maskableOp: newTransferReadOp, mask: shapeCastMask); |
| 413 | } |
| 414 | |
| 415 | auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( |
| 416 | location: loc, args&: vectorType, args: newTransferReadOp->getResults()[0]); |
| 417 | |
| 418 | return shapeCast; |
| 419 | } |
| 420 | }; |
| 421 | |
| 422 | /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) |
| 423 | /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The |
| 424 | /// vector shapes are also reduced accordingly. |
| 425 | class TransferWriteDropUnitDimsPattern |
| 426 | : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> { |
| 427 | using MaskableOpRewritePattern::MaskableOpRewritePattern; |
| 428 | |
| 429 | FailureOr<Value> |
| 430 | matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp, |
| 431 | vector::MaskingOpInterface maskingOp, |
| 432 | PatternRewriter &rewriter) const override { |
| 433 | auto loc = transferWriteOp.getLoc(); |
| 434 | Value vector = transferWriteOp.getVector(); |
| 435 | VectorType vectorType = cast<VectorType>(Val: vector.getType()); |
| 436 | Value source = transferWriteOp.getBase(); |
| 437 | MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType()); |
| 438 | // TODO: support tensor type. |
| 439 | if (!sourceType) |
| 440 | return failure(); |
| 441 | // TODO: generalize this pattern, relax the requirements here. |
| 442 | if (transferWriteOp.hasOutOfBoundsDim()) |
| 443 | return failure(); |
| 444 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
| 445 | return failure(); |
| 446 | // Check if the destination shape can be further reduced. |
| 447 | int reducedRank = getReducedRank(shape: sourceType.getShape()); |
| 448 | if (reducedRank == sourceType.getRank()) |
| 449 | return failure(); |
| 450 | // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail |
| 451 | // out. |
| 452 | if (reducedRank == 0 && maskingOp) |
| 453 | return failure(); |
| 454 | // Check if the reduced vector shape matches the reduced destination shape. |
| 455 | // Otherwise, this case is not supported yet. |
| 456 | VectorType reducedVectorType = trimNonScalableUnitDims(oldType: vectorType); |
| 457 | if (reducedRank != reducedVectorType.getRank()) |
| 458 | return failure(); |
| 459 | if (llvm::any_of(Range: transferWriteOp.getIndices(), P: [](Value v) { |
| 460 | return getConstantIntValue(ofr: v) != static_cast<int64_t>(0); |
| 461 | })) |
| 462 | return failure(); |
| 463 | |
| 464 | Value maskOp = transferWriteOp.getMask(); |
| 465 | if (maskOp) { |
| 466 | auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); |
| 467 | if (!createMaskOp) |
| 468 | return rewriter.notifyMatchFailure( |
| 469 | arg&: transferWriteOp, |
| 470 | msg: "unsupported mask op, only 'vector.create_mask' is " |
| 471 | "currently supported" ); |
| 472 | FailureOr<Value> rankReducedCreateMask = |
| 473 | createMaskDropNonScalableUnitDims(rewriter, loc, op: createMaskOp); |
| 474 | if (failed(Result: rankReducedCreateMask)) |
| 475 | return failure(); |
| 476 | maskOp = *rankReducedCreateMask; |
| 477 | } |
| 478 | Value reducedShapeSource = |
| 479 | rankReducingSubviewDroppingUnitDims(rewriter, loc, input: source); |
| 480 | Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 481 | SmallVector<Value> zeros(reducedRank, c0); |
| 482 | auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank); |
| 483 | SmallVector<bool> inBounds(reducedVectorType.getRank(), true); |
| 484 | auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>( |
| 485 | location: loc, args&: reducedVectorType, args&: vector); |
| 486 | Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>( |
| 487 | location: loc, args: Type(), args&: shapeCastSrc, args&: reducedShapeSource, args&: zeros, args&: identityMap, |
| 488 | args&: maskOp, args: rewriter.getBoolArrayAttr(values: inBounds)); |
| 489 | |
| 490 | if (maskingOp) { |
| 491 | auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( |
| 492 | location: loc, args: reducedVectorType.cloneWith(shape: std::nullopt, elementType: rewriter.getI1Type()), |
| 493 | args: maskingOp.getMask()); |
| 494 | newXferWrite = |
| 495 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: newXferWrite, mask: shapeCastMask); |
| 496 | } |
| 497 | |
| 498 | if (transferWriteOp.hasPureTensorSemantics()) |
| 499 | return newXferWrite->getResults()[0]; |
| 500 | |
| 501 | // With Memref semantics, there's no return value. Use empty value to signal |
| 502 | // success. |
| 503 | return Value(); |
| 504 | } |
| 505 | }; |
| 506 | |
| 507 | } // namespace |
| 508 | |
| 509 | /// Creates a memref.collapse_shape collapsing all inner dimensions of the |
| 510 | /// input starting at `firstDimToCollapse`. |
| 511 | static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, |
| 512 | Value input, int64_t firstDimToCollapse) { |
| 513 | ShapedType inputType = cast<ShapedType>(Val: input.getType()); |
| 514 | if (inputType.getRank() == 1) |
| 515 | return input; |
| 516 | SmallVector<ReassociationIndices> reassociation; |
| 517 | for (int64_t i = 0; i < firstDimToCollapse; ++i) |
| 518 | reassociation.push_back(Elt: ReassociationIndices{i}); |
| 519 | ReassociationIndices collapsedIndices; |
| 520 | for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) |
| 521 | collapsedIndices.push_back(Elt: i); |
| 522 | reassociation.push_back(Elt: collapsedIndices); |
| 523 | return rewriter.create<memref::CollapseShapeOp>(location: loc, args&: input, args&: reassociation); |
| 524 | } |
| 525 | |
| 526 | /// Returns the new indices that collapses the inner dimensions starting from |
| 527 | /// the `firstDimToCollapse` dimension. |
| 528 | static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter, |
| 529 | Location loc, |
| 530 | ArrayRef<int64_t> shape, |
| 531 | ValueRange indices, |
| 532 | int64_t firstDimToCollapse) { |
| 533 | assert(firstDimToCollapse < static_cast<int64_t>(indices.size())); |
| 534 | |
| 535 | // If all the collapsed indices are zero then no extra logic is needed. |
| 536 | // Otherwise, a new offset/index has to be computed. |
| 537 | SmallVector<Value> indicesAfterCollapsing( |
| 538 | indices.begin(), indices.begin() + firstDimToCollapse); |
| 539 | SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, |
| 540 | indices.end()); |
| 541 | if (llvm::all_of(Range&: indicesToCollapse, P: isZeroInteger)) { |
| 542 | indicesAfterCollapsing.push_back(Elt: indicesToCollapse[0]); |
| 543 | return indicesAfterCollapsing; |
| 544 | } |
| 545 | |
| 546 | // Compute the remaining trailing index/offset required for reading from |
| 547 | // the collapsed memref: |
| 548 | // |
| 549 | // offset = 0 |
| 550 | // for (i = firstDimToCollapse; i < outputRank; ++i) |
| 551 | // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] |
| 552 | // |
| 553 | // For this example: |
| 554 | // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : |
| 555 | // memref<1x43x2xi32>, vector<1x2xi32> |
| 556 | // which would be collapsed to: |
| 557 | // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : |
| 558 | // memref<1x86xi32>, vector<2xi32> |
| 559 | // one would get the following offset: |
| 560 | // %offset = %arg0 * 43 |
| 561 | OpFoldResult collapsedOffset = |
| 562 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult(); |
| 563 | |
| 564 | auto collapsedStrides = computeSuffixProduct( |
| 565 | sizes: ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end())); |
| 566 | |
| 567 | // Compute the collapsed offset. |
| 568 | auto &&[collapsedExpr, collapsedVals] = |
| 569 | computeLinearIndex(sourceOffset: collapsedOffset, strides: collapsedStrides, indices: indicesToCollapse); |
| 570 | collapsedOffset = affine::makeComposedFoldedAffineApply( |
| 571 | b&: rewriter, loc, expr: collapsedExpr, operands: collapsedVals); |
| 572 | |
| 573 | if (auto value = dyn_cast<Value>(Val&: collapsedOffset)) { |
| 574 | indicesAfterCollapsing.push_back(Elt: value); |
| 575 | } else { |
| 576 | indicesAfterCollapsing.push_back(Elt: rewriter.create<arith::ConstantIndexOp>( |
| 577 | location: loc, args: *getConstantIntValue(ofr: collapsedOffset))); |
| 578 | } |
| 579 | |
| 580 | return indicesAfterCollapsing; |
| 581 | } |
| 582 | |
| 583 | namespace { |
| 584 | /// Rewrites contiguous row-major vector.transfer_read ops by inserting |
| 585 | /// memref.collapse_shape on the source so that the resulting |
| 586 | /// vector.transfer_read has a 1D source. Requires the source shape to be |
| 587 | /// already reduced i.e. without unit dims. |
| 588 | /// |
| 589 | /// If `targetVectorBitwidth` is provided, the flattening will only happen if |
| 590 | /// the trailing dimension of the vector read is smaller than the provided |
| 591 | /// bitwidth. |
| 592 | class FlattenContiguousRowMajorTransferReadPattern |
| 593 | : public OpRewritePattern<vector::TransferReadOp> { |
| 594 | public: |
| 595 | FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, |
| 596 | unsigned vectorBitwidth, |
| 597 | PatternBenefit benefit) |
| 598 | : OpRewritePattern<vector::TransferReadOp>(context, benefit), |
| 599 | targetVectorBitwidth(vectorBitwidth) {} |
| 600 | |
| 601 | LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, |
| 602 | PatternRewriter &rewriter) const override { |
| 603 | auto loc = transferReadOp.getLoc(); |
| 604 | Value vector = transferReadOp.getVector(); |
| 605 | VectorType vectorType = cast<VectorType>(Val: vector.getType()); |
| 606 | auto source = transferReadOp.getBase(); |
| 607 | MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType()); |
| 608 | |
| 609 | // 0. Check pre-conditions |
| 610 | // Contiguity check is valid on tensors only. |
| 611 | if (!sourceType) |
| 612 | return failure(); |
| 613 | // If this is already 0D/1D, there's nothing to do. |
| 614 | if (vectorType.getRank() <= 1) |
| 615 | return failure(); |
| 616 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
| 617 | return failure(); |
| 618 | unsigned trailingVectorDimBitwidth = |
| 619 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
| 620 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
| 621 | return failure(); |
| 622 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType)) |
| 623 | return failure(); |
| 624 | // TODO: generalize this pattern, relax the requirements here. |
| 625 | if (transferReadOp.hasOutOfBoundsDim()) |
| 626 | return failure(); |
| 627 | if (!transferReadOp.getPermutationMap().isMinorIdentity()) |
| 628 | return failure(); |
| 629 | if (transferReadOp.getMask()) |
| 630 | return failure(); |
| 631 | |
| 632 | // Determine the first memref dimension to collapse - just enough so we can |
| 633 | // read a flattened vector. |
| 634 | int64_t firstDimToCollapse = |
| 635 | sourceType.getRank() - |
| 636 | vectorType.getShape().drop_while(Pred: [](auto v) { return v == 1; }).size(); |
| 637 | |
| 638 | // 1. Collapse the source memref |
| 639 | Value collapsedSource = |
| 640 | collapseInnerDims(rewriter, loc, input: source, firstDimToCollapse); |
| 641 | MemRefType collapsedSourceType = |
| 642 | cast<MemRefType>(Val: collapsedSource.getType()); |
| 643 | int64_t collapsedRank = collapsedSourceType.getRank(); |
| 644 | assert(collapsedRank == firstDimToCollapse + 1); |
| 645 | |
| 646 | // 2. Generate input args for a new vector.transfer_read that will read |
| 647 | // from the collapsed memref. |
| 648 | // 2.1. New dim exprs + affine map |
| 649 | SmallVector<AffineExpr, 1> dimExprs{ |
| 650 | getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())}; |
| 651 | auto collapsedMap = |
| 652 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
| 653 | |
| 654 | // 2.2 New indices |
| 655 | SmallVector<Value> collapsedIndices = |
| 656 | getCollapsedIndices(rewriter, loc, shape: sourceType.getShape(), |
| 657 | indices: transferReadOp.getIndices(), firstDimToCollapse); |
| 658 | |
| 659 | // 3. Create new vector.transfer_read that reads from the collapsed memref |
| 660 | VectorType flatVectorType = VectorType::get(shape: {vectorType.getNumElements()}, |
| 661 | elementType: vectorType.getElementType()); |
| 662 | vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( |
| 663 | location: loc, args&: flatVectorType, args&: collapsedSource, args&: collapsedIndices, |
| 664 | args: transferReadOp.getPadding(), args&: collapsedMap); |
| 665 | flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr(values: {true})); |
| 666 | |
| 667 | // 4. Replace the old transfer_read with the new one reading from the |
| 668 | // collapsed shape |
| 669 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( |
| 670 | op: transferReadOp, args: cast<VectorType>(Val: vector.getType()), args&: flatRead); |
| 671 | return success(); |
| 672 | } |
| 673 | |
| 674 | private: |
| 675 | // Minimum bitwidth that the trailing vector dimension should have after |
| 676 | // flattening. |
| 677 | unsigned targetVectorBitwidth; |
| 678 | }; |
| 679 | |
| 680 | /// Rewrites contiguous row-major vector.transfer_write ops by inserting |
| 681 | /// memref.collapse_shape on the source so that the resulting |
| 682 | /// vector.transfer_write has a 1D source. Requires the source shape to be |
| 683 | /// already reduced i.e. without unit dims. |
| 684 | /// |
| 685 | /// If `targetVectorBitwidth` is provided, the flattening will only happen if |
| 686 | /// the trailing dimension of the vector read is smaller than the provided |
| 687 | /// bitwidth. |
| 688 | class FlattenContiguousRowMajorTransferWritePattern |
| 689 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 690 | public: |
| 691 | FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, |
| 692 | unsigned vectorBitwidth, |
| 693 | PatternBenefit benefit) |
| 694 | : OpRewritePattern<vector::TransferWriteOp>(context, benefit), |
| 695 | targetVectorBitwidth(vectorBitwidth) {} |
| 696 | |
| 697 | LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, |
| 698 | PatternRewriter &rewriter) const override { |
| 699 | auto loc = transferWriteOp.getLoc(); |
| 700 | Value vector = transferWriteOp.getVector(); |
| 701 | VectorType vectorType = cast<VectorType>(Val: vector.getType()); |
| 702 | Value source = transferWriteOp.getBase(); |
| 703 | MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType()); |
| 704 | |
| 705 | // 0. Check pre-conditions |
| 706 | // Contiguity check is valid on tensors only. |
| 707 | if (!sourceType) |
| 708 | return failure(); |
| 709 | // If this is already 0D/1D, there's nothing to do. |
| 710 | if (vectorType.getRank() <= 1) |
| 711 | // Already 0D/1D, nothing to do. |
| 712 | return failure(); |
| 713 | if (!vectorType.getElementType().isSignlessIntOrFloat()) |
| 714 | return failure(); |
| 715 | unsigned trailingVectorDimBitwidth = |
| 716 | vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); |
| 717 | if (trailingVectorDimBitwidth >= targetVectorBitwidth) |
| 718 | return failure(); |
| 719 | if (!vector::isContiguousSlice(memrefType: sourceType, vectorType)) |
| 720 | return failure(); |
| 721 | // TODO: generalize this pattern, relax the requirements here. |
| 722 | if (transferWriteOp.hasOutOfBoundsDim()) |
| 723 | return failure(); |
| 724 | if (!transferWriteOp.getPermutationMap().isMinorIdentity()) |
| 725 | return failure(); |
| 726 | if (transferWriteOp.getMask()) |
| 727 | return failure(); |
| 728 | |
| 729 | // Determine the first memref dimension to collapse - just enough so we can |
| 730 | // read a flattened vector. |
| 731 | int64_t firstDimToCollapse = |
| 732 | sourceType.getRank() - |
| 733 | vectorType.getShape().drop_while(Pred: [](auto v) { return v == 1; }).size(); |
| 734 | |
| 735 | // 1. Collapse the source memref |
| 736 | Value collapsedSource = |
| 737 | collapseInnerDims(rewriter, loc, input: source, firstDimToCollapse); |
| 738 | MemRefType collapsedSourceType = |
| 739 | cast<MemRefType>(Val: collapsedSource.getType()); |
| 740 | int64_t collapsedRank = collapsedSourceType.getRank(); |
| 741 | assert(collapsedRank == firstDimToCollapse + 1); |
| 742 | |
| 743 | // 2. Generate input args for a new vector.transfer_read that will read |
| 744 | // from the collapsed memref. |
| 745 | // 2.1. New dim exprs + affine map |
| 746 | SmallVector<AffineExpr, 1> dimExprs{ |
| 747 | getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())}; |
| 748 | auto collapsedMap = |
| 749 | AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext()); |
| 750 | |
| 751 | // 2.2 New indices |
| 752 | SmallVector<Value> collapsedIndices = |
| 753 | getCollapsedIndices(rewriter, loc, shape: sourceType.getShape(), |
| 754 | indices: transferWriteOp.getIndices(), firstDimToCollapse); |
| 755 | |
| 756 | // 3. Create new vector.transfer_write that writes to the collapsed memref |
| 757 | VectorType flatVectorType = VectorType::get(shape: {vectorType.getNumElements()}, |
| 758 | elementType: vectorType.getElementType()); |
| 759 | Value flatVector = |
| 760 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: flatVectorType, args&: vector); |
| 761 | vector::TransferWriteOp flatWrite = |
| 762 | rewriter.create<vector::TransferWriteOp>( |
| 763 | location: loc, args&: flatVector, args&: collapsedSource, args&: collapsedIndices, args&: collapsedMap); |
| 764 | flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr(values: {true})); |
| 765 | |
| 766 | // 4. Replace the old transfer_write with the new one writing the |
| 767 | // collapsed shape |
| 768 | rewriter.eraseOp(op: transferWriteOp); |
| 769 | return success(); |
| 770 | } |
| 771 | |
| 772 | private: |
| 773 | // Minimum bitwidth that the trailing vector dimension should have after |
| 774 | // flattening. |
| 775 | unsigned targetVectorBitwidth; |
| 776 | }; |
| 777 | |
| 778 | /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. |
| 779 | /// |
| 780 | /// All the users of the transfer op must be `vector.extract` ops. If |
| 781 | /// `allowMultipleUses` is set to true, rewrite transfer ops with any number of |
| 782 | /// users. Otherwise, rewrite only if the extract op is the single user of the |
| 783 | /// transfer op. Rewriting a single vector load with multiple scalar loads may |
| 784 | /// negatively affect performance. |
| 785 | class |
| 786 | : public OpRewritePattern<vector::ExtractOp> { |
| 787 | public: |
| 788 | (MLIRContext *context, |
| 789 | PatternBenefit benefit, |
| 790 | bool allowMultipleUses) |
| 791 | : OpRewritePattern(context, benefit), |
| 792 | allowMultipleUses(allowMultipleUses) {} |
| 793 | |
| 794 | LogicalResult matchAndRewrite(vector::ExtractOp , |
| 795 | PatternRewriter &rewriter) const override { |
| 796 | // Match phase. |
| 797 | auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); |
| 798 | if (!xferOp) |
| 799 | return failure(); |
| 800 | // Check that we are extracting a scalar and not a sub-vector. |
| 801 | if (isa<VectorType>(Val: extractOp.getResult().getType())) |
| 802 | return failure(); |
| 803 | // If multiple uses are not allowed, check if xfer has a single use. |
| 804 | if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) |
| 805 | return failure(); |
| 806 | // If multiple uses are allowed, check if all the xfer uses are extract ops. |
| 807 | if (allowMultipleUses && |
| 808 | !llvm::all_of(Range: xferOp->getUses(), P: [](OpOperand &use) { |
| 809 | return isa<vector::ExtractOp>(Val: use.getOwner()); |
| 810 | })) |
| 811 | return failure(); |
| 812 | // Mask not supported. |
| 813 | if (xferOp.getMask()) |
| 814 | return failure(); |
| 815 | // Map not supported. |
| 816 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
| 817 | return failure(); |
| 818 | // Cannot rewrite if the indices may be out of bounds. |
| 819 | if (xferOp.hasOutOfBoundsDim()) |
| 820 | return failure(); |
| 821 | |
| 822 | // Rewrite phase: construct scalar load. |
| 823 | SmallVector<Value> newIndices(xferOp.getIndices().begin(), |
| 824 | xferOp.getIndices().end()); |
| 825 | for (auto [i, pos] : llvm::enumerate(First: extractOp.getMixedPosition())) { |
| 826 | int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; |
| 827 | |
| 828 | // Compute affine expression `newIndices[idx] + pos` where `pos` can be |
| 829 | // either a constant or a value. |
| 830 | OpFoldResult composedIdx; |
| 831 | if (auto attr = dyn_cast<Attribute>(Val&: pos)) { |
| 832 | int64_t offset = cast<IntegerAttr>(Val&: attr).getInt(); |
| 833 | composedIdx = affine::makeComposedFoldedAffineApply( |
| 834 | b&: rewriter, loc: extractOp.getLoc(), |
| 835 | expr: rewriter.getAffineSymbolExpr(position: 0) + offset, operands: {newIndices[idx]}); |
| 836 | } else { |
| 837 | Value dynamicOffset = cast<Value>(Val&: pos); |
| 838 | AffineExpr sym0, sym1; |
| 839 | bindSymbols(ctx: rewriter.getContext(), exprs&: sym0, exprs&: sym1); |
| 840 | composedIdx = affine::makeComposedFoldedAffineApply( |
| 841 | b&: rewriter, loc: extractOp.getLoc(), expr: sym0 + sym1, |
| 842 | operands: {newIndices[idx], dynamicOffset}); |
| 843 | } |
| 844 | |
| 845 | // Update the corresponding index with the folded result. |
| 846 | if (auto value = dyn_cast<Value>(Val&: composedIdx)) { |
| 847 | newIndices[idx] = value; |
| 848 | } else { |
| 849 | newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( |
| 850 | location: extractOp.getLoc(), args: *getConstantIntValue(ofr: composedIdx)); |
| 851 | } |
| 852 | } |
| 853 | if (isa<MemRefType>(Val: xferOp.getBase().getType())) { |
| 854 | rewriter.replaceOpWithNewOp<memref::LoadOp>(op: extractOp, args: xferOp.getBase(), |
| 855 | args&: newIndices); |
| 856 | } else { |
| 857 | rewriter.replaceOpWithNewOp<tensor::ExtractOp>( |
| 858 | op: extractOp, args: xferOp.getBase(), args&: newIndices); |
| 859 | } |
| 860 | |
| 861 | return success(); |
| 862 | } |
| 863 | |
| 864 | private: |
| 865 | bool ; |
| 866 | }; |
| 867 | |
| 868 | /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) |
| 869 | /// to memref.store. |
| 870 | class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { |
| 871 | using OpRewritePattern::OpRewritePattern; |
| 872 | |
| 873 | LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, |
| 874 | PatternRewriter &rewriter) const override { |
| 875 | // Must be a scalar write. |
| 876 | auto vecType = xferOp.getVectorType(); |
| 877 | if (!llvm::all_of(Range: vecType.getShape(), P: [](int64_t sz) { return sz == 1; })) |
| 878 | return failure(); |
| 879 | // Mask not supported. |
| 880 | if (xferOp.getMask()) |
| 881 | return failure(); |
| 882 | // Map not supported. |
| 883 | if (!xferOp.getPermutationMap().isMinorIdentity()) |
| 884 | return failure(); |
| 885 | // Only float and integer element types are supported. |
| 886 | Value scalar = |
| 887 | rewriter.create<vector::ExtractOp>(location: xferOp.getLoc(), args: xferOp.getVector()); |
| 888 | // Construct a scalar store. |
| 889 | if (isa<MemRefType>(Val: xferOp.getBase().getType())) { |
| 890 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
| 891 | op: xferOp, args&: scalar, args: xferOp.getBase(), args: xferOp.getIndices()); |
| 892 | } else { |
| 893 | rewriter.replaceOpWithNewOp<tensor::InsertOp>( |
| 894 | op: xferOp, args&: scalar, args: xferOp.getBase(), args: xferOp.getIndices()); |
| 895 | } |
| 896 | return success(); |
| 897 | } |
| 898 | }; |
| 899 | |
| 900 | } // namespace |
| 901 | |
| 902 | void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, |
| 903 | Operation *rootOp) { |
| 904 | TransferOptimization opt(rewriter, rootOp); |
| 905 | // Run store to load forwarding first since it can expose more dead store |
| 906 | // opportunity. |
| 907 | rootOp->walk(callback: [&](vector::TransferReadOp read) { |
| 908 | if (isa<MemRefType>(Val: read.getShapedType())) |
| 909 | opt.storeToLoadForwarding(read); |
| 910 | }); |
| 911 | opt.removeDeadOp(); |
| 912 | rootOp->walk(callback: [&](vector::TransferWriteOp write) { |
| 913 | if (isa<MemRefType>(Val: write.getShapedType())) |
| 914 | opt.deadStoreOp(write); |
| 915 | }); |
| 916 | opt.removeDeadOp(); |
| 917 | } |
| 918 | |
| 919 | void mlir::vector::populateScalarVectorTransferLoweringPatterns( |
| 920 | RewritePatternSet &patterns, PatternBenefit benefit, |
| 921 | bool allowMultipleUses) { |
| 922 | patterns.add<RewriteScalarExtractOfTransferRead>(arg: patterns.getContext(), |
| 923 | args&: benefit, args&: allowMultipleUses); |
| 924 | patterns.add<RewriteScalarWrite>(arg: patterns.getContext(), args&: benefit); |
| 925 | } |
| 926 | |
| 927 | void mlir::vector::populateVectorTransferDropUnitDimsPatterns( |
| 928 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 929 | patterns |
| 930 | .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( |
| 931 | arg: patterns.getContext(), args&: benefit); |
| 932 | } |
| 933 | |
| 934 | void mlir::vector::populateFlattenVectorTransferPatterns( |
| 935 | RewritePatternSet &patterns, unsigned targetVectorBitwidth, |
| 936 | PatternBenefit benefit) { |
| 937 | patterns.add<FlattenContiguousRowMajorTransferReadPattern, |
| 938 | FlattenContiguousRowMajorTransferWritePattern>( |
| 939 | arg: patterns.getContext(), args&: targetVectorBitwidth, args&: benefit); |
| 940 | populateDropUnitDimWithShapeCastPatterns(patterns, benefit); |
| 941 | } |
| 942 | |