1//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with optimizing transfer_read and
10// transfer_write ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/Dialect/Vector/IR/VectorOps.h"
20#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
21#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
22#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
23#include "mlir/IR/Dominance.h"
24#include "mlir/Interfaces/SideEffectInterfaces.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/StringRef.h"
27#include "llvm/Support/Debug.h"
28
29#define DEBUG_TYPE "vector-transfer-opt"
30
31#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
32
33using namespace mlir;
34
35/// Return the ancestor op in the region or nullptr if the region is not
36/// an ancestor of the op.
37static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
38 for (; op != nullptr && op->getParentRegion() != region;
39 op = op->getParentOp())
40 ;
41 return op;
42}
43
44namespace {
45
46class TransferOptimization {
47public:
48 TransferOptimization(RewriterBase &rewriter, Operation *op)
49 : rewriter(rewriter), dominators(op), postDominators(op) {}
50 void deadStoreOp(vector::TransferWriteOp);
51 void storeToLoadForwarding(vector::TransferReadOp);
52 void removeDeadOp() {
53 for (Operation *op : opToErase)
54 rewriter.eraseOp(op);
55 opToErase.clear();
56 }
57
58private:
59 RewriterBase &rewriter;
60 bool isReachable(Operation *start, Operation *dest);
61 DominanceInfo dominators;
62 PostDominanceInfo postDominators;
63 std::vector<Operation *> opToErase;
64};
65
66} // namespace
67/// Return true if there is a path from start operation to dest operation,
68/// otherwise return false. The operations have to be in the same region.
69bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
70 assert(start->getParentRegion() == dest->getParentRegion() &&
71 "This function only works for ops i the same region");
72 // Simple case where the start op dominate the destination.
73 if (dominators.dominates(a: start, b: dest))
74 return true;
75 Block *startBlock = start->getBlock();
76 Block *destBlock = dest->getBlock();
77 SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
78 startBlock->succ_end());
79 SmallPtrSet<Block *, 32> visited;
80 while (!worklist.empty()) {
81 Block *bb = worklist.pop_back_val();
82 if (!visited.insert(Ptr: bb).second)
83 continue;
84 if (dominators.dominates(a: bb, b: destBlock))
85 return true;
86 worklist.append(in_start: bb->succ_begin(), in_end: bb->succ_end());
87 }
88 return false;
89}
90
91/// For transfer_write to overwrite fully another transfer_write must:
92/// 1. Access the same memref with the same indices and vector type.
93/// 2. Post-dominate the other transfer_write operation.
94/// If several candidates are available, one must be post-dominated by all the
95/// others since they are all post-dominating the same transfer_write. We only
96/// consider the transfer_write post-dominated by all the other candidates as
97/// this will be the first transfer_write executed after the potentially dead
98/// transfer_write.
99/// If we found such an overwriting transfer_write we know that the original
100/// transfer_write is dead if all reads that can be reached from the potentially
101/// dead transfer_write are dominated by the overwriting transfer_write.
102void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
103 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
104 << "\n");
105 llvm::SmallVector<Operation *, 8> blockingAccesses;
106 Operation *firstOverwriteCandidate = nullptr;
107 Value source = write.getSource();
108 // Skip subview ops.
109 while (auto subView = source.getDefiningOp<memref::SubViewOp>())
110 source = subView.getSource();
111 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
112 source.getUsers().end());
113 llvm::SmallDenseSet<Operation *, 32> processed;
114 while (!users.empty()) {
115 Operation *user = users.pop_back_val();
116 // If the user has already been processed skip.
117 if (!processed.insert(V: user).second)
118 continue;
119 if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
120 users.append(subView->getUsers().begin(), subView->getUsers().end());
121 continue;
122 }
123 if (isMemoryEffectFree(op: user))
124 continue;
125 if (user == write.getOperation())
126 continue;
127 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128 // Check candidate that can override the store.
129 if (write.getSource() == nextWrite.getSource() &&
130 checkSameValueWAW(nextWrite, write) &&
131 postDominators.postDominates(nextWrite, write)) {
132 if (firstOverwriteCandidate == nullptr ||
133 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
134 firstOverwriteCandidate = nextWrite;
135 else
136 assert(
137 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
138 continue;
139 }
140 }
141 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
142 // Don't need to consider disjoint accesses.
143 if (vector::isDisjointTransferSet(
144 cast<VectorTransferOpInterface>(write.getOperation()),
145 cast<VectorTransferOpInterface>(transferOp.getOperation()),
146 /*testDynamicValueUsingBounds=*/true))
147 continue;
148 }
149 blockingAccesses.push_back(Elt: user);
150 }
151 if (firstOverwriteCandidate == nullptr)
152 return;
153 Region *topRegion = firstOverwriteCandidate->getParentRegion();
154 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
155 assert(writeAncestor &&
156 "write op should be recursively part of the top region");
157
158 for (Operation *access : blockingAccesses) {
159 Operation *accessAncestor = findAncestorOpInRegion(region: topRegion, op: access);
160 // TODO: if the access and write have the same ancestor we could recurse in
161 // the region to know if the access is reachable with more precision.
162 if (accessAncestor == nullptr ||
163 !isReachable(start: writeAncestor, dest: accessAncestor))
164 continue;
165 if (!dominators.dominates(a: firstOverwriteCandidate, b: accessAncestor)) {
166 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
167 << *accessAncestor << "\n");
168 return;
169 }
170 }
171 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
172 << " overwritten by: " << *firstOverwriteCandidate << "\n");
173 opToErase.push_back(write.getOperation());
174}
175
176/// A transfer_write candidate to storeToLoad forwarding must:
177/// 1. Access the same memref with the same indices and vector type as the
178/// transfer_read.
179/// 2. Dominate the transfer_read operation.
180/// If several candidates are available, one must be dominated by all the others
181/// since they are all dominating the same transfer_read. We only consider the
182/// transfer_write dominated by all the other candidates as this will be the
183/// last transfer_write executed before the transfer_read.
184/// If we found such a candidate we can do the forwarding if all the other
185/// potentially aliasing ops that may reach the transfer_read are post-dominated
186/// by the transfer_write.
187void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
188 if (read.hasOutOfBoundsDim())
189 return;
190 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
191 << "\n");
192 SmallVector<Operation *, 8> blockingWrites;
193 vector::TransferWriteOp lastwrite = nullptr;
194 Value source = read.getSource();
195 // Skip subview ops.
196 while (auto subView = source.getDefiningOp<memref::SubViewOp>())
197 source = subView.getSource();
198 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
199 source.getUsers().end());
200 llvm::SmallDenseSet<Operation *, 32> processed;
201 while (!users.empty()) {
202 Operation *user = users.pop_back_val();
203 // If the user has already been processed skip.
204 if (!processed.insert(V: user).second)
205 continue;
206 if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
207 users.append(subView->getUsers().begin(), subView->getUsers().end());
208 continue;
209 }
210 if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211 users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
212 continue;
213 }
214 if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user))
215 continue;
216 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
217 // If there is a write, but we can prove that it is disjoint we can ignore
218 // the write.
219 if (vector::isDisjointTransferSet(
220 cast<VectorTransferOpInterface>(write.getOperation()),
221 cast<VectorTransferOpInterface>(read.getOperation()),
222 /*testDynamicValueUsingBounds=*/true))
223 continue;
224 if (write.getSource() == read.getSource() &&
225 dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
226 if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
227 lastwrite = write;
228 else
229 assert(dominators.dominates(write, lastwrite));
230 continue;
231 }
232 }
233 blockingWrites.push_back(Elt: user);
234 }
235
236 if (lastwrite == nullptr)
237 return;
238
239 Region *topRegion = lastwrite->getParentRegion();
240 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
241 assert(readAncestor &&
242 "read op should be recursively part of the top region");
243
244 for (Operation *write : blockingWrites) {
245 Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write);
246 // TODO: if the store and read have the same ancestor we could recurse in
247 // the region to know if the read is reachable with more precision.
248 if (writeAncestor == nullptr || !isReachable(start: writeAncestor, dest: readAncestor))
249 continue;
250 if (!postDominators.postDominates(lastwrite, write)) {
251 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
252 << *write << "\n");
253 return;
254 }
255 }
256
257 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
258 << " to: " << *read.getOperation() << "\n");
259 read.replaceAllUsesWith(lastwrite.getVector());
260 opToErase.push_back(read.getOperation());
261}
262
263/// Converts OpFoldResults to int64_t shape without unit dims.
264static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
265 SmallVector<int64_t> reducedShape;
266 for (const auto size : mixedSizes) {
267 if (llvm::dyn_cast_if_present<Value>(Val: size)) {
268 reducedShape.push_back(ShapedType::kDynamic);
269 continue;
270 }
271
272 auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue();
273 if (value == 1)
274 continue;
275 reducedShape.push_back(Elt: value.getSExtValue());
276 }
277 return reducedShape;
278}
279
280/// Drops unit dimensions from the input MemRefType.
281static MemRefType dropUnitDims(MemRefType inputType,
282 ArrayRef<OpFoldResult> offsets,
283 ArrayRef<OpFoldResult> sizes,
284 ArrayRef<OpFoldResult> strides) {
285 auto targetShape = getReducedShape(mixedSizes: sizes);
286 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
287 targetShape, inputType, offsets, sizes, strides);
288 return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
289}
290
291/// Creates a rank-reducing memref.subview op that drops unit dims from its
292/// input. Or just returns the input if it was already without unit dims.
293static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
294 mlir::Location loc,
295 Value input) {
296 MemRefType inputType = cast<MemRefType>(input.getType());
297 SmallVector<OpFoldResult> offsets(inputType.getRank(),
298 rewriter.getIndexAttr(0));
299 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(builder&: rewriter, loc, value: input);
300 SmallVector<OpFoldResult> strides(inputType.getRank(),
301 rewriter.getIndexAttr(1));
302 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
303
304 if (canonicalizeStridedLayout(resultType) ==
305 canonicalizeStridedLayout(inputType))
306 return input;
307 return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
308 sizes, strides);
309}
310
311/// Returns the number of dims that aren't unit dims.
312static int getReducedRank(ArrayRef<int64_t> shape) {
313 return llvm::count_if(Range&: shape, P: [](int64_t dimSize) { return dimSize != 1; });
314}
315
316/// Trims non-scalable one dimensions from `oldType` and returns the result
317/// type.
318static VectorType trimNonScalableUnitDims(VectorType oldType) {
319 SmallVector<int64_t> newShape;
320 SmallVector<bool> newScalableDims;
321 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
322 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
323 continue;
324 newShape.push_back(dimSize);
325 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
326 }
327 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
328}
329
330// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
331static FailureOr<Value>
332createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
333 vector::CreateMaskOp op) {
334 auto type = op.getType();
335 VectorType reducedType = trimNonScalableUnitDims(type);
336 if (reducedType.getRank() == type.getRank())
337 return failure();
338
339 SmallVector<Value> reducedOperands;
340 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
341 type.getShape(), type.getScalableDims(), op.getOperands())) {
342 if (dim == 1 && !dimIsScalable) {
343 // If the mask for the unit dim is not a constant of 1, do nothing.
344 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
345 if (!constant || (constant.value() != 1))
346 return failure();
347 continue;
348 }
349 reducedOperands.push_back(operand);
350 }
351 return rewriter
352 .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
353 .getResult();
354}
355
356namespace {
357
358/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
359/// inserting a memref.subview dropping those unit dims. The vector shapes are
360/// also reduced accordingly.
361class TransferReadDropUnitDimsPattern
362 : public OpRewritePattern<vector::TransferReadOp> {
363 using OpRewritePattern::OpRewritePattern;
364
365 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
366 PatternRewriter &rewriter) const override {
367 auto loc = transferReadOp.getLoc();
368 Value vector = transferReadOp.getVector();
369 VectorType vectorType = cast<VectorType>(vector.getType());
370 Value source = transferReadOp.getSource();
371 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
372 // TODO: support tensor types.
373 if (!sourceType)
374 return failure();
375 // TODO: generalize this pattern, relax the requirements here.
376 if (transferReadOp.hasOutOfBoundsDim())
377 return failure();
378 if (!transferReadOp.getPermutationMap().isMinorIdentity())
379 return failure();
380 // Check if the source shape can be further reduced.
381 int reducedRank = getReducedRank(sourceType.getShape());
382 if (reducedRank == sourceType.getRank())
383 return failure();
384 // Check if the reduced vector shape matches the reduced source shape.
385 // Otherwise, this case is not supported yet.
386 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
387 if (reducedRank != reducedVectorType.getRank())
388 return failure();
389 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
390 return getConstantIntValue(ofr: v) != static_cast<int64_t>(0);
391 }))
392 return failure();
393
394 Value maskOp = transferReadOp.getMask();
395 if (maskOp) {
396 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
397 if (!createMaskOp)
398 return rewriter.notifyMatchFailure(
399 transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
400 "currently supported");
401 FailureOr<Value> rankReducedCreateMask =
402 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
403 if (failed(result: rankReducedCreateMask))
404 return failure();
405 maskOp = *rankReducedCreateMask;
406 }
407
408 Value reducedShapeSource =
409 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
410 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
411 SmallVector<Value> zeros(reducedRank, c0);
412 auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank);
413 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
414 auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
415 loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
416 transferReadOp.getPadding(), maskOp,
417 rewriter.getBoolArrayAttr(inBounds));
418 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
419 loc, vectorType, newTransferReadOp);
420 rewriter.replaceOp(transferReadOp, shapeCast);
421
422 return success();
423 }
424};
425
426/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
427/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
428/// vector shapes are also reduced accordingly.
429class TransferWriteDropUnitDimsPattern
430 : public OpRewritePattern<vector::TransferWriteOp> {
431 using OpRewritePattern::OpRewritePattern;
432
433 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
434 PatternRewriter &rewriter) const override {
435 auto loc = transferWriteOp.getLoc();
436 Value vector = transferWriteOp.getVector();
437 VectorType vectorType = cast<VectorType>(vector.getType());
438 Value source = transferWriteOp.getSource();
439 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
440 // TODO: support tensor type.
441 if (!sourceType)
442 return failure();
443 // TODO: generalize this pattern, relax the requirements here.
444 if (transferWriteOp.hasOutOfBoundsDim())
445 return failure();
446 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
447 return failure();
448 // Check if the destination shape can be further reduced.
449 int reducedRank = getReducedRank(sourceType.getShape());
450 if (reducedRank == sourceType.getRank())
451 return failure();
452 // Check if the reduced vector shape matches the reduced destination shape.
453 // Otherwise, this case is not supported yet.
454 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
455 if (reducedRank != reducedVectorType.getRank())
456 return failure();
457 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
458 return getConstantIntValue(ofr: v) != static_cast<int64_t>(0);
459 }))
460 return failure();
461
462 Value maskOp = transferWriteOp.getMask();
463 if (maskOp) {
464 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
465 if (!createMaskOp)
466 return rewriter.notifyMatchFailure(
467 transferWriteOp,
468 "unsupported mask op, only 'vector.create_mask' is "
469 "currently supported");
470 FailureOr<Value> rankReducedCreateMask =
471 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
472 if (failed(result: rankReducedCreateMask))
473 return failure();
474 maskOp = *rankReducedCreateMask;
475 }
476 Value reducedShapeSource =
477 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
478 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
479 SmallVector<Value> zeros(reducedRank, c0);
480 auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank);
481 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
482 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
483 loc, reducedVectorType, vector);
484 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
485 transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
486 identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
487
488 return success();
489 }
490};
491
492} // namespace
493
494/// Creates a memref.collapse_shape collapsing all inner dimensions of the
495/// input starting at `firstDimToCollapse`.
496static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
497 Value input, int64_t firstDimToCollapse) {
498 ShapedType inputType = cast<ShapedType>(input.getType());
499 if (inputType.getRank() == 1)
500 return input;
501 SmallVector<ReassociationIndices> reassociation;
502 for (int64_t i = 0; i < firstDimToCollapse; ++i)
503 reassociation.push_back(Elt: ReassociationIndices{i});
504 ReassociationIndices collapsedIndices;
505 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
506 collapsedIndices.push_back(Elt: i);
507 reassociation.push_back(Elt: collapsedIndices);
508 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
509}
510
511/// Checks that the indices corresponding to dimensions starting at
512/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
513/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
514/// TODO: Extract the logic that writes to outIndices so that this method
515/// simply checks one pre-condition.
516static LogicalResult
517checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
518 SmallVector<Value> &outIndices) {
519 int64_t rank = indices.size();
520 if (firstDimToCollapse >= rank)
521 return failure();
522 for (int64_t i = firstDimToCollapse; i < rank; ++i) {
523 std::optional<int64_t> cst = getConstantIntValue(ofr: indices[i]);
524 if (!cst || cst.value() != 0)
525 return failure();
526 }
527 outIndices = indices;
528 outIndices.resize(N: firstDimToCollapse + 1);
529 return success();
530}
531
532namespace {
533
534/// Rewrites contiguous row-major vector.transfer_read ops by inserting
535/// memref.collapse_shape on the source so that the resulting
536/// vector.transfer_read has a 1D source. Requires the source shape to be
537/// already reduced i.e. without unit dims.
538/// If `targetVectorBitwidth` is provided, the flattening will only happen if
539/// the trailing dimension of the vector read is smaller than the provided
540/// bitwidth.
541class FlattenContiguousRowMajorTransferReadPattern
542 : public OpRewritePattern<vector::TransferReadOp> {
543public:
544 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
545 unsigned vectorBitwidth,
546 PatternBenefit benefit)
547 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
548 targetVectorBitwidth(vectorBitwidth) {}
549
550 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
551 PatternRewriter &rewriter) const override {
552 auto loc = transferReadOp.getLoc();
553 Value vector = transferReadOp.getVector();
554 VectorType vectorType = cast<VectorType>(vector.getType());
555 auto source = transferReadOp.getSource();
556 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
557
558 // 0. Check pre-conditions
559 // Contiguity check is valid on tensors only.
560 if (!sourceType)
561 return failure();
562 // If this is already 0D/1D, there's nothing to do.
563 if (vectorType.getRank() <= 1)
564 return failure();
565 if (!vectorType.getElementType().isSignlessIntOrFloat())
566 return failure();
567 unsigned trailingVectorDimBitwidth =
568 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
569 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
570 return failure();
571 if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType))
572 return failure();
573 // TODO: generalize this pattern, relax the requirements here.
574 if (transferReadOp.hasOutOfBoundsDim())
575 return failure();
576 if (!transferReadOp.getPermutationMap().isMinorIdentity())
577 return failure();
578 if (transferReadOp.getMask())
579 return failure();
580
581 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
582
583 // 1. Collapse the source memref
584 Value collapsedSource =
585 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
586 MemRefType collapsedSourceType =
587 dyn_cast<MemRefType>(collapsedSource.getType());
588 int64_t collapsedRank = collapsedSourceType.getRank();
589 assert(collapsedRank == firstDimToCollapse + 1);
590
591 // 2. Generate input args for a new vector.transfer_read that will read
592 // from the collapsed memref.
593 // 2.1. New dim exprs + affine map
594 SmallVector<AffineExpr, 1> dimExprs{
595 getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())};
596 auto collapsedMap =
597 AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext());
598
599 // 2.2 New indices
600 // If all the collapsed indices are zero then no extra logic is needed.
601 // Otherwise, a new offset/index has to be computed.
602 SmallVector<Value> collapsedIndices;
603 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
604 firstDimToCollapse,
605 collapsedIndices))) {
606 // Copy all the leading indices.
607 SmallVector<Value> indices = transferReadOp.getIndices();
608 collapsedIndices.append(in_start: indices.begin(),
609 in_end: indices.begin() + firstDimToCollapse);
610
611 // Compute the remaining trailing index/offset required for reading from
612 // the collapsed memref:
613 //
614 // offset = 0
615 // for (i = firstDimToCollapse; i < outputRank; ++i)
616 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
617 //
618 // For this example:
619 // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
620 // memref<1x43x2xi32>, vector<1x2xi32>
621 // which would be collapsed to:
622 // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
623 // memref<1x86xi32>, vector<2xi32>
624 // one would get the following offset:
625 // %offset = %arg0 * 43
626 OpFoldResult collapsedOffset =
627 rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
628
629 auto sourceShape = sourceType.getShape();
630 auto collapsedStrides = computeSuffixProduct(sizes: ArrayRef<int64_t>(
631 sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
632
633 // Compute the collapsed offset.
634 ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
635 indices.end());
636 auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
637 collapsedOffset, collapsedStrides, indicesToCollapse);
638 collapsedOffset = affine::makeComposedFoldedAffineApply(
639 rewriter, loc, collapsedExpr, collapsedVals);
640
641 if (collapsedOffset.is<Value>()) {
642 collapsedIndices.push_back(Elt: collapsedOffset.get<Value>());
643 } else {
644 collapsedIndices.push_back(Elt: rewriter.create<arith::ConstantIndexOp>(
645 loc, *getConstantIntValue(ofr: collapsedOffset)));
646 }
647 }
648
649 // 3. Create new vector.transfer_read that reads from the collapsed memref
650 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
651 vectorType.getElementType());
652 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
653 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
654 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
655
656 // 4. Replace the old transfer_read with the new one reading from the
657 // collapsed shape
658 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
659 transferReadOp, cast<VectorType>(vector.getType()), flatRead);
660 return success();
661 }
662
663private:
664 // Minimum bitwidth that the trailing vector dimension should have after
665 // flattening.
666 unsigned targetVectorBitwidth;
667};
668
669/// Rewrites contiguous row-major vector.transfer_write ops by inserting
670/// memref.collapse_shape on the source so that the resulting
671/// vector.transfer_write has a 1D source. Requires the source shape to be
672/// already reduced i.e. without unit dims.
673class FlattenContiguousRowMajorTransferWritePattern
674 : public OpRewritePattern<vector::TransferWriteOp> {
675public:
676 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
677 unsigned vectorBitwidth,
678 PatternBenefit benefit)
679 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
680 targetVectorBitwidth(vectorBitwidth) {}
681
682 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
683 PatternRewriter &rewriter) const override {
684 auto loc = transferWriteOp.getLoc();
685 Value vector = transferWriteOp.getVector();
686 VectorType vectorType = cast<VectorType>(vector.getType());
687 Value source = transferWriteOp.getSource();
688 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
689 // Contiguity check is valid on tensors only.
690 if (!sourceType)
691 return failure();
692 if (vectorType.getRank() <= 1)
693 // Already 0D/1D, nothing to do.
694 return failure();
695 if (!vectorType.getElementType().isSignlessIntOrFloat())
696 return failure();
697 unsigned trailingVectorDimBitwidth =
698 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
699 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
700 return failure();
701 if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType))
702 return failure();
703 int64_t firstContiguousInnerDim =
704 sourceType.getRank() - vectorType.getRank();
705 // TODO: generalize this pattern, relax the requirements here.
706 if (transferWriteOp.hasOutOfBoundsDim())
707 return failure();
708 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
709 return failure();
710 if (transferWriteOp.getMask())
711 return failure();
712 SmallVector<Value> collapsedIndices;
713 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(),
714 firstContiguousInnerDim,
715 collapsedIndices)))
716 return failure();
717
718 Value collapsedSource =
719 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
720 MemRefType collapsedSourceType =
721 cast<MemRefType>(collapsedSource.getType());
722 int64_t collapsedRank = collapsedSourceType.getRank();
723 assert(collapsedRank == firstContiguousInnerDim + 1);
724 SmallVector<AffineExpr, 1> dimExprs{
725 getAffineDimExpr(position: firstContiguousInnerDim, context: rewriter.getContext())};
726 auto collapsedMap =
727 AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext());
728 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
729 vectorType.getElementType());
730 Value flatVector =
731 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
732 vector::TransferWriteOp flatWrite =
733 rewriter.create<vector::TransferWriteOp>(
734 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
735 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
736 rewriter.eraseOp(op: transferWriteOp);
737 return success();
738 }
739
740private:
741 // Minimum bitwidth that the trailing vector dimension should have after
742 // flattening.
743 unsigned targetVectorBitwidth;
744};
745
746/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
747/// to `memref.load` patterns. The `match` method is shared for both
748/// `vector.extract` and `vector.extract_element`.
749template <class VectorExtractOp>
750class RewriteScalarExtractOfTransferReadBase
751 : public OpRewritePattern<VectorExtractOp> {
752 using Base = OpRewritePattern<VectorExtractOp>;
753
754public:
755 RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
756 PatternBenefit benefit,
757 bool allowMultipleUses)
758 : Base::OpRewritePattern(context, benefit),
759 allowMultipleUses(allowMultipleUses) {}
760
761 LogicalResult match(VectorExtractOp extractOp) const override {
762 auto xferOp =
763 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
764 if (!xferOp)
765 return failure();
766 // Check that we are extracting a scalar and not a sub-vector.
767 if (isa<VectorType>(extractOp.getResult().getType()))
768 return failure();
769 // If multiple uses are not allowed, check if xfer has a single use.
770 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
771 return failure();
772 // If multiple uses are allowed, check if all the xfer uses are extract ops.
773 if (allowMultipleUses &&
774 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
775 return isa<vector::ExtractOp, vector::ExtractElementOp>(
776 use.getOwner());
777 }))
778 return failure();
779 // Mask not supported.
780 if (xferOp.getMask())
781 return failure();
782 // Map not supported.
783 if (!xferOp.getPermutationMap().isMinorIdentity())
784 return failure();
785 // Cannot rewrite if the indices may be out of bounds.
786 if (xferOp.hasOutOfBoundsDim())
787 return failure();
788 return success();
789 }
790
791private:
792 bool allowMultipleUses;
793};
794
795/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
796///
797/// All the users of the transfer op must be either `vector.extractelement` or
798/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
799/// transfer ops with any number of users. Otherwise, rewrite only if the
800/// extract op is the single user of the transfer op. Rewriting a single
801/// vector load with multiple scalar loads may negatively affect performance.
802class RewriteScalarExtractElementOfTransferRead
803 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
804 using RewriteScalarExtractOfTransferReadBase::
805 RewriteScalarExtractOfTransferReadBase;
806
807 void rewrite(vector::ExtractElementOp extractOp,
808 PatternRewriter &rewriter) const override {
809 // Construct scalar load.
810 auto loc = extractOp.getLoc();
811 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
812 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
813 xferOp.getIndices().end());
814 if (extractOp.getPosition()) {
815 AffineExpr sym0, sym1;
816 bindSymbols(extractOp.getContext(), sym0, sym1);
817 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
818 rewriter, loc, sym0 + sym1,
819 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
820 if (ofr.is<Value>()) {
821 newIndices[newIndices.size() - 1] = ofr.get<Value>();
822 } else {
823 newIndices[newIndices.size() - 1] =
824 rewriter.create<arith::ConstantIndexOp>(loc,
825 *getConstantIntValue(ofr));
826 }
827 }
828 if (isa<MemRefType>(xferOp.getSource().getType())) {
829 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
830 newIndices);
831 } else {
832 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
833 extractOp, xferOp.getSource(), newIndices);
834 }
835 }
836};
837
838/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
839/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
840///
841/// All the users of the transfer op must be either `vector.extractelement` or
842/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
843/// transfer ops with any number of users. Otherwise, rewrite only if the
844/// extract op is the single user of the transfer op. Rewriting a single
845/// vector load with multiple scalar loads may negatively affect performance.
846class RewriteScalarExtractOfTransferRead
847 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
848 using RewriteScalarExtractOfTransferReadBase::
849 RewriteScalarExtractOfTransferReadBase;
850
851 void rewrite(vector::ExtractOp extractOp,
852 PatternRewriter &rewriter) const override {
853 // Construct scalar load.
854 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
855 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
856 xferOp.getIndices().end());
857 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
858 assert(pos.is<Attribute>() && "Unexpected non-constant index");
859 int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt();
860 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
861 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
862 rewriter, extractOp.getLoc(),
863 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
864 if (ofr.is<Value>()) {
865 newIndices[idx] = ofr.get<Value>();
866 } else {
867 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
868 extractOp.getLoc(), *getConstantIntValue(ofr));
869 }
870 }
871 if (isa<MemRefType>(xferOp.getSource().getType())) {
872 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
873 newIndices);
874 } else {
875 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
876 extractOp, xferOp.getSource(), newIndices);
877 }
878 }
879};
880
881/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
882/// to memref.store.
883class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
884 using OpRewritePattern::OpRewritePattern;
885
886 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
887 PatternRewriter &rewriter) const override {
888 // Must be a scalar write.
889 auto vecType = xferOp.getVectorType();
890 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
891 return failure();
892 // Mask not supported.
893 if (xferOp.getMask())
894 return failure();
895 // Map not supported.
896 if (!xferOp.getPermutationMap().isMinorIdentity())
897 return failure();
898 // Only float and integer element types are supported.
899 Value scalar;
900 if (vecType.getRank() == 0) {
901 // vector.extract does not support vector<f32> etc., so use
902 // vector.extractelement instead.
903 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
904 xferOp.getVector());
905 } else {
906 SmallVector<int64_t> pos(vecType.getRank(), 0);
907 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
908 xferOp.getVector(), pos);
909 }
910 // Construct a scalar store.
911 if (isa<MemRefType>(xferOp.getSource().getType())) {
912 rewriter.replaceOpWithNewOp<memref::StoreOp>(
913 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
914 } else {
915 rewriter.replaceOpWithNewOp<tensor::InsertOp>(
916 xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
917 }
918 return success();
919 }
920};
921
922} // namespace
923
924void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
925 Operation *rootOp) {
926 TransferOptimization opt(rewriter, rootOp);
927 // Run store to load forwarding first since it can expose more dead store
928 // opportunity.
929 rootOp->walk([&](vector::TransferReadOp read) {
930 if (isa<MemRefType>(read.getShapedType()))
931 opt.storeToLoadForwarding(read: read);
932 });
933 opt.removeDeadOp();
934 rootOp->walk([&](vector::TransferWriteOp write) {
935 if (isa<MemRefType>(write.getShapedType()))
936 opt.deadStoreOp(write: write);
937 });
938 opt.removeDeadOp();
939}
940
941void mlir::vector::populateScalarVectorTransferLoweringPatterns(
942 RewritePatternSet &patterns, PatternBenefit benefit,
943 bool allowMultipleUses) {
944 patterns.add<RewriteScalarExtractElementOfTransferRead,
945 RewriteScalarExtractOfTransferRead>(arg: patterns.getContext(),
946 args&: benefit, args&: allowMultipleUses);
947 patterns.add<RewriteScalarWrite>(arg: patterns.getContext(), args&: benefit);
948}
949
950void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
951 RewritePatternSet &patterns, PatternBenefit benefit) {
952 patterns
953 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
954 arg: patterns.getContext(), args&: benefit);
955 populateShapeCastFoldingPatterns(patterns);
956}
957
958void mlir::vector::populateFlattenVectorTransferPatterns(
959 RewritePatternSet &patterns, unsigned targetVectorBitwidth,
960 PatternBenefit benefit) {
961 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
962 FlattenContiguousRowMajorTransferWritePattern>(
963 arg: patterns.getContext(), args&: targetVectorBitwidth, args&: benefit);
964 populateShapeCastFoldingPatterns(patterns, benefit);
965 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
966}
967

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp