1//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with optimizing transfer_read and
10// transfer_write ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Utils/IndexingUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24#include "mlir/IR/Dominance.h"
25#include "mlir/Interfaces/SideEffectInterfaces.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/StringRef.h"
28#include "llvm/Support/Debug.h"
29
30#define DEBUG_TYPE "vector-transfer-opt"
31
32#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33
34using namespace mlir;
35
36/// Return the ancestor op in the region or nullptr if the region is not
37/// an ancestor of the op.
38static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
39 for (; op != nullptr && op->getParentRegion() != region;
40 op = op->getParentOp())
41 ;
42 return op;
43}
44
45namespace {
46
47class TransferOptimization {
48public:
49 TransferOptimization(RewriterBase &rewriter, Operation *op)
50 : rewriter(rewriter), dominators(op), postDominators(op) {}
51 void deadStoreOp(vector::TransferWriteOp);
52 void storeToLoadForwarding(vector::TransferReadOp);
53 void removeDeadOp() {
54 for (Operation *op : opToErase)
55 rewriter.eraseOp(op);
56 opToErase.clear();
57 }
58
59private:
60 RewriterBase &rewriter;
61 bool isReachable(Operation *start, Operation *dest);
62 DominanceInfo dominators;
63 PostDominanceInfo postDominators;
64 std::vector<Operation *> opToErase;
65};
66
67} // namespace
68/// Return true if there is a path from start operation to dest operation,
69/// otherwise return false. The operations have to be in the same region.
70bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
71 assert(start->getParentRegion() == dest->getParentRegion() &&
72 "This function only works for ops i the same region");
73 // Simple case where the start op dominate the destination.
74 if (dominators.dominates(a: start, b: dest))
75 return true;
76 return start->getBlock()->isReachable(other: dest->getBlock());
77}
78
79/// For transfer_write to overwrite fully another transfer_write must:
80/// 1. Access the same memref with the same indices and vector type.
81/// 2. Post-dominate the other transfer_write operation.
82/// If several candidates are available, one must be post-dominated by all the
83/// others since they are all post-dominating the same transfer_write. We only
84/// consider the transfer_write post-dominated by all the other candidates as
85/// this will be the first transfer_write executed after the potentially dead
86/// transfer_write.
87/// If we found such an overwriting transfer_write we know that the original
88/// transfer_write is dead if all reads that can be reached from the potentially
89/// dead transfer_write are dominated by the overwriting transfer_write.
90void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
91 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
92 << "\n");
93 llvm::SmallVector<Operation *, 8> blockingAccesses;
94 Operation *firstOverwriteCandidate = nullptr;
95 Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(Val: write.getBase()));
96 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
97 source.getUsers().end());
98 llvm::SmallDenseSet<Operation *, 32> processed;
99 while (!users.empty()) {
100 Operation *user = users.pop_back_val();
101 // If the user has already been processed skip.
102 if (!processed.insert(V: user).second)
103 continue;
104 if (isa<ViewLikeOpInterface>(Val: user)) {
105 users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end());
106 continue;
107 }
108 if (isMemoryEffectFree(op: user))
109 continue;
110 if (user == write.getOperation())
111 continue;
112 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(Val: user)) {
113 // Check candidate that can override the store.
114 if (memref::isSameViewOrTrivialAlias(
115 a: cast<MemrefValue>(Val: nextWrite.getBase()),
116 b: cast<MemrefValue>(Val: write.getBase())) &&
117 checkSameValueWAW(write: nextWrite, priorWrite: write) &&
118 postDominators.postDominates(a: nextWrite, b: write)) {
119 if (firstOverwriteCandidate == nullptr ||
120 postDominators.postDominates(a: firstOverwriteCandidate, b: nextWrite))
121 firstOverwriteCandidate = nextWrite;
122 else
123 assert(
124 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
125 continue;
126 }
127 }
128 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(Val: user)) {
129 // Don't need to consider disjoint accesses.
130 if (vector::isDisjointTransferSet(
131 transferA: cast<VectorTransferOpInterface>(Val: write.getOperation()),
132 transferB: cast<VectorTransferOpInterface>(Val: transferOp.getOperation()),
133 /*testDynamicValueUsingBounds=*/true))
134 continue;
135 }
136 blockingAccesses.push_back(Elt: user);
137 }
138 if (firstOverwriteCandidate == nullptr)
139 return;
140 Region *topRegion = firstOverwriteCandidate->getParentRegion();
141 Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write);
142 assert(writeAncestor &&
143 "write op should be recursively part of the top region");
144
145 for (Operation *access : blockingAccesses) {
146 Operation *accessAncestor = findAncestorOpInRegion(region: topRegion, op: access);
147 // TODO: if the access and write have the same ancestor we could recurse in
148 // the region to know if the access is reachable with more precision.
149 if (accessAncestor == nullptr ||
150 !isReachable(start: writeAncestor, dest: accessAncestor))
151 continue;
152 if (!dominators.dominates(a: firstOverwriteCandidate, b: accessAncestor)) {
153 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
154 << *accessAncestor << "\n");
155 return;
156 }
157 }
158 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
159 << " overwritten by: " << *firstOverwriteCandidate << "\n");
160 opToErase.push_back(x: write.getOperation());
161}
162
163/// A transfer_write candidate to storeToLoad forwarding must:
164/// 1. Access the same memref with the same indices and vector type as the
165/// transfer_read.
166/// 2. Dominate the transfer_read operation.
167/// If several candidates are available, one must be dominated by all the others
168/// since they are all dominating the same transfer_read. We only consider the
169/// transfer_write dominated by all the other candidates as this will be the
170/// last transfer_write executed before the transfer_read.
171/// If we found such a candidate we can do the forwarding if all the other
172/// potentially aliasing ops that may reach the transfer_read are post-dominated
173/// by the transfer_write.
174void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
175 if (read.hasOutOfBoundsDim())
176 return;
177 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
178 << "\n");
179 SmallVector<Operation *, 8> blockingWrites;
180 vector::TransferWriteOp lastwrite = nullptr;
181 Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(Val: read.getBase()));
182 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
183 source.getUsers().end());
184 llvm::SmallDenseSet<Operation *, 32> processed;
185 while (!users.empty()) {
186 Operation *user = users.pop_back_val();
187 // If the user has already been processed skip.
188 if (!processed.insert(V: user).second)
189 continue;
190 if (isa<ViewLikeOpInterface>(Val: user)) {
191 users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end());
192 continue;
193 }
194 if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user))
195 continue;
196 if (auto write = dyn_cast<vector::TransferWriteOp>(Val: user)) {
197 // If there is a write, but we can prove that it is disjoint we can ignore
198 // the write.
199 if (vector::isDisjointTransferSet(
200 transferA: cast<VectorTransferOpInterface>(Val: write.getOperation()),
201 transferB: cast<VectorTransferOpInterface>(Val: read.getOperation()),
202 /*testDynamicValueUsingBounds=*/true))
203 continue;
204 if (memref::isSameViewOrTrivialAlias(
205 a: cast<MemrefValue>(Val: read.getBase()),
206 b: cast<MemrefValue>(Val: write.getBase())) &&
207 dominators.dominates(a: write, b: read) && checkSameValueRAW(defWrite: write, read)) {
208 if (lastwrite == nullptr || dominators.dominates(a: lastwrite, b: write))
209 lastwrite = write;
210 else
211 assert(dominators.dominates(write, lastwrite));
212 continue;
213 }
214 }
215 blockingWrites.push_back(Elt: user);
216 }
217
218 if (lastwrite == nullptr)
219 return;
220
221 Region *topRegion = lastwrite->getParentRegion();
222 Operation *readAncestor = findAncestorOpInRegion(region: topRegion, op: read);
223 assert(readAncestor &&
224 "read op should be recursively part of the top region");
225
226 for (Operation *write : blockingWrites) {
227 Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write);
228 // TODO: if the store and read have the same ancestor we could recurse in
229 // the region to know if the read is reachable with more precision.
230 if (writeAncestor == nullptr || !isReachable(start: writeAncestor, dest: readAncestor))
231 continue;
232 if (!postDominators.postDominates(a: lastwrite, b: write)) {
233 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
234 << *write << "\n");
235 return;
236 }
237 }
238
239 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
240 << " to: " << *read.getOperation() << "\n");
241 read.replaceAllUsesWith(newValue: lastwrite.getVector());
242 opToErase.push_back(x: read.getOperation());
243}
244
245/// Converts OpFoldResults to int64_t shape without unit dims.
246static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
247 SmallVector<int64_t> reducedShape;
248 for (const auto size : mixedSizes) {
249 if (llvm::dyn_cast_if_present<Value>(Val: size)) {
250 reducedShape.push_back(Elt: ShapedType::kDynamic);
251 continue;
252 }
253
254 auto value = cast<IntegerAttr>(Val: cast<Attribute>(Val: size)).getValue();
255 if (value == 1)
256 continue;
257 reducedShape.push_back(Elt: value.getSExtValue());
258 }
259 return reducedShape;
260}
261
262/// Drops unit dimensions from the input MemRefType.
263static MemRefType dropUnitDims(MemRefType inputType,
264 ArrayRef<OpFoldResult> offsets,
265 ArrayRef<OpFoldResult> sizes,
266 ArrayRef<OpFoldResult> strides) {
267 auto targetShape = getReducedShape(mixedSizes: sizes);
268 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269 resultShape: targetShape, sourceMemRefType: inputType, staticOffsets: offsets, staticSizes: sizes, staticStrides: strides);
270 return rankReducedType.canonicalizeStridedLayout();
271}
272
273/// Creates a rank-reducing memref.subview op that drops unit dims from its
274/// input. Or just returns the input if it was already without unit dims.
275static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
276 mlir::Location loc,
277 Value input) {
278 MemRefType inputType = cast<MemRefType>(Val: input.getType());
279 SmallVector<OpFoldResult> offsets(inputType.getRank(),
280 rewriter.getIndexAttr(value: 0));
281 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(builder&: rewriter, loc, value: input);
282 SmallVector<OpFoldResult> strides(inputType.getRank(),
283 rewriter.getIndexAttr(value: 1));
284 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
285
286 if (resultType.canonicalizeStridedLayout() ==
287 inputType.canonicalizeStridedLayout())
288 return input;
289 return rewriter.create<memref::SubViewOp>(location: loc, args&: resultType, args&: input, args&: offsets,
290 args&: sizes, args&: strides);
291}
292
293/// Returns the number of dims that aren't unit dims.
294static int getReducedRank(ArrayRef<int64_t> shape) {
295 return llvm::count_if(Range&: shape, P: [](int64_t dimSize) { return dimSize != 1; });
296}
297
298/// Trims non-scalable one dimensions from `oldType` and returns the result
299/// type.
300static VectorType trimNonScalableUnitDims(VectorType oldType) {
301 SmallVector<int64_t> newShape;
302 SmallVector<bool> newScalableDims;
303 for (auto [dimIdx, dimSize] : llvm::enumerate(First: oldType.getShape())) {
304 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
305 continue;
306 newShape.push_back(Elt: dimSize);
307 newScalableDims.push_back(Elt: oldType.getScalableDims()[dimIdx]);
308 }
309 return VectorType::get(shape: newShape, elementType: oldType.getElementType(), scalableDims: newScalableDims);
310}
311
312// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
313static FailureOr<Value>
314createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
315 vector::CreateMaskOp op) {
316 auto type = op.getType();
317 VectorType reducedType = trimNonScalableUnitDims(oldType: type);
318 if (reducedType.getRank() == type.getRank())
319 return failure();
320
321 SmallVector<Value> reducedOperands;
322 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323 t: type.getShape(), u: type.getScalableDims(), args: op.getOperands())) {
324 if (dim == 1 && !dimIsScalable) {
325 // If the mask for the unit dim is not a constant of 1, do nothing.
326 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327 if (!constant || (constant.value() != 1))
328 return failure();
329 continue;
330 }
331 reducedOperands.push_back(Elt: operand);
332 }
333 return rewriter
334 .create<vector::CreateMaskOp>(location: loc, args&: reducedType, args&: reducedOperands)
335 .getResult();
336}
337
338namespace {
339
340/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
341/// inserting a memref.subview dropping those unit dims. The vector shapes are
342/// also reduced accordingly.
343class TransferReadDropUnitDimsPattern
344 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
345 using MaskableOpRewritePattern::MaskableOpRewritePattern;
346
347 FailureOr<Value>
348 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
349 vector::MaskingOpInterface maskingOp,
350 PatternRewriter &rewriter) const override {
351 auto loc = transferReadOp.getLoc();
352 Value vector = transferReadOp.getVector();
353 VectorType vectorType = cast<VectorType>(Val: vector.getType());
354 Value source = transferReadOp.getBase();
355 MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType());
356 // TODO: support tensor types.
357 if (!sourceType)
358 return failure();
359 // TODO: generalize this pattern, relax the requirements here.
360 if (transferReadOp.hasOutOfBoundsDim())
361 return failure();
362 if (!transferReadOp.getPermutationMap().isMinorIdentity())
363 return failure();
364 // Check if the source shape can be further reduced.
365 int reducedRank = getReducedRank(shape: sourceType.getShape());
366 if (reducedRank == sourceType.getRank())
367 return failure();
368 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
369 // out.
370 if (reducedRank == 0 && maskingOp)
371 return failure();
372 // Check if the reduced vector shape matches the reduced source shape.
373 // Otherwise, this case is not supported yet.
374 VectorType reducedVectorType = trimNonScalableUnitDims(oldType: vectorType);
375 if (reducedRank != reducedVectorType.getRank())
376 return failure();
377 if (llvm::any_of(Range: transferReadOp.getIndices(), P: [](Value v) {
378 return getConstantIntValue(ofr: v) != static_cast<int64_t>(0);
379 }))
380 return failure();
381
382 Value maskOp = transferReadOp.getMask();
383 if (maskOp) {
384 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
385 if (!createMaskOp)
386 return rewriter.notifyMatchFailure(
387 arg&: transferReadOp, msg: "unsupported mask op, only 'vector.create_mask' is "
388 "currently supported");
389 FailureOr<Value> rankReducedCreateMask =
390 createMaskDropNonScalableUnitDims(rewriter, loc, op: createMaskOp);
391 if (failed(Result: rankReducedCreateMask))
392 return failure();
393 maskOp = *rankReducedCreateMask;
394 }
395
396 Value reducedShapeSource =
397 rankReducingSubviewDroppingUnitDims(rewriter, loc, input: source);
398 Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
399 SmallVector<Value> zeros(reducedRank, c0);
400 auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank);
401 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
402 Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
403 location: loc, args&: reducedVectorType, args&: reducedShapeSource, args&: zeros, args&: identityMap,
404 args: transferReadOp.getPadding(), args&: maskOp,
405 args: rewriter.getBoolArrayAttr(values: inBounds));
406
407 if (maskingOp) {
408 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
409 location: loc, args: reducedVectorType.cloneWith(shape: std::nullopt, elementType: rewriter.getI1Type()),
410 args: maskingOp.getMask());
411 newTransferReadOp = mlir::vector::maskOperation(
412 builder&: rewriter, maskableOp: newTransferReadOp, mask: shapeCastMask);
413 }
414
415 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
416 location: loc, args&: vectorType, args: newTransferReadOp->getResults()[0]);
417
418 return shapeCast;
419 }
420};
421
422/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
423/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
424/// vector shapes are also reduced accordingly.
425class TransferWriteDropUnitDimsPattern
426 : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
427 using MaskableOpRewritePattern::MaskableOpRewritePattern;
428
429 FailureOr<Value>
430 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
431 vector::MaskingOpInterface maskingOp,
432 PatternRewriter &rewriter) const override {
433 auto loc = transferWriteOp.getLoc();
434 Value vector = transferWriteOp.getVector();
435 VectorType vectorType = cast<VectorType>(Val: vector.getType());
436 Value source = transferWriteOp.getBase();
437 MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType());
438 // TODO: support tensor type.
439 if (!sourceType)
440 return failure();
441 // TODO: generalize this pattern, relax the requirements here.
442 if (transferWriteOp.hasOutOfBoundsDim())
443 return failure();
444 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
445 return failure();
446 // Check if the destination shape can be further reduced.
447 int reducedRank = getReducedRank(shape: sourceType.getShape());
448 if (reducedRank == sourceType.getRank())
449 return failure();
450 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
451 // out.
452 if (reducedRank == 0 && maskingOp)
453 return failure();
454 // Check if the reduced vector shape matches the reduced destination shape.
455 // Otherwise, this case is not supported yet.
456 VectorType reducedVectorType = trimNonScalableUnitDims(oldType: vectorType);
457 if (reducedRank != reducedVectorType.getRank())
458 return failure();
459 if (llvm::any_of(Range: transferWriteOp.getIndices(), P: [](Value v) {
460 return getConstantIntValue(ofr: v) != static_cast<int64_t>(0);
461 }))
462 return failure();
463
464 Value maskOp = transferWriteOp.getMask();
465 if (maskOp) {
466 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
467 if (!createMaskOp)
468 return rewriter.notifyMatchFailure(
469 arg&: transferWriteOp,
470 msg: "unsupported mask op, only 'vector.create_mask' is "
471 "currently supported");
472 FailureOr<Value> rankReducedCreateMask =
473 createMaskDropNonScalableUnitDims(rewriter, loc, op: createMaskOp);
474 if (failed(Result: rankReducedCreateMask))
475 return failure();
476 maskOp = *rankReducedCreateMask;
477 }
478 Value reducedShapeSource =
479 rankReducingSubviewDroppingUnitDims(rewriter, loc, input: source);
480 Value c0 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
481 SmallVector<Value> zeros(reducedRank, c0);
482 auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank);
483 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
484 auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
485 location: loc, args&: reducedVectorType, args&: vector);
486 Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
487 location: loc, args: Type(), args&: shapeCastSrc, args&: reducedShapeSource, args&: zeros, args&: identityMap,
488 args&: maskOp, args: rewriter.getBoolArrayAttr(values: inBounds));
489
490 if (maskingOp) {
491 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
492 location: loc, args: reducedVectorType.cloneWith(shape: std::nullopt, elementType: rewriter.getI1Type()),
493 args: maskingOp.getMask());
494 newXferWrite =
495 mlir::vector::maskOperation(builder&: rewriter, maskableOp: newXferWrite, mask: shapeCastMask);
496 }
497
498 if (transferWriteOp.hasPureTensorSemantics())
499 return newXferWrite->getResults()[0];
500
501 // With Memref semantics, there's no return value. Use empty value to signal
502 // success.
503 return Value();
504 }
505};
506
507} // namespace
508
509/// Creates a memref.collapse_shape collapsing all inner dimensions of the
510/// input starting at `firstDimToCollapse`.
511static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
512 Value input, int64_t firstDimToCollapse) {
513 ShapedType inputType = cast<ShapedType>(Val: input.getType());
514 if (inputType.getRank() == 1)
515 return input;
516 SmallVector<ReassociationIndices> reassociation;
517 for (int64_t i = 0; i < firstDimToCollapse; ++i)
518 reassociation.push_back(Elt: ReassociationIndices{i});
519 ReassociationIndices collapsedIndices;
520 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521 collapsedIndices.push_back(Elt: i);
522 reassociation.push_back(Elt: collapsedIndices);
523 return rewriter.create<memref::CollapseShapeOp>(location: loc, args&: input, args&: reassociation);
524}
525
526/// Returns the new indices that collapses the inner dimensions starting from
527/// the `firstDimToCollapse` dimension.
528static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
529 Location loc,
530 ArrayRef<int64_t> shape,
531 ValueRange indices,
532 int64_t firstDimToCollapse) {
533 assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
534
535 // If all the collapsed indices are zero then no extra logic is needed.
536 // Otherwise, a new offset/index has to be computed.
537 SmallVector<Value> indicesAfterCollapsing(
538 indices.begin(), indices.begin() + firstDimToCollapse);
539 SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
540 indices.end());
541 if (llvm::all_of(Range&: indicesToCollapse, P: isZeroInteger)) {
542 indicesAfterCollapsing.push_back(Elt: indicesToCollapse[0]);
543 return indicesAfterCollapsing;
544 }
545
546 // Compute the remaining trailing index/offset required for reading from
547 // the collapsed memref:
548 //
549 // offset = 0
550 // for (i = firstDimToCollapse; i < outputRank; ++i)
551 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
552 //
553 // For this example:
554 // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
555 // memref<1x43x2xi32>, vector<1x2xi32>
556 // which would be collapsed to:
557 // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
558 // memref<1x86xi32>, vector<2xi32>
559 // one would get the following offset:
560 // %offset = %arg0 * 43
561 OpFoldResult collapsedOffset =
562 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult();
563
564 auto collapsedStrides = computeSuffixProduct(
565 sizes: ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
566
567 // Compute the collapsed offset.
568 auto &&[collapsedExpr, collapsedVals] =
569 computeLinearIndex(sourceOffset: collapsedOffset, strides: collapsedStrides, indices: indicesToCollapse);
570 collapsedOffset = affine::makeComposedFoldedAffineApply(
571 b&: rewriter, loc, expr: collapsedExpr, operands: collapsedVals);
572
573 if (auto value = dyn_cast<Value>(Val&: collapsedOffset)) {
574 indicesAfterCollapsing.push_back(Elt: value);
575 } else {
576 indicesAfterCollapsing.push_back(Elt: rewriter.create<arith::ConstantIndexOp>(
577 location: loc, args: *getConstantIntValue(ofr: collapsedOffset)));
578 }
579
580 return indicesAfterCollapsing;
581}
582
583namespace {
584/// Rewrites contiguous row-major vector.transfer_read ops by inserting
585/// memref.collapse_shape on the source so that the resulting
586/// vector.transfer_read has a 1D source. Requires the source shape to be
587/// already reduced i.e. without unit dims.
588///
589/// If `targetVectorBitwidth` is provided, the flattening will only happen if
590/// the trailing dimension of the vector read is smaller than the provided
591/// bitwidth.
592class FlattenContiguousRowMajorTransferReadPattern
593 : public OpRewritePattern<vector::TransferReadOp> {
594public:
595 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
596 unsigned vectorBitwidth,
597 PatternBenefit benefit)
598 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
599 targetVectorBitwidth(vectorBitwidth) {}
600
601 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
602 PatternRewriter &rewriter) const override {
603 auto loc = transferReadOp.getLoc();
604 Value vector = transferReadOp.getVector();
605 VectorType vectorType = cast<VectorType>(Val: vector.getType());
606 auto source = transferReadOp.getBase();
607 MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType());
608
609 // 0. Check pre-conditions
610 // Contiguity check is valid on tensors only.
611 if (!sourceType)
612 return failure();
613 // If this is already 0D/1D, there's nothing to do.
614 if (vectorType.getRank() <= 1)
615 return failure();
616 if (!vectorType.getElementType().isSignlessIntOrFloat())
617 return failure();
618 unsigned trailingVectorDimBitwidth =
619 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
620 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
621 return failure();
622 if (!vector::isContiguousSlice(memrefType: sourceType, vectorType))
623 return failure();
624 // TODO: generalize this pattern, relax the requirements here.
625 if (transferReadOp.hasOutOfBoundsDim())
626 return failure();
627 if (!transferReadOp.getPermutationMap().isMinorIdentity())
628 return failure();
629 if (transferReadOp.getMask())
630 return failure();
631
632 // Determine the first memref dimension to collapse - just enough so we can
633 // read a flattened vector.
634 int64_t firstDimToCollapse =
635 sourceType.getRank() -
636 vectorType.getShape().drop_while(Pred: [](auto v) { return v == 1; }).size();
637
638 // 1. Collapse the source memref
639 Value collapsedSource =
640 collapseInnerDims(rewriter, loc, input: source, firstDimToCollapse);
641 MemRefType collapsedSourceType =
642 cast<MemRefType>(Val: collapsedSource.getType());
643 int64_t collapsedRank = collapsedSourceType.getRank();
644 assert(collapsedRank == firstDimToCollapse + 1);
645
646 // 2. Generate input args for a new vector.transfer_read that will read
647 // from the collapsed memref.
648 // 2.1. New dim exprs + affine map
649 SmallVector<AffineExpr, 1> dimExprs{
650 getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())};
651 auto collapsedMap =
652 AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext());
653
654 // 2.2 New indices
655 SmallVector<Value> collapsedIndices =
656 getCollapsedIndices(rewriter, loc, shape: sourceType.getShape(),
657 indices: transferReadOp.getIndices(), firstDimToCollapse);
658
659 // 3. Create new vector.transfer_read that reads from the collapsed memref
660 VectorType flatVectorType = VectorType::get(shape: {vectorType.getNumElements()},
661 elementType: vectorType.getElementType());
662 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
663 location: loc, args&: flatVectorType, args&: collapsedSource, args&: collapsedIndices,
664 args: transferReadOp.getPadding(), args&: collapsedMap);
665 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr(values: {true}));
666
667 // 4. Replace the old transfer_read with the new one reading from the
668 // collapsed shape
669 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
670 op: transferReadOp, args: cast<VectorType>(Val: vector.getType()), args&: flatRead);
671 return success();
672 }
673
674private:
675 // Minimum bitwidth that the trailing vector dimension should have after
676 // flattening.
677 unsigned targetVectorBitwidth;
678};
679
680/// Rewrites contiguous row-major vector.transfer_write ops by inserting
681/// memref.collapse_shape on the source so that the resulting
682/// vector.transfer_write has a 1D source. Requires the source shape to be
683/// already reduced i.e. without unit dims.
684///
685/// If `targetVectorBitwidth` is provided, the flattening will only happen if
686/// the trailing dimension of the vector read is smaller than the provided
687/// bitwidth.
688class FlattenContiguousRowMajorTransferWritePattern
689 : public OpRewritePattern<vector::TransferWriteOp> {
690public:
691 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
692 unsigned vectorBitwidth,
693 PatternBenefit benefit)
694 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
695 targetVectorBitwidth(vectorBitwidth) {}
696
697 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
698 PatternRewriter &rewriter) const override {
699 auto loc = transferWriteOp.getLoc();
700 Value vector = transferWriteOp.getVector();
701 VectorType vectorType = cast<VectorType>(Val: vector.getType());
702 Value source = transferWriteOp.getBase();
703 MemRefType sourceType = dyn_cast<MemRefType>(Val: source.getType());
704
705 // 0. Check pre-conditions
706 // Contiguity check is valid on tensors only.
707 if (!sourceType)
708 return failure();
709 // If this is already 0D/1D, there's nothing to do.
710 if (vectorType.getRank() <= 1)
711 // Already 0D/1D, nothing to do.
712 return failure();
713 if (!vectorType.getElementType().isSignlessIntOrFloat())
714 return failure();
715 unsigned trailingVectorDimBitwidth =
716 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
717 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
718 return failure();
719 if (!vector::isContiguousSlice(memrefType: sourceType, vectorType))
720 return failure();
721 // TODO: generalize this pattern, relax the requirements here.
722 if (transferWriteOp.hasOutOfBoundsDim())
723 return failure();
724 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
725 return failure();
726 if (transferWriteOp.getMask())
727 return failure();
728
729 // Determine the first memref dimension to collapse - just enough so we can
730 // read a flattened vector.
731 int64_t firstDimToCollapse =
732 sourceType.getRank() -
733 vectorType.getShape().drop_while(Pred: [](auto v) { return v == 1; }).size();
734
735 // 1. Collapse the source memref
736 Value collapsedSource =
737 collapseInnerDims(rewriter, loc, input: source, firstDimToCollapse);
738 MemRefType collapsedSourceType =
739 cast<MemRefType>(Val: collapsedSource.getType());
740 int64_t collapsedRank = collapsedSourceType.getRank();
741 assert(collapsedRank == firstDimToCollapse + 1);
742
743 // 2. Generate input args for a new vector.transfer_read that will read
744 // from the collapsed memref.
745 // 2.1. New dim exprs + affine map
746 SmallVector<AffineExpr, 1> dimExprs{
747 getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())};
748 auto collapsedMap =
749 AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext());
750
751 // 2.2 New indices
752 SmallVector<Value> collapsedIndices =
753 getCollapsedIndices(rewriter, loc, shape: sourceType.getShape(),
754 indices: transferWriteOp.getIndices(), firstDimToCollapse);
755
756 // 3. Create new vector.transfer_write that writes to the collapsed memref
757 VectorType flatVectorType = VectorType::get(shape: {vectorType.getNumElements()},
758 elementType: vectorType.getElementType());
759 Value flatVector =
760 rewriter.create<vector::ShapeCastOp>(location: loc, args&: flatVectorType, args&: vector);
761 vector::TransferWriteOp flatWrite =
762 rewriter.create<vector::TransferWriteOp>(
763 location: loc, args&: flatVector, args&: collapsedSource, args&: collapsedIndices, args&: collapsedMap);
764 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr(values: {true}));
765
766 // 4. Replace the old transfer_write with the new one writing the
767 // collapsed shape
768 rewriter.eraseOp(op: transferWriteOp);
769 return success();
770 }
771
772private:
773 // Minimum bitwidth that the trailing vector dimension should have after
774 // flattening.
775 unsigned targetVectorBitwidth;
776};
777
778/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
779///
780/// All the users of the transfer op must be `vector.extract` ops. If
781/// `allowMultipleUses` is set to true, rewrite transfer ops with any number of
782/// users. Otherwise, rewrite only if the extract op is the single user of the
783/// transfer op. Rewriting a single vector load with multiple scalar loads may
784/// negatively affect performance.
785class RewriteScalarExtractOfTransferRead
786 : public OpRewritePattern<vector::ExtractOp> {
787public:
788 RewriteScalarExtractOfTransferRead(MLIRContext *context,
789 PatternBenefit benefit,
790 bool allowMultipleUses)
791 : OpRewritePattern(context, benefit),
792 allowMultipleUses(allowMultipleUses) {}
793
794 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
795 PatternRewriter &rewriter) const override {
796 // Match phase.
797 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
798 if (!xferOp)
799 return failure();
800 // Check that we are extracting a scalar and not a sub-vector.
801 if (isa<VectorType>(Val: extractOp.getResult().getType()))
802 return failure();
803 // If multiple uses are not allowed, check if xfer has a single use.
804 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
805 return failure();
806 // If multiple uses are allowed, check if all the xfer uses are extract ops.
807 if (allowMultipleUses &&
808 !llvm::all_of(Range: xferOp->getUses(), P: [](OpOperand &use) {
809 return isa<vector::ExtractOp>(Val: use.getOwner());
810 }))
811 return failure();
812 // Mask not supported.
813 if (xferOp.getMask())
814 return failure();
815 // Map not supported.
816 if (!xferOp.getPermutationMap().isMinorIdentity())
817 return failure();
818 // Cannot rewrite if the indices may be out of bounds.
819 if (xferOp.hasOutOfBoundsDim())
820 return failure();
821
822 // Rewrite phase: construct scalar load.
823 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
824 xferOp.getIndices().end());
825 for (auto [i, pos] : llvm::enumerate(First: extractOp.getMixedPosition())) {
826 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
827
828 // Compute affine expression `newIndices[idx] + pos` where `pos` can be
829 // either a constant or a value.
830 OpFoldResult composedIdx;
831 if (auto attr = dyn_cast<Attribute>(Val&: pos)) {
832 int64_t offset = cast<IntegerAttr>(Val&: attr).getInt();
833 composedIdx = affine::makeComposedFoldedAffineApply(
834 b&: rewriter, loc: extractOp.getLoc(),
835 expr: rewriter.getAffineSymbolExpr(position: 0) + offset, operands: {newIndices[idx]});
836 } else {
837 Value dynamicOffset = cast<Value>(Val&: pos);
838 AffineExpr sym0, sym1;
839 bindSymbols(ctx: rewriter.getContext(), exprs&: sym0, exprs&: sym1);
840 composedIdx = affine::makeComposedFoldedAffineApply(
841 b&: rewriter, loc: extractOp.getLoc(), expr: sym0 + sym1,
842 operands: {newIndices[idx], dynamicOffset});
843 }
844
845 // Update the corresponding index with the folded result.
846 if (auto value = dyn_cast<Value>(Val&: composedIdx)) {
847 newIndices[idx] = value;
848 } else {
849 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
850 location: extractOp.getLoc(), args: *getConstantIntValue(ofr: composedIdx));
851 }
852 }
853 if (isa<MemRefType>(Val: xferOp.getBase().getType())) {
854 rewriter.replaceOpWithNewOp<memref::LoadOp>(op: extractOp, args: xferOp.getBase(),
855 args&: newIndices);
856 } else {
857 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
858 op: extractOp, args: xferOp.getBase(), args&: newIndices);
859 }
860
861 return success();
862 }
863
864private:
865 bool allowMultipleUses;
866};
867
868/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
869/// to memref.store.
870class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
871 using OpRewritePattern::OpRewritePattern;
872
873 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
874 PatternRewriter &rewriter) const override {
875 // Must be a scalar write.
876 auto vecType = xferOp.getVectorType();
877 if (!llvm::all_of(Range: vecType.getShape(), P: [](int64_t sz) { return sz == 1; }))
878 return failure();
879 // Mask not supported.
880 if (xferOp.getMask())
881 return failure();
882 // Map not supported.
883 if (!xferOp.getPermutationMap().isMinorIdentity())
884 return failure();
885 // Only float and integer element types are supported.
886 Value scalar =
887 rewriter.create<vector::ExtractOp>(location: xferOp.getLoc(), args: xferOp.getVector());
888 // Construct a scalar store.
889 if (isa<MemRefType>(Val: xferOp.getBase().getType())) {
890 rewriter.replaceOpWithNewOp<memref::StoreOp>(
891 op: xferOp, args&: scalar, args: xferOp.getBase(), args: xferOp.getIndices());
892 } else {
893 rewriter.replaceOpWithNewOp<tensor::InsertOp>(
894 op: xferOp, args&: scalar, args: xferOp.getBase(), args: xferOp.getIndices());
895 }
896 return success();
897 }
898};
899
900} // namespace
901
902void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
903 Operation *rootOp) {
904 TransferOptimization opt(rewriter, rootOp);
905 // Run store to load forwarding first since it can expose more dead store
906 // opportunity.
907 rootOp->walk(callback: [&](vector::TransferReadOp read) {
908 if (isa<MemRefType>(Val: read.getShapedType()))
909 opt.storeToLoadForwarding(read);
910 });
911 opt.removeDeadOp();
912 rootOp->walk(callback: [&](vector::TransferWriteOp write) {
913 if (isa<MemRefType>(Val: write.getShapedType()))
914 opt.deadStoreOp(write);
915 });
916 opt.removeDeadOp();
917}
918
919void mlir::vector::populateScalarVectorTransferLoweringPatterns(
920 RewritePatternSet &patterns, PatternBenefit benefit,
921 bool allowMultipleUses) {
922 patterns.add<RewriteScalarExtractOfTransferRead>(arg: patterns.getContext(),
923 args&: benefit, args&: allowMultipleUses);
924 patterns.add<RewriteScalarWrite>(arg: patterns.getContext(), args&: benefit);
925}
926
927void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
928 RewritePatternSet &patterns, PatternBenefit benefit) {
929 patterns
930 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
931 arg: patterns.getContext(), args&: benefit);
932}
933
934void mlir::vector::populateFlattenVectorTransferPatterns(
935 RewritePatternSet &patterns, unsigned targetVectorBitwidth,
936 PatternBenefit benefit) {
937 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
938 FlattenContiguousRowMajorTransferWritePattern>(
939 arg: patterns.getContext(), args&: targetVectorBitwidth, args&: benefit);
940 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
941}
942

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp