1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/GPU/IR/GPUDialect.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/MemRef/Transforms/Passes.h"
20#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
22#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/IR/AffineMap.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/Debug.h"
30
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
33
34namespace mlir {
35namespace memref {
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38} // namespace memref
39} // namespace mlir
40
41using namespace mlir;
42
43//===----------------------------------------------------------------------===//
44// Utility functions
45//===----------------------------------------------------------------------===//
46
47/// Given the 'indices' of a load/store operation where the memref is a result
48/// of a expand_shape op, returns the indices w.r.t to the source memref of the
49/// expand_shape op. For example
50///
51/// %0 = ... : memref<12x42xf32>
52/// %1 = memref.expand_shape %0 [[0, 1], [2]]
53/// : memref<12x42xf32> into memref<2x6x42xf32>
54/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
55///
56/// could be folded into
57///
58/// %2 = load %0[6 * i1 + i2, %i3] :
59/// memref<12x42xf32>
60static LogicalResult resolveSourceIndicesExpandShape(
61 Location loc, PatternRewriter &rewriter,
62 memref::ExpandShapeOp expandShapeOp, ValueRange indices,
63 SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
64 SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
65
66 // Traverse all reassociation groups to determine the appropriate indices
67 // corresponding to each one of them post op folding.
68 for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
69 assert(!group.empty() && "association indices groups cannot be empty");
70 int64_t groupSize = group.size();
71 if (groupSize == 1) {
72 sourceIndices.push_back(Elt: indices[group[0]]);
73 continue;
74 }
75 SmallVector<OpFoldResult> groupBasis =
76 llvm::map_to_vector(C&: group, F: [&](int64_t d) { return destShape[d]; });
77 SmallVector<Value> groupIndices =
78 llvm::map_to_vector(C&: group, F: [&](int64_t d) { return indices[d]; });
79 Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
80 location: loc, args&: groupIndices, args&: groupBasis, /*disjoint=*/args&: startsInbounds);
81 sourceIndices.push_back(Elt: collapsedIndex);
82 }
83 return success();
84}
85
86/// Given the 'indices' of a load/store operation where the memref is a result
87/// of a collapse_shape op, returns the indices w.r.t to the source memref of
88/// the collapse_shape op. For example
89///
90/// %0 = ... : memref<2x6x42xf32>
91/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
92/// : memref<2x6x42xf32> into memref<12x42xf32>
93/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
94///
95/// could be folded into
96///
97/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
98/// memref<2x6x42xf32>
99static LogicalResult
100resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
101 memref::CollapseShapeOp collapseShapeOp,
102 ValueRange indices,
103 SmallVectorImpl<Value> &sourceIndices) {
104 // Note: collapse_shape requires a strided memref, we can do this.
105 auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
106 location: loc, args: collapseShapeOp.getSrc());
107 SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
108 for (auto [index, group] :
109 llvm::zip(t&: indices, u: collapseShapeOp.getReassociationIndices())) {
110 assert(!group.empty() && "association indices groups cannot be empty");
111 int64_t groupSize = group.size();
112
113 if (groupSize == 1) {
114 sourceIndices.push_back(Elt: index);
115 continue;
116 }
117
118 SmallVector<OpFoldResult> basis =
119 llvm::map_to_vector(C&: group, F: [&](int64_t d) { return sourceSizes[d]; });
120 auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
121 location: loc, args&: index, args&: basis, /*hasOuterBound=*/args: true);
122 llvm::append_range(C&: sourceIndices, R: delinearize.getResults());
123 }
124 if (collapseShapeOp.getReassociationIndices().empty()) {
125 auto zeroAffineMap = rewriter.getConstantAffineMap(val: 0);
126 int64_t srcRank =
127 cast<MemRefType>(Val: collapseShapeOp.getViewSource().getType()).getRank();
128 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
129 b&: rewriter, loc, map: zeroAffineMap, operands: ArrayRef<OpFoldResult>{});
130 for (int64_t i = 0; i < srcRank; i++) {
131 sourceIndices.push_back(
132 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr));
133 }
134 }
135 return success();
136}
137
138/// Helpers to access the memref operand for each op.
139template <typename LoadOrStoreOpTy>
140static Value getMemRefOperand(LoadOrStoreOpTy op) {
141 return op.getMemref();
142}
143
144static Value getMemRefOperand(vector::TransferReadOp op) {
145 return op.getBase();
146}
147
148static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
149 return op.getSrcMemref();
150}
151
152static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
153
154static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
155
156static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
157
158static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
159
160static Value getMemRefOperand(vector::TransferWriteOp op) {
161 return op.getBase();
162}
163
164static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
165 return op.getSrcMemref();
166}
167
168static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
169 return op.getDstMemref();
170}
171
172//===----------------------------------------------------------------------===//
173// Patterns
174//===----------------------------------------------------------------------===//
175
176namespace {
177/// Merges subview operation with load/transferRead operation.
178template <typename OpTy>
179class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
180public:
181 using OpRewritePattern<OpTy>::OpRewritePattern;
182
183 LogicalResult matchAndRewrite(OpTy loadOp,
184 PatternRewriter &rewriter) const override;
185};
186
187/// Merges expand_shape operation with load/transferRead operation.
188template <typename OpTy>
189class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
190public:
191 using OpRewritePattern<OpTy>::OpRewritePattern;
192
193 LogicalResult matchAndRewrite(OpTy loadOp,
194 PatternRewriter &rewriter) const override;
195};
196
197/// Merges collapse_shape operation with load/transferRead operation.
198template <typename OpTy>
199class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
200public:
201 using OpRewritePattern<OpTy>::OpRewritePattern;
202
203 LogicalResult matchAndRewrite(OpTy loadOp,
204 PatternRewriter &rewriter) const override;
205};
206
207/// Merges subview operation with store/transferWriteOp operation.
208template <typename OpTy>
209class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
210public:
211 using OpRewritePattern<OpTy>::OpRewritePattern;
212
213 LogicalResult matchAndRewrite(OpTy storeOp,
214 PatternRewriter &rewriter) const override;
215};
216
217/// Merges expand_shape operation with store/transferWriteOp operation.
218template <typename OpTy>
219class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
220public:
221 using OpRewritePattern<OpTy>::OpRewritePattern;
222
223 LogicalResult matchAndRewrite(OpTy storeOp,
224 PatternRewriter &rewriter) const override;
225};
226
227/// Merges collapse_shape operation with store/transferWriteOp operation.
228template <typename OpTy>
229class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
230public:
231 using OpRewritePattern<OpTy>::OpRewritePattern;
232
233 LogicalResult matchAndRewrite(OpTy storeOp,
234 PatternRewriter &rewriter) const override;
235};
236
237/// Folds subview(subview(x)) to a single subview(x).
238class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
239public:
240 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
241
242 LogicalResult matchAndRewrite(memref::SubViewOp subView,
243 PatternRewriter &rewriter) const override {
244 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
245 if (!srcSubView)
246 return failure();
247
248 // TODO: relax unit stride assumption.
249 if (!subView.hasUnitStride()) {
250 return rewriter.notifyMatchFailure(arg&: subView, msg: "requires unit strides");
251 }
252 if (!srcSubView.hasUnitStride()) {
253 return rewriter.notifyMatchFailure(arg&: srcSubView, msg: "requires unit strides");
254 }
255
256 // Resolve sizes according to dropped dims.
257 SmallVector<OpFoldResult> resolvedSizes;
258 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
259 affine::resolveSizesIntoOpWithSizes(sourceSizes: srcSubView.getMixedSizes(),
260 destSizes: subView.getMixedSizes(), rankReducedSourceDims: srcDroppedDims,
261 resolvedSizes);
262
263 // Resolve offsets according to source offsets and strides.
264 SmallVector<Value> resolvedOffsets;
265 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
266 rewriter, loc: subView.getLoc(), mixedSourceOffsets: srcSubView.getMixedOffsets(),
267 mixedSourceStrides: srcSubView.getMixedStrides(), rankReducedDims: srcDroppedDims, consumerIndices: subView.getMixedOffsets(),
268 resolvedIndices&: resolvedOffsets);
269
270 // Replace original op.
271 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
272 op: subView, args: subView.getType(), args: srcSubView.getSource(),
273 args: getAsOpFoldResult(values: resolvedOffsets), args&: resolvedSizes,
274 args: srcSubView.getMixedStrides());
275
276 return success();
277 }
278};
279
280/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
281/// is folds subview on src and dst memref of the copy.
282class NVGPUAsyncCopyOpSubViewOpFolder final
283 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
284public:
285 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
286
287 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
288 PatternRewriter &rewriter) const override;
289};
290} // namespace
291
292static SmallVector<Value>
293calculateExpandedAccessIndices(AffineMap affineMap,
294 const SmallVector<Value> &indices, Location loc,
295 PatternRewriter &rewriter) {
296 SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
297 Range: llvm::map_range(C: indices, F: [](Value v) -> OpFoldResult { return v; })));
298 SmallVector<Value> expandedIndices;
299 for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
300 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
301 b&: rewriter, loc, map: affineMap.getSubMap(resultPos: {i}), operands: indicesOfr);
302 expandedIndices.push_back(
303 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr));
304 }
305 return expandedIndices;
306}
307
308template <typename XferOp>
309static LogicalResult
310preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
311 memref::SubViewOp subviewOp) {
312 static_assert(
313 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
314 "must be a vector transfer op");
315 if (xferOp.hasOutOfBoundsDim())
316 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
317 if (!subviewOp.hasUnitStride()) {
318 return rewriter.notifyMatchFailure(
319 xferOp, "non-1 stride subview, need to track strides in folded memref");
320 }
321 return success();
322}
323
324static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
325 Operation *op,
326 memref::SubViewOp subviewOp) {
327 return success();
328}
329
330static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
331 vector::TransferReadOp readOp,
332 memref::SubViewOp subviewOp) {
333 return preconditionsFoldSubViewOpImpl(rewriter, xferOp: readOp, subviewOp);
334}
335
336static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
337 vector::TransferWriteOp writeOp,
338 memref::SubViewOp subviewOp) {
339 return preconditionsFoldSubViewOpImpl(rewriter, xferOp: writeOp, subviewOp);
340}
341
342template <typename OpTy>
343LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
344 OpTy loadOp, PatternRewriter &rewriter) const {
345 auto subViewOp =
346 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
347
348 if (!subViewOp)
349 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
350
351 LogicalResult preconditionResult =
352 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
353 if (failed(Result: preconditionResult))
354 return preconditionResult;
355
356 SmallVector<Value> indices(loadOp.getIndices().begin(),
357 loadOp.getIndices().end());
358 // For affine ops, we need to apply the map to get the operands to get the
359 // "actual" indices.
360 if (auto affineLoadOp =
361 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
362 AffineMap affineMap = affineLoadOp.getAffineMap();
363 auto expandedIndices = calculateExpandedAccessIndices(
364 affineMap, indices, loadOp.getLoc(), rewriter);
365 indices.assign(expandedIndices.begin(), expandedIndices.end());
366 }
367 SmallVector<Value> sourceIndices;
368 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
369 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
370 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
371 sourceIndices);
372
373 llvm::TypeSwitch<Operation *, void>(loadOp)
374 .Case([&](affine::AffineLoadOp op) {
375 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
376 loadOp, subViewOp.getSource(), sourceIndices);
377 })
378 .Case([&](memref::LoadOp op) {
379 rewriter.replaceOpWithNewOp<memref::LoadOp>(
380 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
381 })
382 .Case([&](vector::LoadOp op) {
383 rewriter.replaceOpWithNewOp<vector::LoadOp>(
384 op, op.getType(), subViewOp.getSource(), sourceIndices);
385 })
386 .Case([&](vector::MaskedLoadOp op) {
387 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
388 op, op.getType(), subViewOp.getSource(), sourceIndices,
389 op.getMask(), op.getPassThru());
390 })
391 .Case([&](vector::TransferReadOp op) {
392 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
393 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
394 AffineMapAttr::get(value: expandDimsToRank(
395 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
396 subViewOp.getDroppedDims())),
397 op.getPadding(), op.getMask(), op.getInBoundsAttr());
398 })
399 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
400 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
401 op, op.getType(), subViewOp.getSource(), sourceIndices,
402 op.getLeadDimension(), op.getTransposeAttr());
403 })
404 .Case([&](nvgpu::LdMatrixOp op) {
405 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
406 op, op.getType(), subViewOp.getSource(), sourceIndices,
407 op.getTranspose(), op.getNumTiles());
408 })
409 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
410 return success();
411}
412
413template <typename OpTy>
414LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
415 OpTy loadOp, PatternRewriter &rewriter) const {
416 auto expandShapeOp =
417 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
418
419 if (!expandShapeOp)
420 return failure();
421
422 SmallVector<Value> indices(loadOp.getIndices().begin(),
423 loadOp.getIndices().end());
424 // For affine ops, we need to apply the map to get the operands to get the
425 // "actual" indices.
426 if (auto affineLoadOp =
427 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
428 AffineMap affineMap = affineLoadOp.getAffineMap();
429 auto expandedIndices = calculateExpandedAccessIndices(
430 affineMap, indices, loadOp.getLoc(), rewriter);
431 indices.assign(expandedIndices.begin(), expandedIndices.end());
432 }
433 SmallVector<Value> sourceIndices;
434 // memref.load and affine.load guarantee that indexes start inbounds
435 // while the vector operations don't. This impacts if our linearization
436 // is `disjoint`
437 if (failed(resolveSourceIndicesExpandShape(
438 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
439 isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
440 return failure();
441 llvm::TypeSwitch<Operation *, void>(loadOp)
442 .Case([&](affine::AffineLoadOp op) {
443 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
444 loadOp, expandShapeOp.getViewSource(), sourceIndices);
445 })
446 .Case([&](memref::LoadOp op) {
447 rewriter.replaceOpWithNewOp<memref::LoadOp>(
448 loadOp, expandShapeOp.getViewSource(), sourceIndices,
449 op.getNontemporal());
450 })
451 .Case([&](vector::LoadOp op) {
452 rewriter.replaceOpWithNewOp<vector::LoadOp>(
453 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
454 op.getNontemporal());
455 })
456 .Case([&](vector::MaskedLoadOp op) {
457 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
458 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
459 op.getMask(), op.getPassThru());
460 })
461 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
462 return success();
463}
464
465template <typename OpTy>
466LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
467 OpTy loadOp, PatternRewriter &rewriter) const {
468 auto collapseShapeOp = getMemRefOperand(loadOp)
469 .template getDefiningOp<memref::CollapseShapeOp>();
470
471 if (!collapseShapeOp)
472 return failure();
473
474 SmallVector<Value> indices(loadOp.getIndices().begin(),
475 loadOp.getIndices().end());
476 // For affine ops, we need to apply the map to get the operands to get the
477 // "actual" indices.
478 if (auto affineLoadOp =
479 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
480 AffineMap affineMap = affineLoadOp.getAffineMap();
481 auto expandedIndices = calculateExpandedAccessIndices(
482 affineMap, indices, loadOp.getLoc(), rewriter);
483 indices.assign(expandedIndices.begin(), expandedIndices.end());
484 }
485 SmallVector<Value> sourceIndices;
486 if (failed(resolveSourceIndicesCollapseShape(
487 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
488 return failure();
489 llvm::TypeSwitch<Operation *, void>(loadOp)
490 .Case([&](affine::AffineLoadOp op) {
491 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
492 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
493 })
494 .Case([&](memref::LoadOp op) {
495 rewriter.replaceOpWithNewOp<memref::LoadOp>(
496 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
497 op.getNontemporal());
498 })
499 .Case([&](vector::LoadOp op) {
500 rewriter.replaceOpWithNewOp<vector::LoadOp>(
501 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
502 op.getNontemporal());
503 })
504 .Case([&](vector::MaskedLoadOp op) {
505 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
506 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
507 op.getMask(), op.getPassThru());
508 })
509 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
510 return success();
511}
512
513template <typename OpTy>
514LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
515 OpTy storeOp, PatternRewriter &rewriter) const {
516 auto subViewOp =
517 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
518
519 if (!subViewOp)
520 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
521
522 LogicalResult preconditionResult =
523 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
524 if (failed(Result: preconditionResult))
525 return preconditionResult;
526
527 SmallVector<Value> indices(storeOp.getIndices().begin(),
528 storeOp.getIndices().end());
529 // For affine ops, we need to apply the map to get the operands to get the
530 // "actual" indices.
531 if (auto affineStoreOp =
532 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
533 AffineMap affineMap = affineStoreOp.getAffineMap();
534 auto expandedIndices = calculateExpandedAccessIndices(
535 affineMap, indices, storeOp.getLoc(), rewriter);
536 indices.assign(expandedIndices.begin(), expandedIndices.end());
537 }
538 SmallVector<Value> sourceIndices;
539 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
540 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
541 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
542 sourceIndices);
543
544 llvm::TypeSwitch<Operation *, void>(storeOp)
545 .Case([&](affine::AffineStoreOp op) {
546 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
547 op, op.getValue(), subViewOp.getSource(), sourceIndices);
548 })
549 .Case([&](memref::StoreOp op) {
550 rewriter.replaceOpWithNewOp<memref::StoreOp>(
551 op, op.getValue(), subViewOp.getSource(), sourceIndices,
552 op.getNontemporal());
553 })
554 .Case([&](vector::TransferWriteOp op) {
555 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
556 op, op.getValue(), subViewOp.getSource(), sourceIndices,
557 AffineMapAttr::get(value: expandDimsToRank(
558 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
559 subViewOp.getDroppedDims())),
560 op.getMask(), op.getInBoundsAttr());
561 })
562 .Case([&](vector::StoreOp op) {
563 rewriter.replaceOpWithNewOp<vector::StoreOp>(
564 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
565 })
566 .Case([&](vector::MaskedStoreOp op) {
567 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
568 op, subViewOp.getSource(), sourceIndices, op.getMask(),
569 op.getValueToStore());
570 })
571 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
572 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
573 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
574 op.getLeadDimension(), op.getTransposeAttr());
575 })
576 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
577 return success();
578}
579
580template <typename OpTy>
581LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
582 OpTy storeOp, PatternRewriter &rewriter) const {
583 auto expandShapeOp =
584 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
585
586 if (!expandShapeOp)
587 return failure();
588
589 SmallVector<Value> indices(storeOp.getIndices().begin(),
590 storeOp.getIndices().end());
591 // For affine ops, we need to apply the map to get the operands to get the
592 // "actual" indices.
593 if (auto affineStoreOp =
594 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
595 AffineMap affineMap = affineStoreOp.getAffineMap();
596 auto expandedIndices = calculateExpandedAccessIndices(
597 affineMap, indices, storeOp.getLoc(), rewriter);
598 indices.assign(expandedIndices.begin(), expandedIndices.end());
599 }
600 SmallVector<Value> sourceIndices;
601 // memref.store and affine.store guarantee that indexes start inbounds
602 // while the vector operations don't. This impacts if our linearization
603 // is `disjoint`
604 if (failed(resolveSourceIndicesExpandShape(
605 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
606 isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
607 return failure();
608 llvm::TypeSwitch<Operation *, void>(storeOp)
609 .Case([&](affine::AffineStoreOp op) {
610 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
611 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
612 sourceIndices);
613 })
614 .Case([&](memref::StoreOp op) {
615 rewriter.replaceOpWithNewOp<memref::StoreOp>(
616 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
617 sourceIndices, op.getNontemporal());
618 })
619 .Case([&](vector::StoreOp op) {
620 rewriter.replaceOpWithNewOp<vector::StoreOp>(
621 op, op.getValueToStore(), expandShapeOp.getViewSource(),
622 sourceIndices, op.getNontemporal());
623 })
624 .Case([&](vector::MaskedStoreOp op) {
625 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
626 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
627 op.getValueToStore());
628 })
629 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
630 return success();
631}
632
633template <typename OpTy>
634LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
635 OpTy storeOp, PatternRewriter &rewriter) const {
636 auto collapseShapeOp = getMemRefOperand(storeOp)
637 .template getDefiningOp<memref::CollapseShapeOp>();
638
639 if (!collapseShapeOp)
640 return failure();
641
642 SmallVector<Value> indices(storeOp.getIndices().begin(),
643 storeOp.getIndices().end());
644 // For affine ops, we need to apply the map to get the operands to get the
645 // "actual" indices.
646 if (auto affineStoreOp =
647 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
648 AffineMap affineMap = affineStoreOp.getAffineMap();
649 auto expandedIndices = calculateExpandedAccessIndices(
650 affineMap, indices, storeOp.getLoc(), rewriter);
651 indices.assign(expandedIndices.begin(), expandedIndices.end());
652 }
653 SmallVector<Value> sourceIndices;
654 if (failed(resolveSourceIndicesCollapseShape(
655 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
656 return failure();
657 llvm::TypeSwitch<Operation *, void>(storeOp)
658 .Case([&](affine::AffineStoreOp op) {
659 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
660 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
661 sourceIndices);
662 })
663 .Case([&](memref::StoreOp op) {
664 rewriter.replaceOpWithNewOp<memref::StoreOp>(
665 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
666 sourceIndices, op.getNontemporal());
667 })
668 .Case([&](vector::StoreOp op) {
669 rewriter.replaceOpWithNewOp<vector::StoreOp>(
670 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
671 sourceIndices, op.getNontemporal());
672 })
673 .Case([&](vector::MaskedStoreOp op) {
674 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
675 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
676 op.getValueToStore());
677 })
678 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
679 return success();
680}
681
682LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
683 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
684
685 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
686
687 auto srcSubViewOp =
688 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
689 auto dstSubViewOp =
690 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
691
692 if (!(srcSubViewOp || dstSubViewOp))
693 return rewriter.notifyMatchFailure(arg&: copyOp, msg: "does not use subview ops for "
694 "source or destination");
695
696 // If the source is a subview, we need to resolve the indices.
697 SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
698 copyOp.getSrcIndices().end());
699 SmallVector<Value> foldedSrcIndices(srcindices);
700
701 if (srcSubViewOp) {
702 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
703 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
704 rewriter, loc: copyOp.getLoc(), mixedSourceOffsets: srcSubViewOp.getMixedOffsets(),
705 mixedSourceStrides: srcSubViewOp.getMixedStrides(), rankReducedDims: srcSubViewOp.getDroppedDims(),
706 consumerIndices: srcindices, resolvedIndices&: foldedSrcIndices);
707 }
708
709 // If the destination is a subview, we need to resolve the indices.
710 SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
711 copyOp.getDstIndices().end());
712 SmallVector<Value> foldedDstIndices(dstindices);
713
714 if (dstSubViewOp) {
715 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
716 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
717 rewriter, loc: copyOp.getLoc(), mixedSourceOffsets: dstSubViewOp.getMixedOffsets(),
718 mixedSourceStrides: dstSubViewOp.getMixedStrides(), rankReducedDims: dstSubViewOp.getDroppedDims(),
719 consumerIndices: dstindices, resolvedIndices&: foldedDstIndices);
720 }
721
722 // Replace the copy op with a new copy op that uses the source and destination
723 // of the subview.
724 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
725 op: copyOp, args: nvgpu::DeviceAsyncTokenType::get(ctx: copyOp.getContext()),
726 args: (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
727 args&: foldedDstIndices,
728 args: (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
729 args&: foldedSrcIndices, args: copyOp.getDstElements(), args: copyOp.getSrcElements(),
730 args: copyOp.getBypassL1Attr());
731
732 return success();
733}
734
735void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
736 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
737 LoadOpOfSubViewOpFolder<memref::LoadOp>,
738 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
739 LoadOpOfSubViewOpFolder<vector::LoadOp>,
740 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
741 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
742 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
743 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
744 StoreOpOfSubViewOpFolder<memref::StoreOp>,
745 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
746 StoreOpOfSubViewOpFolder<vector::StoreOp>,
747 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
748 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
749 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
750 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
751 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
752 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
753 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
754 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
755 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
756 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
757 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
758 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
759 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
760 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
761 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
762 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
763 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
764 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
765 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
766 arg: patterns.getContext());
767}
768
769//===----------------------------------------------------------------------===//
770// Pass registration
771//===----------------------------------------------------------------------===//
772
773namespace {
774
775struct FoldMemRefAliasOpsPass final
776 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
777 void runOnOperation() override;
778};
779
780} // namespace
781
782void FoldMemRefAliasOpsPass::runOnOperation() {
783 RewritePatternSet patterns(&getContext());
784 memref::populateFoldMemRefAliasOpPatterns(patterns);
785 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
786}
787

source code of mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp