1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/Dialect/MemRef/Transforms/Passes.h"
21#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
22#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23#include "mlir/Dialect/Utils/IndexingUtils.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/IR/AffineMap.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallBitVector.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
31
32#define DEBUG_TYPE "fold-memref-alias-ops"
33#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
34
35namespace mlir {
36namespace memref {
37#define GEN_PASS_DEF_FOLDMEMREFALIASOPS
38#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
39} // namespace memref
40} // namespace mlir
41
42using namespace mlir;
43
44//===----------------------------------------------------------------------===//
45// Utility functions
46//===----------------------------------------------------------------------===//
47
48/// Given the 'indices' of a load/store operation where the memref is a result
49/// of a expand_shape op, returns the indices w.r.t to the source memref of the
50/// expand_shape op. For example
51///
52/// %0 = ... : memref<12x42xf32>
53/// %1 = memref.expand_shape %0 [[0, 1], [2]]
54/// : memref<12x42xf32> into memref<2x6x42xf32>
55/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
56///
57/// could be folded into
58///
59/// %2 = load %0[6 * i1 + i2, %i3] :
60/// memref<12x42xf32>
61static LogicalResult
62resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
63 memref::ExpandShapeOp expandShapeOp,
64 ValueRange indices,
65 SmallVectorImpl<Value> &sourceIndices) {
66 // The below implementation uses computeSuffixProduct method, which only
67 // allows int64_t values (i.e., static shape). Bail out if it has dynamic
68 // shapes.
69 if (!expandShapeOp.getResultType().hasStaticShape())
70 return failure();
71
72 MLIRContext *ctx = rewriter.getContext();
73 for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
74 assert(!groups.empty() && "association indices groups cannot be empty");
75 int64_t groupSize = groups.size();
76
77 // Construct the expression for the index value w.r.t to expand shape op
78 // source corresponding the indices wrt to expand shape op result.
79 SmallVector<int64_t> sizes(groupSize);
80 for (int64_t i = 0; i < groupSize; ++i)
81 sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
82 SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
83 SmallVector<AffineExpr> dims(groupSize);
84 bindDimsList(ctx, MutableArrayRef{dims});
85 AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
86
87 /// Apply permutation and create AffineApplyOp.
88 SmallVector<OpFoldResult> dynamicIndices(groupSize);
89 for (int64_t i = 0; i < groupSize; i++)
90 dynamicIndices[i] = indices[groups[i]];
91
92 // Creating maximally folded and composd affine.apply composes better with
93 // other transformations without interleaving canonicalization passes.
94 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
95 rewriter, loc,
96 AffineMap::get(/*numDims=*/groupSize,
97 /*numSymbols=*/0, srcIndexExpr),
98 dynamicIndices);
99 sourceIndices.push_back(
100 getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101 }
102 return success();
103}
104
105/// Given the 'indices' of a load/store operation where the memref is a result
106/// of a collapse_shape op, returns the indices w.r.t to the source memref of
107/// the collapse_shape op. For example
108///
109/// %0 = ... : memref<2x6x42xf32>
110/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
111/// : memref<2x6x42xf32> into memref<12x42xf32>
112/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
113///
114/// could be folded into
115///
116/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
117/// memref<2x6x42xf32>
118static LogicalResult
119resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
120 memref::CollapseShapeOp collapseShapeOp,
121 ValueRange indices,
122 SmallVectorImpl<Value> &sourceIndices) {
123 int64_t cnt = 0;
124 SmallVector<Value> tmp(indices.size());
125 SmallVector<OpFoldResult> dynamicIndices;
126 for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
127 assert(!groups.empty() && "association indices groups cannot be empty");
128 dynamicIndices.push_back(indices[cnt++]);
129 int64_t groupSize = groups.size();
130
131 // Calculate suffix product for all collapse op source dimension sizes
132 // except the most major one of each group.
133 // We allow the most major source dimension to be dynamic but enforce all
134 // others to be known statically.
135 SmallVector<int64_t> sizes(groupSize, 1);
136 for (int64_t i = 1; i < groupSize; ++i) {
137 sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
138 if (sizes[i] == ShapedType::kDynamic)
139 return failure();
140 }
141 SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
142
143 // Derive the index values along all dimensions of the source corresponding
144 // to the index wrt to collapsed shape op output.
145 auto d0 = rewriter.getAffineDimExpr(0);
146 SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
147
148 // Construct the AffineApplyOp for each delinearizingExpr.
149 for (int64_t i = 0; i < groupSize; i++) {
150 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
151 rewriter, loc,
152 AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
153 delinearizingExprs[i]),
154 dynamicIndices);
155 sourceIndices.push_back(
156 getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
157 }
158 dynamicIndices.clear();
159 }
160 if (collapseShapeOp.getReassociationIndices().empty()) {
161 auto zeroAffineMap = rewriter.getConstantAffineMap(val: 0);
162 int64_t srcRank =
163 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
164 for (int64_t i = 0; i < srcRank; i++) {
165 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
166 b&: rewriter, loc, map: zeroAffineMap, operands: dynamicIndices);
167 sourceIndices.push_back(
168 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr));
169 }
170 }
171 return success();
172}
173
174/// Helpers to access the memref operand for each op.
175template <typename LoadOrStoreOpTy>
176static Value getMemRefOperand(LoadOrStoreOpTy op) {
177 return op.getMemref();
178}
179
180static Value getMemRefOperand(vector::TransferReadOp op) {
181 return op.getSource();
182}
183
184static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
185 return op.getSrcMemref();
186}
187
188static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
189
190static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
191
192static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
193
194static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
195
196static Value getMemRefOperand(vector::TransferWriteOp op) {
197 return op.getSource();
198}
199
200static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
201 return op.getSrcMemref();
202}
203
204static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
205 return op.getDstMemref();
206}
207
208//===----------------------------------------------------------------------===//
209// Patterns
210//===----------------------------------------------------------------------===//
211
212namespace {
213/// Merges subview operation with load/transferRead operation.
214template <typename OpTy>
215class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
216public:
217 using OpRewritePattern<OpTy>::OpRewritePattern;
218
219 LogicalResult matchAndRewrite(OpTy loadOp,
220 PatternRewriter &rewriter) const override;
221};
222
223/// Merges expand_shape operation with load/transferRead operation.
224template <typename OpTy>
225class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
226public:
227 using OpRewritePattern<OpTy>::OpRewritePattern;
228
229 LogicalResult matchAndRewrite(OpTy loadOp,
230 PatternRewriter &rewriter) const override;
231};
232
233/// Merges collapse_shape operation with load/transferRead operation.
234template <typename OpTy>
235class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
236public:
237 using OpRewritePattern<OpTy>::OpRewritePattern;
238
239 LogicalResult matchAndRewrite(OpTy loadOp,
240 PatternRewriter &rewriter) const override;
241};
242
243/// Merges subview operation with store/transferWriteOp operation.
244template <typename OpTy>
245class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
246public:
247 using OpRewritePattern<OpTy>::OpRewritePattern;
248
249 LogicalResult matchAndRewrite(OpTy storeOp,
250 PatternRewriter &rewriter) const override;
251};
252
253/// Merges expand_shape operation with store/transferWriteOp operation.
254template <typename OpTy>
255class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
256public:
257 using OpRewritePattern<OpTy>::OpRewritePattern;
258
259 LogicalResult matchAndRewrite(OpTy storeOp,
260 PatternRewriter &rewriter) const override;
261};
262
263/// Merges collapse_shape operation with store/transferWriteOp operation.
264template <typename OpTy>
265class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
266public:
267 using OpRewritePattern<OpTy>::OpRewritePattern;
268
269 LogicalResult matchAndRewrite(OpTy storeOp,
270 PatternRewriter &rewriter) const override;
271};
272
273/// Folds subview(subview(x)) to a single subview(x).
274class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
275public:
276 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
277
278 LogicalResult matchAndRewrite(memref::SubViewOp subView,
279 PatternRewriter &rewriter) const override {
280 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
281 if (!srcSubView)
282 return failure();
283
284 // TODO: relax unit stride assumption.
285 if (!subView.hasUnitStride()) {
286 return rewriter.notifyMatchFailure(subView, "requires unit strides");
287 }
288 if (!srcSubView.hasUnitStride()) {
289 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
290 }
291
292 // Resolve sizes according to dropped dims.
293 SmallVector<OpFoldResult> resolvedSizes;
294 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
295 affine::resolveSizesIntoOpWithSizes(sourceSizes: srcSubView.getMixedSizes(),
296 destSizes: subView.getMixedSizes(), rankReducedSourceDims: srcDroppedDims,
297 resolvedSizes);
298
299 // Resolve offsets according to source offsets and strides.
300 SmallVector<Value> resolvedOffsets;
301 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
302 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
303 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
304 resolvedOffsets);
305
306 // Replace original op.
307 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
308 subView, subView.getType(), srcSubView.getSource(),
309 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
310 srcSubView.getMixedStrides());
311
312 return success();
313 }
314};
315
316/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
317/// is folds subview on src and dst memref of the copy.
318class NvgpuAsyncCopyOpSubViewOpFolder final
319 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
320public:
321 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
322
323 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
324 PatternRewriter &rewriter) const override;
325};
326} // namespace
327
328static SmallVector<Value>
329calculateExpandedAccessIndices(AffineMap affineMap,
330 const SmallVector<Value> &indices, Location loc,
331 PatternRewriter &rewriter) {
332 SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
333 Range: llvm::map_range(C: indices, F: [](Value v) -> OpFoldResult { return v; })));
334 SmallVector<Value> expandedIndices;
335 for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
336 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
337 b&: rewriter, loc, map: affineMap.getSubMap(resultPos: {i}), operands: indicesOfr);
338 expandedIndices.push_back(
339 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr));
340 }
341 return expandedIndices;
342}
343
344template <typename XferOp>
345static LogicalResult
346preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
347 memref::SubViewOp subviewOp) {
348 static_assert(
349 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
350 "must be a vector transfer op");
351 if (xferOp.hasOutOfBoundsDim())
352 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
353 if (!subviewOp.hasUnitStride()) {
354 return rewriter.notifyMatchFailure(
355 xferOp, "non-1 stride subview, need to track strides in folded memref");
356 }
357 return success();
358}
359
360static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
361 Operation *op,
362 memref::SubViewOp subviewOp) {
363 return success();
364}
365
366static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
367 vector::TransferReadOp readOp,
368 memref::SubViewOp subviewOp) {
369 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
370}
371
372static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
373 vector::TransferWriteOp writeOp,
374 memref::SubViewOp subviewOp) {
375 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
376}
377
378template <typename OpTy>
379LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
380 OpTy loadOp, PatternRewriter &rewriter) const {
381 auto subViewOp =
382 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
383
384 if (!subViewOp)
385 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
386
387 LogicalResult preconditionResult =
388 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
389 if (failed(result: preconditionResult))
390 return preconditionResult;
391
392 SmallVector<Value> indices(loadOp.getIndices().begin(),
393 loadOp.getIndices().end());
394 // For affine ops, we need to apply the map to get the operands to get the
395 // "actual" indices.
396 if (auto affineLoadOp =
397 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
398 AffineMap affineMap = affineLoadOp.getAffineMap();
399 auto expandedIndices = calculateExpandedAccessIndices(
400 affineMap, indices, loadOp.getLoc(), rewriter);
401 indices.assign(expandedIndices.begin(), expandedIndices.end());
402 }
403 SmallVector<Value> sourceIndices;
404 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
405 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
406 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
407 sourceIndices);
408
409 llvm::TypeSwitch<Operation *, void>(loadOp)
410 .Case([&](affine::AffineLoadOp op) {
411 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
412 loadOp, subViewOp.getSource(), sourceIndices);
413 })
414 .Case([&](memref::LoadOp op) {
415 rewriter.replaceOpWithNewOp<memref::LoadOp>(
416 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
417 })
418 .Case([&](vector::LoadOp op) {
419 rewriter.replaceOpWithNewOp<vector::LoadOp>(
420 op, op.getType(), subViewOp.getSource(), sourceIndices);
421 })
422 .Case([&](vector::MaskedLoadOp op) {
423 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
424 op, op.getType(), subViewOp.getSource(), sourceIndices,
425 op.getMask(), op.getPassThru());
426 })
427 .Case([&](vector::TransferReadOp op) {
428 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
429 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
430 AffineMapAttr::get(expandDimsToRank(
431 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
432 subViewOp.getDroppedDims())),
433 op.getPadding(), op.getMask(), op.getInBoundsAttr());
434 })
435 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
436 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
437 op, op.getType(), subViewOp.getSource(), sourceIndices,
438 op.getLeadDimension(), op.getTransposeAttr());
439 })
440 .Case([&](nvgpu::LdMatrixOp op) {
441 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
442 op, op.getType(), subViewOp.getSource(), sourceIndices,
443 op.getTranspose(), op.getNumTiles());
444 })
445 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
446 return success();
447}
448
449template <typename OpTy>
450LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
451 OpTy loadOp, PatternRewriter &rewriter) const {
452 auto expandShapeOp =
453 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
454
455 if (!expandShapeOp)
456 return failure();
457
458 SmallVector<Value> indices(loadOp.getIndices().begin(),
459 loadOp.getIndices().end());
460 // For affine ops, we need to apply the map to get the operands to get the
461 // "actual" indices.
462 if (auto affineLoadOp =
463 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
464 AffineMap affineMap = affineLoadOp.getAffineMap();
465 auto expandedIndices = calculateExpandedAccessIndices(
466 affineMap, indices, loadOp.getLoc(), rewriter);
467 indices.assign(expandedIndices.begin(), expandedIndices.end());
468 }
469 SmallVector<Value> sourceIndices;
470 if (failed(resolveSourceIndicesExpandShape(
471 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
472 return failure();
473 llvm::TypeSwitch<Operation *, void>(loadOp)
474 .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
475 rewriter.replaceOpWithNewOp<decltype(op)>(
476 loadOp, expandShapeOp.getViewSource(), sourceIndices);
477 })
478 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
479 return success();
480}
481
482template <typename OpTy>
483LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
484 OpTy loadOp, PatternRewriter &rewriter) const {
485 auto collapseShapeOp = getMemRefOperand(loadOp)
486 .template getDefiningOp<memref::CollapseShapeOp>();
487
488 if (!collapseShapeOp)
489 return failure();
490
491 SmallVector<Value> indices(loadOp.getIndices().begin(),
492 loadOp.getIndices().end());
493 // For affine ops, we need to apply the map to get the operands to get the
494 // "actual" indices.
495 if (auto affineLoadOp =
496 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
497 AffineMap affineMap = affineLoadOp.getAffineMap();
498 auto expandedIndices = calculateExpandedAccessIndices(
499 affineMap, indices, loadOp.getLoc(), rewriter);
500 indices.assign(expandedIndices.begin(), expandedIndices.end());
501 }
502 SmallVector<Value> sourceIndices;
503 if (failed(resolveSourceIndicesCollapseShape(
504 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
505 return failure();
506 llvm::TypeSwitch<Operation *, void>(loadOp)
507 .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
508 rewriter.replaceOpWithNewOp<decltype(op)>(
509 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
510 })
511 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
512 return success();
513}
514
515template <typename OpTy>
516LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
517 OpTy storeOp, PatternRewriter &rewriter) const {
518 auto subViewOp =
519 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
520
521 if (!subViewOp)
522 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
523
524 LogicalResult preconditionResult =
525 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
526 if (failed(result: preconditionResult))
527 return preconditionResult;
528
529 SmallVector<Value> indices(storeOp.getIndices().begin(),
530 storeOp.getIndices().end());
531 // For affine ops, we need to apply the map to get the operands to get the
532 // "actual" indices.
533 if (auto affineStoreOp =
534 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
535 AffineMap affineMap = affineStoreOp.getAffineMap();
536 auto expandedIndices = calculateExpandedAccessIndices(
537 affineMap, indices, storeOp.getLoc(), rewriter);
538 indices.assign(expandedIndices.begin(), expandedIndices.end());
539 }
540 SmallVector<Value> sourceIndices;
541 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
542 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
543 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
544 sourceIndices);
545
546 llvm::TypeSwitch<Operation *, void>(storeOp)
547 .Case([&](affine::AffineStoreOp op) {
548 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
549 op, op.getValue(), subViewOp.getSource(), sourceIndices);
550 })
551 .Case([&](memref::StoreOp op) {
552 rewriter.replaceOpWithNewOp<memref::StoreOp>(
553 op, op.getValue(), subViewOp.getSource(), sourceIndices,
554 op.getNontemporal());
555 })
556 .Case([&](vector::TransferWriteOp op) {
557 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
558 op, op.getValue(), subViewOp.getSource(), sourceIndices,
559 AffineMapAttr::get(expandDimsToRank(
560 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
561 subViewOp.getDroppedDims())),
562 op.getMask(), op.getInBoundsAttr());
563 })
564 .Case([&](vector::StoreOp op) {
565 rewriter.replaceOpWithNewOp<vector::StoreOp>(
566 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
567 })
568 .Case([&](vector::MaskedStoreOp op) {
569 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
570 op, subViewOp.getSource(), sourceIndices, op.getMask(),
571 op.getValueToStore());
572 })
573 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
574 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
575 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
576 op.getLeadDimension(), op.getTransposeAttr());
577 })
578 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
579 return success();
580}
581
582template <typename OpTy>
583LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
584 OpTy storeOp, PatternRewriter &rewriter) const {
585 auto expandShapeOp =
586 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
587
588 if (!expandShapeOp)
589 return failure();
590
591 SmallVector<Value> indices(storeOp.getIndices().begin(),
592 storeOp.getIndices().end());
593 // For affine ops, we need to apply the map to get the operands to get the
594 // "actual" indices.
595 if (auto affineStoreOp =
596 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
597 AffineMap affineMap = affineStoreOp.getAffineMap();
598 auto expandedIndices = calculateExpandedAccessIndices(
599 affineMap, indices, storeOp.getLoc(), rewriter);
600 indices.assign(expandedIndices.begin(), expandedIndices.end());
601 }
602 SmallVector<Value> sourceIndices;
603 if (failed(resolveSourceIndicesExpandShape(
604 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
605 return failure();
606 llvm::TypeSwitch<Operation *, void>(storeOp)
607 .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
608 rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
609 expandShapeOp.getViewSource(),
610 sourceIndices);
611 })
612 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
613 return success();
614}
615
616template <typename OpTy>
617LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
618 OpTy storeOp, PatternRewriter &rewriter) const {
619 auto collapseShapeOp = getMemRefOperand(storeOp)
620 .template getDefiningOp<memref::CollapseShapeOp>();
621
622 if (!collapseShapeOp)
623 return failure();
624
625 SmallVector<Value> indices(storeOp.getIndices().begin(),
626 storeOp.getIndices().end());
627 // For affine ops, we need to apply the map to get the operands to get the
628 // "actual" indices.
629 if (auto affineStoreOp =
630 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
631 AffineMap affineMap = affineStoreOp.getAffineMap();
632 auto expandedIndices = calculateExpandedAccessIndices(
633 affineMap, indices, storeOp.getLoc(), rewriter);
634 indices.assign(expandedIndices.begin(), expandedIndices.end());
635 }
636 SmallVector<Value> sourceIndices;
637 if (failed(resolveSourceIndicesCollapseShape(
638 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
639 return failure();
640 llvm::TypeSwitch<Operation *, void>(storeOp)
641 .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
642 rewriter.replaceOpWithNewOp<decltype(op)>(
643 storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
644 sourceIndices);
645 })
646 .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
647 return success();
648}
649
650LogicalResult NvgpuAsyncCopyOpSubViewOpFolder::matchAndRewrite(
651 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
652
653 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
654
655 auto srcSubViewOp =
656 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
657 auto dstSubViewOp =
658 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
659
660 if (!(srcSubViewOp || dstSubViewOp))
661 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
662 "source or destination");
663
664 // If the source is a subview, we need to resolve the indices.
665 SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
666 copyOp.getSrcIndices().end());
667 SmallVector<Value> foldedSrcIndices(srcindices);
668
669 if (srcSubViewOp) {
670 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
671 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
672 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
673 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
674 srcindices, foldedSrcIndices);
675 }
676
677 // If the destination is a subview, we need to resolve the indices.
678 SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
679 copyOp.getDstIndices().end());
680 SmallVector<Value> foldedDstIndices(dstindices);
681
682 if (dstSubViewOp) {
683 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
684 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
685 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
686 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
687 dstindices, foldedDstIndices);
688 }
689
690 // Replace the copy op with a new copy op that uses the source and destination
691 // of the subview.
692 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
693 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
694 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
695 foldedDstIndices,
696 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
697 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
698 copyOp.getBypassL1Attr());
699
700 return success();
701}
702
703void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
704 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
705 LoadOpOfSubViewOpFolder<memref::LoadOp>,
706 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
707 LoadOpOfSubViewOpFolder<vector::LoadOp>,
708 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
709 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
710 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
711 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
712 StoreOpOfSubViewOpFolder<memref::StoreOp>,
713 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
714 StoreOpOfSubViewOpFolder<vector::StoreOp>,
715 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
716 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
717 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
718 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
719 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
720 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
721 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
722 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
723 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
724 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
725 SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder>(
726 patterns.getContext());
727}
728
729//===----------------------------------------------------------------------===//
730// Pass registration
731//===----------------------------------------------------------------------===//
732
733namespace {
734
735struct FoldMemRefAliasOpsPass final
736 : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
737 void runOnOperation() override;
738};
739
740} // namespace
741
742void FoldMemRefAliasOpsPass::runOnOperation() {
743 RewritePatternSet patterns(&getContext());
744 memref::populateFoldMemRefAliasOpPatterns(patterns);
745 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
746}
747
748std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
749 return std::make_unique<FoldMemRefAliasOpsPass>();
750}
751

source code of mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp