1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/MemRef/Transforms/Passes.h"
21#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
22#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
23#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/IR/AffineMap.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallBitVector.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32
33#define DEBUG_TYPE "fold-memref-alias-ops"
34#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35
36namespace mlir {
37namespace memref {
38#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
39#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40} // namespace memref
41} // namespace mlir
42
43using namespace mlir;
44
45//===----------------------------------------------------------------------===//
46// Utility functions
47//===----------------------------------------------------------------------===//
48
49/// Given the 'indices' of a load/store operation where the memref is a result
50/// of a expand_shape op, returns the indices w.r.t to the source memref of the
51/// expand_shape op. For example
52///
53/// %0 = ... : memref<12x42xf32>
54/// %1 = memref.expand_shape %0 [[0, 1], [2]]
55/// : memref<12x42xf32> into memref<2x6x42xf32>
56/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
57///
58/// could be folded into
59///
60/// %2 = load %0[6 * i1 + i2, %i3] :
61/// memref<12x42xf32>
62static LogicalResult resolveSourceIndicesExpandShape(
63 Location loc, PatternRewriter &rewriter,
64 memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65 SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66 SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
67
68 // Traverse all reassociation groups to determine the appropriate indices
69 // corresponding to each one of them post op folding.
70 for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
71 assert(!group.empty() && "association indices groups cannot be empty");
72 int64_t groupSize = group.size();
73 if (groupSize == 1) {
74 sourceIndices.push_back(indices[group[0]]);
75 continue;
76 }
77 SmallVector<OpFoldResult> groupBasis =
78 llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
79 SmallVector<Value> groupIndices =
80 llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
81 Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
82 loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
83 sourceIndices.push_back(collapsedIndex);
84 }
85 return success();
86}
87
88/// Given the 'indices' of a load/store operation where the memref is a result
89/// of a collapse_shape op, returns the indices w.r.t to the source memref of
90/// the collapse_shape op. For example
91///
92/// %0 = ... : memref<2x6x42xf32>
93/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
94/// : memref<2x6x42xf32> into memref<12x42xf32>
95/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
96///
97/// could be folded into
98///
99/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
100/// memref<2x6x42xf32>
101static LogicalResult
102resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
103 memref::CollapseShapeOp collapseShapeOp,
104 ValueRange indices,
105 SmallVectorImpl<Value> &sourceIndices) {
106 // Note: collapse_shape requires a strided memref, we can do this.
107 auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
108 loc, collapseShapeOp.getSrc());
109 SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
110 for (auto [index, group] :
111 llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
112 assert(!group.empty() && "association indices groups cannot be empty");
113 int64_t groupSize = group.size();
114
115 if (groupSize == 1) {
116 sourceIndices.push_back(index);
117 continue;
118 }
119
120 SmallVector<OpFoldResult> basis =
121 llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
122 auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
123 loc, index, basis, /*hasOuterBound=*/true);
124 llvm::append_range(sourceIndices, delinearize.getResults());
125 }
126 if (collapseShapeOp.getReassociationIndices().empty()) {
127 auto zeroAffineMap = rewriter.getConstantAffineMap(val: 0);
128 int64_t srcRank =
129 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
130 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
131 b&: rewriter, loc, map: zeroAffineMap, operands: ArrayRef<OpFoldResult>{});
132 for (int64_t i = 0; i < srcRank; i++) {
133 sourceIndices.push_back(
134 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr));
135 }
136 }
137 return success();
138}
139
140/// Helpers to access the memref operand for each op.
141template <typename LoadOrStoreOpTy>
142static Value getMemRefOperand(LoadOrStoreOpTy op) {
143 return op.getMemref();
144}
145
146static Value getMemRefOperand(vector::TransferReadOp op) {
147 return op.getBase();
148}
149
150static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
151 return op.getSrcMemref();
152}
153
154static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
155
156static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
157
158static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
159
160static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
161
162static Value getMemRefOperand(vector::TransferWriteOp op) {
163 return op.getBase();
164}
165
166static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
167 return op.getSrcMemref();
168}
169
170static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
171 return op.getDstMemref();
172}
173
174//===----------------------------------------------------------------------===//
175// Patterns
176//===----------------------------------------------------------------------===//
177
178namespace {
179/// Merges subview operation with load/transferRead operation.
180template <typename OpTy>
181class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
182public:
183 using OpRewritePattern<OpTy>::OpRewritePattern;
184
185 LogicalResult matchAndRewrite(OpTy loadOp,
186 PatternRewriter &rewriter) const override;
187};
188
189/// Merges expand_shape operation with load/transferRead operation.
190template <typename OpTy>
191class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
192public:
193 using OpRewritePattern<OpTy>::OpRewritePattern;
194
195 LogicalResult matchAndRewrite(OpTy loadOp,
196 PatternRewriter &rewriter) const override;
197};
198
199/// Merges collapse_shape operation with load/transferRead operation.
200template <typename OpTy>
201class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
202public:
203 using OpRewritePattern<OpTy>::OpRewritePattern;
204
205 LogicalResult matchAndRewrite(OpTy loadOp,
206 PatternRewriter &rewriter) const override;
207};
208
209/// Merges subview operation with store/transferWriteOp operation.
210template <typename OpTy>
211class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
212public:
213 using OpRewritePattern<OpTy>::OpRewritePattern;
214
215 LogicalResult matchAndRewrite(OpTy storeOp,
216 PatternRewriter &rewriter) const override;
217};
218
219/// Merges expand_shape operation with store/transferWriteOp operation.
220template <typename OpTy>
221class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
222public:
223 using OpRewritePattern<OpTy>::OpRewritePattern;
224
225 LogicalResult matchAndRewrite(OpTy storeOp,
226 PatternRewriter &rewriter) const override;
227};
228
229/// Merges collapse_shape operation with store/transferWriteOp operation.
230template <typename OpTy>
231class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
232public:
233 using OpRewritePattern<OpTy>::OpRewritePattern;
234
235 LogicalResult matchAndRewrite(OpTy storeOp,
236 PatternRewriter &rewriter) const override;
237};
238
239/// Folds subview(subview(x)) to a single subview(x).
240class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
241public:
242 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
243
244 LogicalResult matchAndRewrite(memref::SubViewOp subView,
245 PatternRewriter &rewriter) const override {
246 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
247 if (!srcSubView)
248 return failure();
249
250 // TODO: relax unit stride assumption.
251 if (!subView.hasUnitStride()) {
252 return rewriter.notifyMatchFailure(subView, "requires unit strides");
253 }
254 if (!srcSubView.hasUnitStride()) {
255 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
256 }
257
258 // Resolve sizes according to dropped dims.
259 SmallVector<OpFoldResult> resolvedSizes;
260 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
261 affine::resolveSizesIntoOpWithSizes(sourceSizes: srcSubView.getMixedSizes(),
262 destSizes: subView.getMixedSizes(), rankReducedSourceDims: srcDroppedDims,
263 resolvedSizes);
264
265 // Resolve offsets according to source offsets and strides.
266 SmallVector<Value> resolvedOffsets;
267 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
268 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
269 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
270 resolvedOffsets);
271
272 // Replace original op.
273 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
274 subView, subView.getType(), srcSubView.getSource(),
275 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
276 srcSubView.getMixedStrides());
277
278 return success();
279 }
280};
281
282/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
283/// is folds subview on src and dst memref of the copy.
284class NVGPUAsyncCopyOpSubViewOpFolder final
285 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
286public:
287 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
288
289 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
290 PatternRewriter &rewriter) const override;
291};
292} // namespace
293
294static SmallVector<Value>
295calculateExpandedAccessIndices(AffineMap affineMap,
296 const SmallVector<Value> &indices, Location loc,
297 PatternRewriter &rewriter) {
298 SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
299 Range: llvm::map_range(C: indices, F: [](Value v) -> OpFoldResult { return v; })));
300 SmallVector<Value> expandedIndices;
301 for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
302 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
303 b&: rewriter, loc, map: affineMap.getSubMap(resultPos: {i}), operands: indicesOfr);
304 expandedIndices.push_back(
305 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr));
306 }
307 return expandedIndices;
308}
309
310template <typename XferOp>
311static LogicalResult
312preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
313 memref::SubViewOp subviewOp) {
314 static_assert(
315 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
316 "must be a vector transfer op");
317 if (xferOp.hasOutOfBoundsDim())
318 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
319 if (!subviewOp.hasUnitStride()) {
320 return rewriter.notifyMatchFailure(
321 xferOp, "non-1 stride subview, need to track strides in folded memref");
322 }
323 return success();
324}
325
326static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
327 Operation *op,
328 memref::SubViewOp subviewOp) {
329 return success();
330}
331
332static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
333 vector::TransferReadOp readOp,
334 memref::SubViewOp subviewOp) {
335 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
336}
337
338static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
339 vector::TransferWriteOp writeOp,
340 memref::SubViewOp subviewOp) {
341 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
342}
343
344template <typename OpTy>
345LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
346 OpTy loadOp, PatternRewriter &rewriter) const {
347 auto subViewOp =
348 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
349
350 if (!subViewOp)
351 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
352
353 LogicalResult preconditionResult =
354 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
355 if (failed(Result: preconditionResult))
356 return preconditionResult;
357
358 SmallVector<Value> indices(loadOp.getIndices().begin(),
359 loadOp.getIndices().end());
360 // For affine ops, we need to apply the map to get the operands to get the
361 // "actual" indices.
362 if (auto affineLoadOp =
363 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
364 AffineMap affineMap = affineLoadOp.getAffineMap();
365 auto expandedIndices = calculateExpandedAccessIndices(
366 affineMap, indices, loadOp.getLoc(), rewriter);
367 indices.assign(expandedIndices.begin(), expandedIndices.end());
368 }
369 SmallVector<Value> sourceIndices;
370 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
371 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
372 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
373 sourceIndices);
374
375 llvm::TypeSwitch<Operation *, void>(loadOp)
376 .Case([&](affine::AffineLoadOp op) {
377 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
378 loadOp, subViewOp.getSource(), sourceIndices);
379 })
380 .Case([&](memref::LoadOp op) {
381 rewriter.replaceOpWithNewOp<memref::LoadOp>(
382 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
383 })
384 .Case([&](vector::LoadOp op) {
385 rewriter.replaceOpWithNewOp<vector::LoadOp>(
386 op, op.getType(), subViewOp.getSource(), sourceIndices);
387 })
388 .Case([&](vector::MaskedLoadOp op) {
389 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
390 op, op.getType(), subViewOp.getSource(), sourceIndices,
391 op.getMask(), op.getPassThru());
392 })
393 .Case([&](vector::TransferReadOp op) {
394 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
395 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
396 AffineMapAttr::get(expandDimsToRank(
397 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
398 subViewOp.getDroppedDims())),
399 op.getPadding(), op.getMask(), op.getInBoundsAttr());
400 })
401 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
402 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
403 op, op.getType(), subViewOp.getSource(), sourceIndices,
404 op.getLeadDimension(), op.getTransposeAttr());
405 })
406 .Case([&](nvgpu::LdMatrixOp op) {
407 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
408 op, op.getType(), subViewOp.getSource(), sourceIndices,
409 op.getTranspose(), op.getNumTiles());
410 })
411 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
412 return success();
413}
414
415template <typename OpTy>
416LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
417 OpTy loadOp, PatternRewriter &rewriter) const {
418 auto expandShapeOp =
419 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
420
421 if (!expandShapeOp)
422 return failure();
423
424 SmallVector<Value> indices(loadOp.getIndices().begin(),
425 loadOp.getIndices().end());
426 // For affine ops, we need to apply the map to get the operands to get the
427 // "actual" indices.
428 if (auto affineLoadOp =
429 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
430 AffineMap affineMap = affineLoadOp.getAffineMap();
431 auto expandedIndices = calculateExpandedAccessIndices(
432 affineMap, indices, loadOp.getLoc(), rewriter);
433 indices.assign(expandedIndices.begin(), expandedIndices.end());
434 }
435 SmallVector<Value> sourceIndices;
436 // memref.load and affine.load guarantee that indexes start inbounds
437 // while the vector operations don't. This impacts if our linearization
438 // is `disjoint`
439 if (failed(resolveSourceIndicesExpandShape(
440 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
441 isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
442 return failure();
443 llvm::TypeSwitch<Operation *, void>(loadOp)
444 .Case([&](affine::AffineLoadOp op) {
445 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
446 loadOp, expandShapeOp.getViewSource(), sourceIndices);
447 })
448 .Case([&](memref::LoadOp op) {
449 rewriter.replaceOpWithNewOp<memref::LoadOp>(
450 loadOp, expandShapeOp.getViewSource(), sourceIndices,
451 op.getNontemporal());
452 })
453 .Case([&](vector::LoadOp op) {
454 rewriter.replaceOpWithNewOp<vector::LoadOp>(
455 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
456 op.getNontemporal());
457 })
458 .Case([&](vector::MaskedLoadOp op) {
459 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
460 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
461 op.getMask(), op.getPassThru());
462 })
463 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
464 return success();
465}
466
467template <typename OpTy>
468LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
469 OpTy loadOp, PatternRewriter &rewriter) const {
470 auto collapseShapeOp = getMemRefOperand(loadOp)
471 .template getDefiningOp<memref::CollapseShapeOp>();
472
473 if (!collapseShapeOp)
474 return failure();
475
476 SmallVector<Value> indices(loadOp.getIndices().begin(),
477 loadOp.getIndices().end());
478 // For affine ops, we need to apply the map to get the operands to get the
479 // "actual" indices.
480 if (auto affineLoadOp =
481 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
482 AffineMap affineMap = affineLoadOp.getAffineMap();
483 auto expandedIndices = calculateExpandedAccessIndices(
484 affineMap, indices, loadOp.getLoc(), rewriter);
485 indices.assign(expandedIndices.begin(), expandedIndices.end());
486 }
487 SmallVector<Value> sourceIndices;
488 if (failed(resolveSourceIndicesCollapseShape(
489 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
490 return failure();
491 llvm::TypeSwitch<Operation *, void>(loadOp)
492 .Case([&](affine::AffineLoadOp op) {
493 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
494 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
495 })
496 .Case([&](memref::LoadOp op) {
497 rewriter.replaceOpWithNewOp<memref::LoadOp>(
498 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
499 op.getNontemporal());
500 })
501 .Case([&](vector::LoadOp op) {
502 rewriter.replaceOpWithNewOp<vector::LoadOp>(
503 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
504 op.getNontemporal());
505 })
506 .Case([&](vector::MaskedLoadOp op) {
507 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
508 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
509 op.getMask(), op.getPassThru());
510 })
511 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
512 return success();
513}
514
515template <typename OpTy>
516LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
517 OpTy storeOp, PatternRewriter &rewriter) const {
518 auto subViewOp =
519 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
520
521 if (!subViewOp)
522 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
523
524 LogicalResult preconditionResult =
525 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
526 if (failed(Result: preconditionResult))
527 return preconditionResult;
528
529 SmallVector<Value> indices(storeOp.getIndices().begin(),
530 storeOp.getIndices().end());
531 // For affine ops, we need to apply the map to get the operands to get the
532 // "actual" indices.
533 if (auto affineStoreOp =
534 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
535 AffineMap affineMap = affineStoreOp.getAffineMap();
536 auto expandedIndices = calculateExpandedAccessIndices(
537 affineMap, indices, storeOp.getLoc(), rewriter);
538 indices.assign(expandedIndices.begin(), expandedIndices.end());
539 }
540 SmallVector<Value> sourceIndices;
541 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
542 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
543 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
544 sourceIndices);
545
546 llvm::TypeSwitch<Operation *, void>(storeOp)
547 .Case([&](affine::AffineStoreOp op) {
548 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
549 op, op.getValue(), subViewOp.getSource(), sourceIndices);
550 })
551 .Case([&](memref::StoreOp op) {
552 rewriter.replaceOpWithNewOp<memref::StoreOp>(
553 op, op.getValue(), subViewOp.getSource(), sourceIndices,
554 op.getNontemporal());
555 })
556 .Case([&](vector::TransferWriteOp op) {
557 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
558 op, op.getValue(), subViewOp.getSource(), sourceIndices,
559 AffineMapAttr::get(expandDimsToRank(
560 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
561 subViewOp.getDroppedDims())),
562 op.getMask(), op.getInBoundsAttr());
563 })
564 .Case([&](vector::StoreOp op) {
565 rewriter.replaceOpWithNewOp<vector::StoreOp>(
566 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
567 })
568 .Case([&](vector::MaskedStoreOp op) {
569 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
570 op, subViewOp.getSource(), sourceIndices, op.getMask(),
571 op.getValueToStore());
572 })
573 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
574 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
575 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
576 op.getLeadDimension(), op.getTransposeAttr());
577 })
578 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
579 return success();
580}
581
582template <typename OpTy>
583LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
584 OpTy storeOp, PatternRewriter &rewriter) const {
585 auto expandShapeOp =
586 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
587
588 if (!expandShapeOp)
589 return failure();
590
591 SmallVector<Value> indices(storeOp.getIndices().begin(),
592 storeOp.getIndices().end());
593 // For affine ops, we need to apply the map to get the operands to get the
594 // "actual" indices.
595 if (auto affineStoreOp =
596 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
597 AffineMap affineMap = affineStoreOp.getAffineMap();
598 auto expandedIndices = calculateExpandedAccessIndices(
599 affineMap, indices, storeOp.getLoc(), rewriter);
600 indices.assign(expandedIndices.begin(), expandedIndices.end());
601 }
602 SmallVector<Value> sourceIndices;
603 // memref.store and affine.store guarantee that indexes start inbounds
604 // while the vector operations don't. This impacts if our linearization
605 // is `disjoint`
606 if (failed(resolveSourceIndicesExpandShape(
607 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
608 isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
609 return failure();
610 llvm::TypeSwitch<Operation *, void>(storeOp)
611 .Case([&](affine::AffineStoreOp op) {
612 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
613 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
614 sourceIndices);
615 })
616 .Case([&](memref::StoreOp op) {
617 rewriter.replaceOpWithNewOp<memref::StoreOp>(
618 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
619 sourceIndices, op.getNontemporal());
620 })
621 .Case([&](vector::StoreOp op) {
622 rewriter.replaceOpWithNewOp<vector::StoreOp>(
623 op, op.getValueToStore(), expandShapeOp.getViewSource(),
624 sourceIndices, op.getNontemporal());
625 })
626 .Case([&](vector::MaskedStoreOp op) {
627 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
628 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
629 op.getValueToStore());
630 })
631 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
632 return success();
633}
634
635template <typename OpTy>
636LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
637 OpTy storeOp, PatternRewriter &rewriter) const {
638 auto collapseShapeOp = getMemRefOperand(storeOp)
639 .template getDefiningOp<memref::CollapseShapeOp>();
640
641 if (!collapseShapeOp)
642 return failure();
643
644 SmallVector<Value> indices(storeOp.getIndices().begin(),
645 storeOp.getIndices().end());
646 // For affine ops, we need to apply the map to get the operands to get the
647 // "actual" indices.
648 if (auto affineStoreOp =
649 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
650 AffineMap affineMap = affineStoreOp.getAffineMap();
651 auto expandedIndices = calculateExpandedAccessIndices(
652 affineMap, indices, storeOp.getLoc(), rewriter);
653 indices.assign(expandedIndices.begin(), expandedIndices.end());
654 }
655 SmallVector<Value> sourceIndices;
656 if (failed(resolveSourceIndicesCollapseShape(
657 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
658 return failure();
659 llvm::TypeSwitch<Operation *, void>(storeOp)
660 .Case([&](affine::AffineStoreOp op) {
661 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
662 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
663 sourceIndices);
664 })
665 .Case([&](memref::StoreOp op) {
666 rewriter.replaceOpWithNewOp<memref::StoreOp>(
667 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
668 sourceIndices, op.getNontemporal());
669 })
670 .Case([&](vector::StoreOp op) {
671 rewriter.replaceOpWithNewOp<vector::StoreOp>(
672 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
673 sourceIndices, op.getNontemporal());
674 })
675 .Case([&](vector::MaskedStoreOp op) {
676 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
677 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
678 op.getValueToStore());
679 })
680 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
681 return success();
682}
683
684LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
685 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
686
687 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
688
689 auto srcSubViewOp =
690 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
691 auto dstSubViewOp =
692 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
693
694 if (!(srcSubViewOp || dstSubViewOp))
695 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
696 "source or destination");
697
698 // If the source is a subview, we need to resolve the indices.
699 SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
700 copyOp.getSrcIndices().end());
701 SmallVector<Value> foldedSrcIndices(srcindices);
702
703 if (srcSubViewOp) {
704 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
705 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
706 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
707 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
708 srcindices, foldedSrcIndices);
709 }
710
711 // If the destination is a subview, we need to resolve the indices.
712 SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
713 copyOp.getDstIndices().end());
714 SmallVector<Value> foldedDstIndices(dstindices);
715
716 if (dstSubViewOp) {
717 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
718 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
719 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
720 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
721 dstindices, foldedDstIndices);
722 }
723
724 // Replace the copy op with a new copy op that uses the source and destination
725 // of the subview.
726 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
727 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
728 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
729 foldedDstIndices,
730 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
731 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
732 copyOp.getBypassL1Attr());
733
734 return success();
735}
736
737void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
738 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
739 LoadOpOfSubViewOpFolder<memref::LoadOp>,
740 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
741 LoadOpOfSubViewOpFolder<vector::LoadOp>,
742 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
743 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
744 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
745 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
746 StoreOpOfSubViewOpFolder<memref::StoreOp>,
747 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
748 StoreOpOfSubViewOpFolder<vector::StoreOp>,
749 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
750 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
751 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
752 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
753 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
754 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
755 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
756 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
757 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
758 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
759 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
760 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
761 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
762 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
763 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
764 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
765 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
766 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
767 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
768 patterns.getContext());
769}
770
771//===----------------------------------------------------------------------===//
772// Pass registration
773//===----------------------------------------------------------------------===//
774
775namespace {
776
777struct FoldMemRefAliasOpsPass final
778 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
779 void runOnOperation() override;
780};
781
782} // namespace
783
784void FoldMemRefAliasOpsPass::runOnOperation() {
785 RewritePatternSet patterns(&getContext());
786 memref::populateFoldMemRefAliasOpPatterns(patterns);
787 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
788}
789

source code of mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp