1//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functions concerned with optimizing transfer_read and
10// transfer_write ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Utils/IndexingUtils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
22#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
24#include "mlir/IR/Dominance.h"
25#include "mlir/Interfaces/SideEffectInterfaces.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/StringRef.h"
28#include "llvm/Support/Debug.h"
29
30#define DEBUG_TYPE "vector-transfer-opt"
31
32#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33
34using namespace mlir;
35
36/// Return the ancestor op in the region or nullptr if the region is not
37/// an ancestor of the op.
38static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
39 for (; op != nullptr && op->getParentRegion() != region;
40 op = op->getParentOp())
41 ;
42 return op;
43}
44
45namespace {
46
47class TransferOptimization {
48public:
49 TransferOptimization(RewriterBase &rewriter, Operation *op)
50 : rewriter(rewriter), dominators(op), postDominators(op) {}
51 void deadStoreOp(vector::TransferWriteOp);
52 void storeToLoadForwarding(vector::TransferReadOp);
53 void removeDeadOp() {
54 for (Operation *op : opToErase)
55 rewriter.eraseOp(op);
56 opToErase.clear();
57 }
58
59private:
60 RewriterBase &rewriter;
61 bool isReachable(Operation *start, Operation *dest);
62 DominanceInfo dominators;
63 PostDominanceInfo postDominators;
64 std::vector<Operation *> opToErase;
65};
66
67} // namespace
68/// Return true if there is a path from start operation to dest operation,
69/// otherwise return false. The operations have to be in the same region.
70bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
71 assert(start->getParentRegion() == dest->getParentRegion() &&
72 "This function only works for ops i the same region");
73 // Simple case where the start op dominate the destination.
74 if (dominators.dominates(a: start, b: dest))
75 return true;
76 return start->getBlock()->isReachable(other: dest->getBlock());
77}
78
79/// For transfer_write to overwrite fully another transfer_write must:
80/// 1. Access the same memref with the same indices and vector type.
81/// 2. Post-dominate the other transfer_write operation.
82/// If several candidates are available, one must be post-dominated by all the
83/// others since they are all post-dominating the same transfer_write. We only
84/// consider the transfer_write post-dominated by all the other candidates as
85/// this will be the first transfer_write executed after the potentially dead
86/// transfer_write.
87/// If we found such an overwriting transfer_write we know that the original
88/// transfer_write is dead if all reads that can be reached from the potentially
89/// dead transfer_write are dominated by the overwriting transfer_write.
90void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
91 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
92 << "\n");
93 llvm::SmallVector<Operation *, 8> blockingAccesses;
94 Operation *firstOverwriteCandidate = nullptr;
95 Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(write.getBase()));
96 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
97 source.getUsers().end());
98 llvm::SmallDenseSet<Operation *, 32> processed;
99 while (!users.empty()) {
100 Operation *user = users.pop_back_val();
101 // If the user has already been processed skip.
102 if (!processed.insert(V: user).second)
103 continue;
104 if (isa<ViewLikeOpInterface>(user)) {
105 users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end());
106 continue;
107 }
108 if (isMemoryEffectFree(op: user))
109 continue;
110 if (user == write.getOperation())
111 continue;
112 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
113 // Check candidate that can override the store.
114 if (memref::isSameViewOrTrivialAlias(
115 a: cast<MemrefValue>(nextWrite.getBase()),
116 b: cast<MemrefValue>(write.getBase())) &&
117 checkSameValueWAW(nextWrite, write) &&
118 postDominators.postDominates(nextWrite, write)) {
119 if (firstOverwriteCandidate == nullptr ||
120 postDominators.postDominates(firstOverwriteCandidate, nextWrite))
121 firstOverwriteCandidate = nextWrite;
122 else
123 assert(
124 postDominators.postDominates(nextWrite, firstOverwriteCandidate));
125 continue;
126 }
127 }
128 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
129 // Don't need to consider disjoint accesses.
130 if (vector::isDisjointTransferSet(
131 cast<VectorTransferOpInterface>(write.getOperation()),
132 cast<VectorTransferOpInterface>(transferOp.getOperation()),
133 /*testDynamicValueUsingBounds=*/true))
134 continue;
135 }
136 blockingAccesses.push_back(Elt: user);
137 }
138 if (firstOverwriteCandidate == nullptr)
139 return;
140 Region *topRegion = firstOverwriteCandidate->getParentRegion();
141 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
142 assert(writeAncestor &&
143 "write op should be recursively part of the top region");
144
145 for (Operation *access : blockingAccesses) {
146 Operation *accessAncestor = findAncestorOpInRegion(region: topRegion, op: access);
147 // TODO: if the access and write have the same ancestor we could recurse in
148 // the region to know if the access is reachable with more precision.
149 if (accessAncestor == nullptr ||
150 !isReachable(start: writeAncestor, dest: accessAncestor))
151 continue;
152 if (!dominators.dominates(a: firstOverwriteCandidate, b: accessAncestor)) {
153 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
154 << *accessAncestor << "\n");
155 return;
156 }
157 }
158 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
159 << " overwritten by: " << *firstOverwriteCandidate << "\n");
160 opToErase.push_back(write.getOperation());
161}
162
163/// A transfer_write candidate to storeToLoad forwarding must:
164/// 1. Access the same memref with the same indices and vector type as the
165/// transfer_read.
166/// 2. Dominate the transfer_read operation.
167/// If several candidates are available, one must be dominated by all the others
168/// since they are all dominating the same transfer_read. We only consider the
169/// transfer_write dominated by all the other candidates as this will be the
170/// last transfer_write executed before the transfer_read.
171/// If we found such a candidate we can do the forwarding if all the other
172/// potentially aliasing ops that may reach the transfer_read are post-dominated
173/// by the transfer_write.
174void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
175 if (read.hasOutOfBoundsDim())
176 return;
177 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
178 << "\n");
179 SmallVector<Operation *, 8> blockingWrites;
180 vector::TransferWriteOp lastwrite = nullptr;
181 Value source = memref::skipViewLikeOps(source: cast<MemrefValue>(read.getBase()));
182 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
183 source.getUsers().end());
184 llvm::SmallDenseSet<Operation *, 32> processed;
185 while (!users.empty()) {
186 Operation *user = users.pop_back_val();
187 // If the user has already been processed skip.
188 if (!processed.insert(V: user).second)
189 continue;
190 if (isa<ViewLikeOpInterface>(user)) {
191 users.append(in_start: user->getUsers().begin(), in_end: user->getUsers().end());
192 continue;
193 }
194 if (isMemoryEffectFree(op: user) || isa<vector::TransferReadOp>(Val: user))
195 continue;
196 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
197 // If there is a write, but we can prove that it is disjoint we can ignore
198 // the write.
199 if (vector::isDisjointTransferSet(
200 cast<VectorTransferOpInterface>(write.getOperation()),
201 cast<VectorTransferOpInterface>(read.getOperation()),
202 /*testDynamicValueUsingBounds=*/true))
203 continue;
204 if (memref::isSameViewOrTrivialAlias(
205 a: cast<MemrefValue>(read.getBase()),
206 b: cast<MemrefValue>(write.getBase())) &&
207 dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
208 if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
209 lastwrite = write;
210 else
211 assert(dominators.dominates(write, lastwrite));
212 continue;
213 }
214 }
215 blockingWrites.push_back(Elt: user);
216 }
217
218 if (lastwrite == nullptr)
219 return;
220
221 Region *topRegion = lastwrite->getParentRegion();
222 Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
223 assert(readAncestor &&
224 "read op should be recursively part of the top region");
225
226 for (Operation *write : blockingWrites) {
227 Operation *writeAncestor = findAncestorOpInRegion(region: topRegion, op: write);
228 // TODO: if the store and read have the same ancestor we could recurse in
229 // the region to know if the read is reachable with more precision.
230 if (writeAncestor == nullptr || !isReachable(start: writeAncestor, dest: readAncestor))
231 continue;
232 if (!postDominators.postDominates(lastwrite, write)) {
233 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
234 << *write << "\n");
235 return;
236 }
237 }
238
239 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
240 << " to: " << *read.getOperation() << "\n");
241 read.replaceAllUsesWith(lastwrite.getVector());
242 opToErase.push_back(read.getOperation());
243}
244
245/// Converts OpFoldResults to int64_t shape without unit dims.
246static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
247 SmallVector<int64_t> reducedShape;
248 for (const auto size : mixedSizes) {
249 if (llvm::dyn_cast_if_present<Value>(Val: size)) {
250 reducedShape.push_back(ShapedType::kDynamic);
251 continue;
252 }
253
254 auto value = cast<IntegerAttr>(cast<Attribute>(Val: size)).getValue();
255 if (value == 1)
256 continue;
257 reducedShape.push_back(Elt: value.getSExtValue());
258 }
259 return reducedShape;
260}
261
262/// Drops unit dimensions from the input MemRefType.
263static MemRefType dropUnitDims(MemRefType inputType,
264 ArrayRef<OpFoldResult> offsets,
265 ArrayRef<OpFoldResult> sizes,
266 ArrayRef<OpFoldResult> strides) {
267 auto targetShape = getReducedShape(mixedSizes: sizes);
268 MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
269 targetShape, inputType, offsets, sizes, strides);
270 return rankReducedType.canonicalizeStridedLayout();
271}
272
273/// Creates a rank-reducing memref.subview op that drops unit dims from its
274/// input. Or just returns the input if it was already without unit dims.
275static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
276 mlir::Location loc,
277 Value input) {
278 MemRefType inputType = cast<MemRefType>(input.getType());
279 SmallVector<OpFoldResult> offsets(inputType.getRank(),
280 rewriter.getIndexAttr(0));
281 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(builder&: rewriter, loc, value: input);
282 SmallVector<OpFoldResult> strides(inputType.getRank(),
283 rewriter.getIndexAttr(1));
284 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
285
286 if (resultType.canonicalizeStridedLayout() ==
287 inputType.canonicalizeStridedLayout())
288 return input;
289 return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
290 sizes, strides);
291}
292
293/// Returns the number of dims that aren't unit dims.
294static int getReducedRank(ArrayRef<int64_t> shape) {
295 return llvm::count_if(Range&: shape, P: [](int64_t dimSize) { return dimSize != 1; });
296}
297
298/// Trims non-scalable one dimensions from `oldType` and returns the result
299/// type.
300static VectorType trimNonScalableUnitDims(VectorType oldType) {
301 SmallVector<int64_t> newShape;
302 SmallVector<bool> newScalableDims;
303 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
304 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
305 continue;
306 newShape.push_back(dimSize);
307 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
308 }
309 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
310}
311
312// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
313static FailureOr<Value>
314createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
315 vector::CreateMaskOp op) {
316 auto type = op.getType();
317 VectorType reducedType = trimNonScalableUnitDims(type);
318 if (reducedType.getRank() == type.getRank())
319 return failure();
320
321 SmallVector<Value> reducedOperands;
322 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
323 type.getShape(), type.getScalableDims(), op.getOperands())) {
324 if (dim == 1 && !dimIsScalable) {
325 // If the mask for the unit dim is not a constant of 1, do nothing.
326 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
327 if (!constant || (constant.value() != 1))
328 return failure();
329 continue;
330 }
331 reducedOperands.push_back(operand);
332 }
333 return rewriter
334 .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
335 .getResult();
336}
337
338namespace {
339
340/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
341/// inserting a memref.subview dropping those unit dims. The vector shapes are
342/// also reduced accordingly.
343class TransferReadDropUnitDimsPattern
344 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
345 using MaskableOpRewritePattern::MaskableOpRewritePattern;
346
347 FailureOr<Value>
348 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
349 vector::MaskingOpInterface maskingOp,
350 PatternRewriter &rewriter) const override {
351 auto loc = transferReadOp.getLoc();
352 Value vector = transferReadOp.getVector();
353 VectorType vectorType = cast<VectorType>(vector.getType());
354 Value source = transferReadOp.getBase();
355 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
356 // TODO: support tensor types.
357 if (!sourceType)
358 return failure();
359 // TODO: generalize this pattern, relax the requirements here.
360 if (transferReadOp.hasOutOfBoundsDim())
361 return failure();
362 if (!transferReadOp.getPermutationMap().isMinorIdentity())
363 return failure();
364 // Check if the source shape can be further reduced.
365 int reducedRank = getReducedRank(sourceType.getShape());
366 if (reducedRank == sourceType.getRank())
367 return failure();
368 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
369 // out.
370 if (reducedRank == 0 && maskingOp)
371 return failure();
372 // Check if the reduced vector shape matches the reduced source shape.
373 // Otherwise, this case is not supported yet.
374 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
375 if (reducedRank != reducedVectorType.getRank())
376 return failure();
377 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
378 return getConstantIntValue(ofr: v) != static_cast<int64_t>(0);
379 }))
380 return failure();
381
382 Value maskOp = transferReadOp.getMask();
383 if (maskOp) {
384 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
385 if (!createMaskOp)
386 return rewriter.notifyMatchFailure(
387 transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
388 "currently supported");
389 FailureOr<Value> rankReducedCreateMask =
390 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
391 if (failed(Result: rankReducedCreateMask))
392 return failure();
393 maskOp = *rankReducedCreateMask;
394 }
395
396 Value reducedShapeSource =
397 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
398 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
399 SmallVector<Value> zeros(reducedRank, c0);
400 auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank);
401 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
402 Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
403 loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
404 transferReadOp.getPadding(), maskOp,
405 rewriter.getBoolArrayAttr(inBounds));
406
407 if (maskingOp) {
408 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
409 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
410 maskingOp.getMask());
411 newTransferReadOp = mlir::vector::maskOperation(
412 builder&: rewriter, maskableOp: newTransferReadOp, mask: shapeCastMask);
413 }
414
415 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
416 loc, vectorType, newTransferReadOp->getResults()[0]);
417
418 return shapeCast;
419 }
420};
421
422/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
423/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
424/// vector shapes are also reduced accordingly.
425class TransferWriteDropUnitDimsPattern
426 : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
427 using MaskableOpRewritePattern::MaskableOpRewritePattern;
428
429 FailureOr<Value>
430 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
431 vector::MaskingOpInterface maskingOp,
432 PatternRewriter &rewriter) const override {
433 auto loc = transferWriteOp.getLoc();
434 Value vector = transferWriteOp.getVector();
435 VectorType vectorType = cast<VectorType>(vector.getType());
436 Value source = transferWriteOp.getBase();
437 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
438 // TODO: support tensor type.
439 if (!sourceType)
440 return failure();
441 // TODO: generalize this pattern, relax the requirements here.
442 if (transferWriteOp.hasOutOfBoundsDim())
443 return failure();
444 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
445 return failure();
446 // Check if the destination shape can be further reduced.
447 int reducedRank = getReducedRank(sourceType.getShape());
448 if (reducedRank == sourceType.getRank())
449 return failure();
450 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
451 // out.
452 if (reducedRank == 0 && maskingOp)
453 return failure();
454 // Check if the reduced vector shape matches the reduced destination shape.
455 // Otherwise, this case is not supported yet.
456 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
457 if (reducedRank != reducedVectorType.getRank())
458 return failure();
459 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
460 return getConstantIntValue(ofr: v) != static_cast<int64_t>(0);
461 }))
462 return failure();
463
464 Value maskOp = transferWriteOp.getMask();
465 if (maskOp) {
466 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
467 if (!createMaskOp)
468 return rewriter.notifyMatchFailure(
469 transferWriteOp,
470 "unsupported mask op, only 'vector.create_mask' is "
471 "currently supported");
472 FailureOr<Value> rankReducedCreateMask =
473 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
474 if (failed(Result: rankReducedCreateMask))
475 return failure();
476 maskOp = *rankReducedCreateMask;
477 }
478 Value reducedShapeSource =
479 rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
480 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
481 SmallVector<Value> zeros(reducedRank, c0);
482 auto identityMap = rewriter.getMultiDimIdentityMap(rank: reducedRank);
483 SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
484 auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
485 loc, reducedVectorType, vector);
486 Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
487 loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
488 maskOp, rewriter.getBoolArrayAttr(inBounds));
489
490 if (maskingOp) {
491 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
492 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
493 maskingOp.getMask());
494 newXferWrite =
495 mlir::vector::maskOperation(builder&: rewriter, maskableOp: newXferWrite, mask: shapeCastMask);
496 }
497
498 if (transferWriteOp.hasPureTensorSemantics())
499 return newXferWrite->getResults()[0];
500
501 // With Memref semantics, there's no return value. Use empty value to signal
502 // success.
503 return Value();
504 }
505};
506
507} // namespace
508
509/// Creates a memref.collapse_shape collapsing all inner dimensions of the
510/// input starting at `firstDimToCollapse`.
511static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
512 Value input, int64_t firstDimToCollapse) {
513 ShapedType inputType = cast<ShapedType>(input.getType());
514 if (inputType.getRank() == 1)
515 return input;
516 SmallVector<ReassociationIndices> reassociation;
517 for (int64_t i = 0; i < firstDimToCollapse; ++i)
518 reassociation.push_back(Elt: ReassociationIndices{i});
519 ReassociationIndices collapsedIndices;
520 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
521 collapsedIndices.push_back(Elt: i);
522 reassociation.push_back(Elt: collapsedIndices);
523 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
524}
525
526/// Returns the new indices that collapses the inner dimensions starting from
527/// the `firstDimToCollapse` dimension.
528static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
529 Location loc,
530 ArrayRef<int64_t> shape,
531 ValueRange indices,
532 int64_t firstDimToCollapse) {
533 assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
534
535 // If all the collapsed indices are zero then no extra logic is needed.
536 // Otherwise, a new offset/index has to be computed.
537 SmallVector<Value> indicesAfterCollapsing(
538 indices.begin(), indices.begin() + firstDimToCollapse);
539 SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
540 indices.end());
541 if (llvm::all_of(Range&: indicesToCollapse, P: isZeroInteger)) {
542 indicesAfterCollapsing.push_back(Elt: indicesToCollapse[0]);
543 return indicesAfterCollapsing;
544 }
545
546 // Compute the remaining trailing index/offset required for reading from
547 // the collapsed memref:
548 //
549 // offset = 0
550 // for (i = firstDimToCollapse; i < outputRank; ++i)
551 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
552 //
553 // For this example:
554 // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
555 // memref<1x43x2xi32>, vector<1x2xi32>
556 // which would be collapsed to:
557 // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
558 // memref<1x86xi32>, vector<2xi32>
559 // one would get the following offset:
560 // %offset = %arg0 * 43
561 OpFoldResult collapsedOffset =
562 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0).getResult();
563
564 auto collapsedStrides = computeSuffixProduct(
565 sizes: ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
566
567 // Compute the collapsed offset.
568 auto &&[collapsedExpr, collapsedVals] =
569 computeLinearIndex(sourceOffset: collapsedOffset, strides: collapsedStrides, indices: indicesToCollapse);
570 collapsedOffset = affine::makeComposedFoldedAffineApply(
571 rewriter, loc, collapsedExpr, collapsedVals);
572
573 if (auto value = dyn_cast<Value>(collapsedOffset)) {
574 indicesAfterCollapsing.push_back(Elt: value);
575 } else {
576 indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
577 location: loc, args: *getConstantIntValue(ofr: collapsedOffset)));
578 }
579
580 return indicesAfterCollapsing;
581}
582
583namespace {
584
585/// Rewrites contiguous row-major vector.transfer_read ops by inserting
586/// memref.collapse_shape on the source so that the resulting
587/// vector.transfer_read has a 1D source. Requires the source shape to be
588/// already reduced i.e. without unit dims.
589///
590/// If `targetVectorBitwidth` is provided, the flattening will only happen if
591/// the trailing dimension of the vector read is smaller than the provided
592/// bitwidth.
593class FlattenContiguousRowMajorTransferReadPattern
594 : public OpRewritePattern<vector::TransferReadOp> {
595public:
596 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
597 unsigned vectorBitwidth,
598 PatternBenefit benefit)
599 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
600 targetVectorBitwidth(vectorBitwidth) {}
601
602 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
603 PatternRewriter &rewriter) const override {
604 auto loc = transferReadOp.getLoc();
605 Value vector = transferReadOp.getVector();
606 VectorType vectorType = cast<VectorType>(vector.getType());
607 auto source = transferReadOp.getBase();
608 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
609
610 // 0. Check pre-conditions
611 // Contiguity check is valid on tensors only.
612 if (!sourceType)
613 return failure();
614 // If this is already 0D/1D, there's nothing to do.
615 if (vectorType.getRank() <= 1)
616 return failure();
617 if (!vectorType.getElementType().isSignlessIntOrFloat())
618 return failure();
619 unsigned trailingVectorDimBitwidth =
620 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
621 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
622 return failure();
623 if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType))
624 return failure();
625 // TODO: generalize this pattern, relax the requirements here.
626 if (transferReadOp.hasOutOfBoundsDim())
627 return failure();
628 if (!transferReadOp.getPermutationMap().isMinorIdentity())
629 return failure();
630 if (transferReadOp.getMask())
631 return failure();
632
633 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
634
635 // 1. Collapse the source memref
636 Value collapsedSource =
637 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
638 MemRefType collapsedSourceType =
639 cast<MemRefType>(collapsedSource.getType());
640 int64_t collapsedRank = collapsedSourceType.getRank();
641 assert(collapsedRank == firstDimToCollapse + 1);
642
643 // 2. Generate input args for a new vector.transfer_read that will read
644 // from the collapsed memref.
645 // 2.1. New dim exprs + affine map
646 SmallVector<AffineExpr, 1> dimExprs{
647 getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())};
648 auto collapsedMap =
649 AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext());
650
651 // 2.2 New indices
652 SmallVector<Value> collapsedIndices =
653 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
654 transferReadOp.getIndices(), firstDimToCollapse);
655
656 // 3. Create new vector.transfer_read that reads from the collapsed memref
657 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
658 vectorType.getElementType());
659 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
660 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
661 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
662
663 // 4. Replace the old transfer_read with the new one reading from the
664 // collapsed shape
665 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
666 transferReadOp, cast<VectorType>(vector.getType()), flatRead);
667 return success();
668 }
669
670private:
671 // Minimum bitwidth that the trailing vector dimension should have after
672 // flattening.
673 unsigned targetVectorBitwidth;
674};
675
676/// Rewrites contiguous row-major vector.transfer_write ops by inserting
677/// memref.collapse_shape on the source so that the resulting
678/// vector.transfer_write has a 1D source. Requires the source shape to be
679/// already reduced i.e. without unit dims.
680///
681/// If `targetVectorBitwidth` is provided, the flattening will only happen if
682/// the trailing dimension of the vector read is smaller than the provided
683/// bitwidth.
684class FlattenContiguousRowMajorTransferWritePattern
685 : public OpRewritePattern<vector::TransferWriteOp> {
686public:
687 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
688 unsigned vectorBitwidth,
689 PatternBenefit benefit)
690 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
691 targetVectorBitwidth(vectorBitwidth) {}
692
693 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
694 PatternRewriter &rewriter) const override {
695 auto loc = transferWriteOp.getLoc();
696 Value vector = transferWriteOp.getVector();
697 VectorType vectorType = cast<VectorType>(vector.getType());
698 Value source = transferWriteOp.getBase();
699 MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
700
701 // 0. Check pre-conditions
702 // Contiguity check is valid on tensors only.
703 if (!sourceType)
704 return failure();
705 // If this is already 0D/1D, there's nothing to do.
706 if (vectorType.getRank() <= 1)
707 // Already 0D/1D, nothing to do.
708 return failure();
709 if (!vectorType.getElementType().isSignlessIntOrFloat())
710 return failure();
711 unsigned trailingVectorDimBitwidth =
712 vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
713 if (trailingVectorDimBitwidth >= targetVectorBitwidth)
714 return failure();
715 if (!vector::isContiguousSlice(memrefType: sourceType, vectorType: vectorType))
716 return failure();
717 // TODO: generalize this pattern, relax the requirements here.
718 if (transferWriteOp.hasOutOfBoundsDim())
719 return failure();
720 if (!transferWriteOp.getPermutationMap().isMinorIdentity())
721 return failure();
722 if (transferWriteOp.getMask())
723 return failure();
724
725 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
726
727 // 1. Collapse the source memref
728 Value collapsedSource =
729 collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
730 MemRefType collapsedSourceType =
731 cast<MemRefType>(collapsedSource.getType());
732 int64_t collapsedRank = collapsedSourceType.getRank();
733 assert(collapsedRank == firstDimToCollapse + 1);
734
735 // 2. Generate input args for a new vector.transfer_read that will read
736 // from the collapsed memref.
737 // 2.1. New dim exprs + affine map
738 SmallVector<AffineExpr, 1> dimExprs{
739 getAffineDimExpr(position: firstDimToCollapse, context: rewriter.getContext())};
740 auto collapsedMap =
741 AffineMap::get(dimCount: collapsedRank, symbolCount: 0, results: dimExprs, context: rewriter.getContext());
742
743 // 2.2 New indices
744 SmallVector<Value> collapsedIndices =
745 getCollapsedIndices(rewriter, loc, sourceType.getShape(),
746 transferWriteOp.getIndices(), firstDimToCollapse);
747
748 // 3. Create new vector.transfer_write that writes to the collapsed memref
749 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
750 vectorType.getElementType());
751 Value flatVector =
752 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
753 vector::TransferWriteOp flatWrite =
754 rewriter.create<vector::TransferWriteOp>(
755 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
756 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
757
758 // 4. Replace the old transfer_write with the new one writing the
759 // collapsed shape
760 rewriter.eraseOp(op: transferWriteOp);
761 return success();
762 }
763
764private:
765 // Minimum bitwidth that the trailing vector dimension should have after
766 // flattening.
767 unsigned targetVectorBitwidth;
768};
769
770/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
771/// to `memref.load` patterns. The `match` method is shared for both
772/// `vector.extract` and `vector.extract_element`.
773template <class VectorExtractOp>
774class RewriteScalarExtractOfTransferReadBase
775 : public OpRewritePattern<VectorExtractOp> {
776 using Base = OpRewritePattern<VectorExtractOp>;
777
778public:
779 RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
780 PatternBenefit benefit,
781 bool allowMultipleUses)
782 : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
783
784 LogicalResult match(VectorExtractOp extractOp) const {
785 auto xferOp =
786 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
787 if (!xferOp)
788 return failure();
789 // Check that we are extracting a scalar and not a sub-vector.
790 if (isa<VectorType>(extractOp.getResult().getType()))
791 return failure();
792 // If multiple uses are not allowed, check if xfer has a single use.
793 if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
794 return failure();
795 // If multiple uses are allowed, check if all the xfer uses are extract ops.
796 if (allowMultipleUses &&
797 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
798 return isa<vector::ExtractOp, vector::ExtractElementOp>(
799 use.getOwner());
800 }))
801 return failure();
802 // Mask not supported.
803 if (xferOp.getMask())
804 return failure();
805 // Map not supported.
806 if (!xferOp.getPermutationMap().isMinorIdentity())
807 return failure();
808 // Cannot rewrite if the indices may be out of bounds.
809 if (xferOp.hasOutOfBoundsDim())
810 return failure();
811 return success();
812 }
813
814private:
815 bool allowMultipleUses;
816};
817
818/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
819///
820/// All the users of the transfer op must be either `vector.extractelement` or
821/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
822/// transfer ops with any number of users. Otherwise, rewrite only if the
823/// extract op is the single user of the transfer op. Rewriting a single
824/// vector load with multiple scalar loads may negatively affect performance.
825class RewriteScalarExtractElementOfTransferRead
826 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
827 using RewriteScalarExtractOfTransferReadBase::
828 RewriteScalarExtractOfTransferReadBase;
829
830 LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
831 PatternRewriter &rewriter) const override {
832 if (failed(match(extractOp)))
833 return failure();
834
835 // Construct scalar load.
836 auto loc = extractOp.getLoc();
837 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
838 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
839 xferOp.getIndices().end());
840 if (extractOp.getPosition()) {
841 AffineExpr sym0, sym1;
842 bindSymbols(extractOp.getContext(), sym0, sym1);
843 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
844 rewriter, loc, sym0 + sym1,
845 {newIndices[newIndices.size() - 1], extractOp.getPosition()});
846 if (auto value = dyn_cast<Value>(ofr)) {
847 newIndices[newIndices.size() - 1] = value;
848 } else {
849 newIndices[newIndices.size() - 1] =
850 rewriter.create<arith::ConstantIndexOp>(loc,
851 *getConstantIntValue(ofr));
852 }
853 }
854 if (isa<MemRefType>(xferOp.getBase().getType())) {
855 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
856 newIndices);
857 } else {
858 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
859 extractOp, xferOp.getBase(), newIndices);
860 }
861
862 return success();
863 }
864};
865
866/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
867/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
868///
869/// All the users of the transfer op must be either `vector.extractelement` or
870/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
871/// transfer ops with any number of users. Otherwise, rewrite only if the
872/// extract op is the single user of the transfer op. Rewriting a single
873/// vector load with multiple scalar loads may negatively affect performance.
874class RewriteScalarExtractOfTransferRead
875 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
876 using RewriteScalarExtractOfTransferReadBase::
877 RewriteScalarExtractOfTransferReadBase;
878
879 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
880 PatternRewriter &rewriter) const override {
881 if (failed(match(extractOp)))
882 return failure();
883
884 // Construct scalar load.
885 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
886 SmallVector<Value> newIndices(xferOp.getIndices().begin(),
887 xferOp.getIndices().end());
888 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
889 assert(isa<Attribute>(pos) && "Unexpected non-constant index");
890 int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
891 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
892 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
893 rewriter, extractOp.getLoc(),
894 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
895 if (auto value = dyn_cast<Value>(ofr)) {
896 newIndices[idx] = value;
897 } else {
898 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
899 extractOp.getLoc(), *getConstantIntValue(ofr));
900 }
901 }
902 if (isa<MemRefType>(xferOp.getBase().getType())) {
903 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
904 newIndices);
905 } else {
906 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
907 extractOp, xferOp.getBase(), newIndices);
908 }
909
910 return success();
911 }
912};
913
914/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
915/// to memref.store.
916class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
917 using OpRewritePattern::OpRewritePattern;
918
919 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
920 PatternRewriter &rewriter) const override {
921 // Must be a scalar write.
922 auto vecType = xferOp.getVectorType();
923 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
924 return failure();
925 // Mask not supported.
926 if (xferOp.getMask())
927 return failure();
928 // Map not supported.
929 if (!xferOp.getPermutationMap().isMinorIdentity())
930 return failure();
931 // Only float and integer element types are supported.
932 Value scalar =
933 rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
934 // Construct a scalar store.
935 if (isa<MemRefType>(xferOp.getBase().getType())) {
936 rewriter.replaceOpWithNewOp<memref::StoreOp>(
937 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
938 } else {
939 rewriter.replaceOpWithNewOp<tensor::InsertOp>(
940 xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
941 }
942 return success();
943 }
944};
945
946} // namespace
947
948void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
949 Operation *rootOp) {
950 TransferOptimization opt(rewriter, rootOp);
951 // Run store to load forwarding first since it can expose more dead store
952 // opportunity.
953 rootOp->walk([&](vector::TransferReadOp read) {
954 if (isa<MemRefType>(read.getShapedType()))
955 opt.storeToLoadForwarding(read: read);
956 });
957 opt.removeDeadOp();
958 rootOp->walk([&](vector::TransferWriteOp write) {
959 if (isa<MemRefType>(write.getShapedType()))
960 opt.deadStoreOp(write: write);
961 });
962 opt.removeDeadOp();
963}
964
965void mlir::vector::populateScalarVectorTransferLoweringPatterns(
966 RewritePatternSet &patterns, PatternBenefit benefit,
967 bool allowMultipleUses) {
968 patterns.add<RewriteScalarExtractElementOfTransferRead,
969 RewriteScalarExtractOfTransferRead>(arg: patterns.getContext(),
970 args&: benefit, args&: allowMultipleUses);
971 patterns.add<RewriteScalarWrite>(arg: patterns.getContext(), args&: benefit);
972}
973
974void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
975 RewritePatternSet &patterns, PatternBenefit benefit) {
976 patterns
977 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
978 arg: patterns.getContext(), args&: benefit);
979}
980
981void mlir::vector::populateFlattenVectorTransferPatterns(
982 RewritePatternSet &patterns, unsigned targetVectorBitwidth,
983 PatternBenefit benefit) {
984 patterns.add<FlattenContiguousRowMajorTransferReadPattern,
985 FlattenContiguousRowMajorTransferWritePattern>(
986 arg: patterns.getContext(), args&: targetVectorBitwidth, args&: benefit);
987 populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
988}
989

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp