1 | //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation pass folds loading/storing from/to subview ops into |
10 | // loading/storing from/to the original memref. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
20 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
21 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
22 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
23 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
24 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
25 | #include "mlir/IR/AffineMap.h" |
26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
27 | #include "llvm/ADT/STLExtras.h" |
28 | #include "llvm/ADT/SmallBitVector.h" |
29 | #include "llvm/ADT/TypeSwitch.h" |
30 | #include "llvm/Support/Debug.h" |
31 | |
32 | #define DEBUG_TYPE "fold-memref-alias-ops" |
33 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
34 | |
35 | namespace mlir { |
36 | namespace memref { |
37 | #define GEN_PASS_DEF_FOLDMEMREFALIASOPS |
38 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
39 | } // namespace memref |
40 | } // namespace mlir |
41 | |
42 | using namespace mlir; |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // Utility functions |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | /// Given the 'indices' of a load/store operation where the memref is a result |
49 | /// of a expand_shape op, returns the indices w.r.t to the source memref of the |
50 | /// expand_shape op. For example |
51 | /// |
52 | /// %0 = ... : memref<12x42xf32> |
53 | /// %1 = memref.expand_shape %0 [[0, 1], [2]] |
54 | /// : memref<12x42xf32> into memref<2x6x42xf32> |
55 | /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32 |
56 | /// |
57 | /// could be folded into |
58 | /// |
59 | /// %2 = load %0[6 * i1 + i2, %i3] : |
60 | /// memref<12x42xf32> |
61 | static LogicalResult |
62 | resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, |
63 | memref::ExpandShapeOp expandShapeOp, |
64 | ValueRange indices, |
65 | SmallVectorImpl<Value> &sourceIndices) { |
66 | // The below implementation uses computeSuffixProduct method, which only |
67 | // allows int64_t values (i.e., static shape). Bail out if it has dynamic |
68 | // shapes. |
69 | if (!expandShapeOp.getResultType().hasStaticShape()) |
70 | return failure(); |
71 | |
72 | MLIRContext *ctx = rewriter.getContext(); |
73 | for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) { |
74 | assert(!groups.empty() && "association indices groups cannot be empty" ); |
75 | int64_t groupSize = groups.size(); |
76 | |
77 | // Construct the expression for the index value w.r.t to expand shape op |
78 | // source corresponding the indices wrt to expand shape op result. |
79 | SmallVector<int64_t> sizes(groupSize); |
80 | for (int64_t i = 0; i < groupSize; ++i) |
81 | sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); |
82 | SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes); |
83 | SmallVector<AffineExpr> dims(groupSize); |
84 | bindDimsList(ctx, MutableArrayRef{dims}); |
85 | AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); |
86 | |
87 | /// Apply permutation and create AffineApplyOp. |
88 | SmallVector<OpFoldResult> dynamicIndices(groupSize); |
89 | for (int64_t i = 0; i < groupSize; i++) |
90 | dynamicIndices[i] = indices[groups[i]]; |
91 | |
92 | // Creating maximally folded and composd affine.apply composes better with |
93 | // other transformations without interleaving canonicalization passes. |
94 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
95 | rewriter, loc, |
96 | AffineMap::get(/*numDims=*/groupSize, |
97 | /*numSymbols=*/0, srcIndexExpr), |
98 | dynamicIndices); |
99 | sourceIndices.push_back( |
100 | getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); |
101 | } |
102 | return success(); |
103 | } |
104 | |
105 | /// Given the 'indices' of a load/store operation where the memref is a result |
106 | /// of a collapse_shape op, returns the indices w.r.t to the source memref of |
107 | /// the collapse_shape op. For example |
108 | /// |
109 | /// %0 = ... : memref<2x6x42xf32> |
110 | /// %1 = memref.collapse_shape %0 [[0, 1], [2]] |
111 | /// : memref<2x6x42xf32> into memref<12x42xf32> |
112 | /// %2 = load %1[%i1, %i2] : memref<12x42xf32> |
113 | /// |
114 | /// could be folded into |
115 | /// |
116 | /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] : |
117 | /// memref<2x6x42xf32> |
118 | static LogicalResult |
119 | resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, |
120 | memref::CollapseShapeOp collapseShapeOp, |
121 | ValueRange indices, |
122 | SmallVectorImpl<Value> &sourceIndices) { |
123 | int64_t cnt = 0; |
124 | SmallVector<Value> tmp(indices.size()); |
125 | SmallVector<OpFoldResult> dynamicIndices; |
126 | for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) { |
127 | assert(!groups.empty() && "association indices groups cannot be empty" ); |
128 | dynamicIndices.push_back(indices[cnt++]); |
129 | int64_t groupSize = groups.size(); |
130 | |
131 | // Calculate suffix product for all collapse op source dimension sizes |
132 | // except the most major one of each group. |
133 | // We allow the most major source dimension to be dynamic but enforce all |
134 | // others to be known statically. |
135 | SmallVector<int64_t> sizes(groupSize, 1); |
136 | for (int64_t i = 1; i < groupSize; ++i) { |
137 | sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); |
138 | if (sizes[i] == ShapedType::kDynamic) |
139 | return failure(); |
140 | } |
141 | SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes); |
142 | |
143 | // Derive the index values along all dimensions of the source corresponding |
144 | // to the index wrt to collapsed shape op output. |
145 | auto d0 = rewriter.getAffineDimExpr(0); |
146 | SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct); |
147 | |
148 | // Construct the AffineApplyOp for each delinearizingExpr. |
149 | for (int64_t i = 0; i < groupSize; i++) { |
150 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
151 | rewriter, loc, |
152 | AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, |
153 | delinearizingExprs[i]), |
154 | dynamicIndices); |
155 | sourceIndices.push_back( |
156 | getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); |
157 | } |
158 | dynamicIndices.clear(); |
159 | } |
160 | if (collapseShapeOp.getReassociationIndices().empty()) { |
161 | auto zeroAffineMap = rewriter.getConstantAffineMap(val: 0); |
162 | int64_t srcRank = |
163 | cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank(); |
164 | for (int64_t i = 0; i < srcRank; i++) { |
165 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
166 | b&: rewriter, loc, map: zeroAffineMap, operands: dynamicIndices); |
167 | sourceIndices.push_back( |
168 | Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr)); |
169 | } |
170 | } |
171 | return success(); |
172 | } |
173 | |
174 | /// Helpers to access the memref operand for each op. |
175 | template <typename LoadOrStoreOpTy> |
176 | static Value getMemRefOperand(LoadOrStoreOpTy op) { |
177 | return op.getMemref(); |
178 | } |
179 | |
180 | static Value getMemRefOperand(vector::TransferReadOp op) { |
181 | return op.getSource(); |
182 | } |
183 | |
184 | static Value getMemRefOperand(nvgpu::LdMatrixOp op) { |
185 | return op.getSrcMemref(); |
186 | } |
187 | |
188 | static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } |
189 | |
190 | static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); } |
191 | |
192 | static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); } |
193 | |
194 | static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); } |
195 | |
196 | static Value getMemRefOperand(vector::TransferWriteOp op) { |
197 | return op.getSource(); |
198 | } |
199 | |
200 | static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) { |
201 | return op.getSrcMemref(); |
202 | } |
203 | |
204 | static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { |
205 | return op.getDstMemref(); |
206 | } |
207 | |
208 | //===----------------------------------------------------------------------===// |
209 | // Patterns |
210 | //===----------------------------------------------------------------------===// |
211 | |
212 | namespace { |
213 | /// Merges subview operation with load/transferRead operation. |
214 | template <typename OpTy> |
215 | class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { |
216 | public: |
217 | using OpRewritePattern<OpTy>::OpRewritePattern; |
218 | |
219 | LogicalResult matchAndRewrite(OpTy loadOp, |
220 | PatternRewriter &rewriter) const override; |
221 | }; |
222 | |
223 | /// Merges expand_shape operation with load/transferRead operation. |
224 | template <typename OpTy> |
225 | class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { |
226 | public: |
227 | using OpRewritePattern<OpTy>::OpRewritePattern; |
228 | |
229 | LogicalResult matchAndRewrite(OpTy loadOp, |
230 | PatternRewriter &rewriter) const override; |
231 | }; |
232 | |
233 | /// Merges collapse_shape operation with load/transferRead operation. |
234 | template <typename OpTy> |
235 | class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { |
236 | public: |
237 | using OpRewritePattern<OpTy>::OpRewritePattern; |
238 | |
239 | LogicalResult matchAndRewrite(OpTy loadOp, |
240 | PatternRewriter &rewriter) const override; |
241 | }; |
242 | |
243 | /// Merges subview operation with store/transferWriteOp operation. |
244 | template <typename OpTy> |
245 | class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { |
246 | public: |
247 | using OpRewritePattern<OpTy>::OpRewritePattern; |
248 | |
249 | LogicalResult matchAndRewrite(OpTy storeOp, |
250 | PatternRewriter &rewriter) const override; |
251 | }; |
252 | |
253 | /// Merges expand_shape operation with store/transferWriteOp operation. |
254 | template <typename OpTy> |
255 | class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { |
256 | public: |
257 | using OpRewritePattern<OpTy>::OpRewritePattern; |
258 | |
259 | LogicalResult matchAndRewrite(OpTy storeOp, |
260 | PatternRewriter &rewriter) const override; |
261 | }; |
262 | |
263 | /// Merges collapse_shape operation with store/transferWriteOp operation. |
264 | template <typename OpTy> |
265 | class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { |
266 | public: |
267 | using OpRewritePattern<OpTy>::OpRewritePattern; |
268 | |
269 | LogicalResult matchAndRewrite(OpTy storeOp, |
270 | PatternRewriter &rewriter) const override; |
271 | }; |
272 | |
273 | /// Folds subview(subview(x)) to a single subview(x). |
274 | class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> { |
275 | public: |
276 | using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; |
277 | |
278 | LogicalResult matchAndRewrite(memref::SubViewOp subView, |
279 | PatternRewriter &rewriter) const override { |
280 | auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>(); |
281 | if (!srcSubView) |
282 | return failure(); |
283 | |
284 | // TODO: relax unit stride assumption. |
285 | if (!subView.hasUnitStride()) { |
286 | return rewriter.notifyMatchFailure(subView, "requires unit strides" ); |
287 | } |
288 | if (!srcSubView.hasUnitStride()) { |
289 | return rewriter.notifyMatchFailure(srcSubView, "requires unit strides" ); |
290 | } |
291 | |
292 | // Resolve sizes according to dropped dims. |
293 | SmallVector<OpFoldResult> resolvedSizes; |
294 | llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims(); |
295 | affine::resolveSizesIntoOpWithSizes(sourceSizes: srcSubView.getMixedSizes(), |
296 | destSizes: subView.getMixedSizes(), rankReducedSourceDims: srcDroppedDims, |
297 | resolvedSizes); |
298 | |
299 | // Resolve offsets according to source offsets and strides. |
300 | SmallVector<Value> resolvedOffsets; |
301 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
302 | rewriter, subView.getLoc(), srcSubView.getMixedOffsets(), |
303 | srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(), |
304 | resolvedOffsets); |
305 | |
306 | // Replace original op. |
307 | rewriter.replaceOpWithNewOp<memref::SubViewOp>( |
308 | subView, subView.getType(), srcSubView.getSource(), |
309 | getAsOpFoldResult(resolvedOffsets), resolvedSizes, |
310 | srcSubView.getMixedStrides()); |
311 | |
312 | return success(); |
313 | } |
314 | }; |
315 | |
316 | /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern |
317 | /// is folds subview on src and dst memref of the copy. |
318 | class NvgpuAsyncCopyOpSubViewOpFolder final |
319 | : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> { |
320 | public: |
321 | using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern; |
322 | |
323 | LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, |
324 | PatternRewriter &rewriter) const override; |
325 | }; |
326 | } // namespace |
327 | |
328 | static SmallVector<Value> |
329 | calculateExpandedAccessIndices(AffineMap affineMap, |
330 | const SmallVector<Value> &indices, Location loc, |
331 | PatternRewriter &rewriter) { |
332 | SmallVector<OpFoldResult> indicesOfr(llvm::to_vector( |
333 | Range: llvm::map_range(C: indices, F: [](Value v) -> OpFoldResult { return v; }))); |
334 | SmallVector<Value> expandedIndices; |
335 | for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { |
336 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
337 | b&: rewriter, loc, map: affineMap.getSubMap(resultPos: {i}), operands: indicesOfr); |
338 | expandedIndices.push_back( |
339 | Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr)); |
340 | } |
341 | return expandedIndices; |
342 | } |
343 | |
344 | template <typename XferOp> |
345 | static LogicalResult |
346 | preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, |
347 | memref::SubViewOp subviewOp) { |
348 | static_assert( |
349 | !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value, |
350 | "must be a vector transfer op" ); |
351 | if (xferOp.hasOutOfBoundsDim()) |
352 | return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim" ); |
353 | if (!subviewOp.hasUnitStride()) { |
354 | return rewriter.notifyMatchFailure( |
355 | xferOp, "non-1 stride subview, need to track strides in folded memref" ); |
356 | } |
357 | return success(); |
358 | } |
359 | |
360 | static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, |
361 | Operation *op, |
362 | memref::SubViewOp subviewOp) { |
363 | return success(); |
364 | } |
365 | |
366 | static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, |
367 | vector::TransferReadOp readOp, |
368 | memref::SubViewOp subviewOp) { |
369 | return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp); |
370 | } |
371 | |
372 | static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, |
373 | vector::TransferWriteOp writeOp, |
374 | memref::SubViewOp subviewOp) { |
375 | return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp); |
376 | } |
377 | |
378 | template <typename OpTy> |
379 | LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite( |
380 | OpTy loadOp, PatternRewriter &rewriter) const { |
381 | auto subViewOp = |
382 | getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); |
383 | |
384 | if (!subViewOp) |
385 | return rewriter.notifyMatchFailure(loadOp, "not a subview producer" ); |
386 | |
387 | LogicalResult preconditionResult = |
388 | preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp); |
389 | if (failed(result: preconditionResult)) |
390 | return preconditionResult; |
391 | |
392 | SmallVector<Value> indices(loadOp.getIndices().begin(), |
393 | loadOp.getIndices().end()); |
394 | // For affine ops, we need to apply the map to get the operands to get the |
395 | // "actual" indices. |
396 | if (auto affineLoadOp = |
397 | dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { |
398 | AffineMap affineMap = affineLoadOp.getAffineMap(); |
399 | auto expandedIndices = calculateExpandedAccessIndices( |
400 | affineMap, indices, loadOp.getLoc(), rewriter); |
401 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
402 | } |
403 | SmallVector<Value> sourceIndices; |
404 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
405 | rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), |
406 | subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, |
407 | sourceIndices); |
408 | |
409 | llvm::TypeSwitch<Operation *, void>(loadOp) |
410 | .Case([&](affine::AffineLoadOp op) { |
411 | rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( |
412 | loadOp, subViewOp.getSource(), sourceIndices); |
413 | }) |
414 | .Case([&](memref::LoadOp op) { |
415 | rewriter.replaceOpWithNewOp<memref::LoadOp>( |
416 | loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); |
417 | }) |
418 | .Case([&](vector::LoadOp op) { |
419 | rewriter.replaceOpWithNewOp<vector::LoadOp>( |
420 | op, op.getType(), subViewOp.getSource(), sourceIndices); |
421 | }) |
422 | .Case([&](vector::MaskedLoadOp op) { |
423 | rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( |
424 | op, op.getType(), subViewOp.getSource(), sourceIndices, |
425 | op.getMask(), op.getPassThru()); |
426 | }) |
427 | .Case([&](vector::TransferReadOp op) { |
428 | rewriter.replaceOpWithNewOp<vector::TransferReadOp>( |
429 | op, op.getVectorType(), subViewOp.getSource(), sourceIndices, |
430 | AffineMapAttr::get(expandDimsToRank( |
431 | op.getPermutationMap(), subViewOp.getSourceType().getRank(), |
432 | subViewOp.getDroppedDims())), |
433 | op.getPadding(), op.getMask(), op.getInBoundsAttr()); |
434 | }) |
435 | .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { |
436 | rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>( |
437 | op, op.getType(), subViewOp.getSource(), sourceIndices, |
438 | op.getLeadDimension(), op.getTransposeAttr()); |
439 | }) |
440 | .Case([&](nvgpu::LdMatrixOp op) { |
441 | rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>( |
442 | op, op.getType(), subViewOp.getSource(), sourceIndices, |
443 | op.getTranspose(), op.getNumTiles()); |
444 | }) |
445 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
446 | return success(); |
447 | } |
448 | |
449 | template <typename OpTy> |
450 | LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( |
451 | OpTy loadOp, PatternRewriter &rewriter) const { |
452 | auto expandShapeOp = |
453 | getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>(); |
454 | |
455 | if (!expandShapeOp) |
456 | return failure(); |
457 | |
458 | SmallVector<Value> indices(loadOp.getIndices().begin(), |
459 | loadOp.getIndices().end()); |
460 | // For affine ops, we need to apply the map to get the operands to get the |
461 | // "actual" indices. |
462 | if (auto affineLoadOp = |
463 | dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { |
464 | AffineMap affineMap = affineLoadOp.getAffineMap(); |
465 | auto expandedIndices = calculateExpandedAccessIndices( |
466 | affineMap, indices, loadOp.getLoc(), rewriter); |
467 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
468 | } |
469 | SmallVector<Value> sourceIndices; |
470 | if (failed(resolveSourceIndicesExpandShape( |
471 | loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) |
472 | return failure(); |
473 | llvm::TypeSwitch<Operation *, void>(loadOp) |
474 | .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) { |
475 | rewriter.replaceOpWithNewOp<decltype(op)>( |
476 | loadOp, expandShapeOp.getViewSource(), sourceIndices); |
477 | }) |
478 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
479 | return success(); |
480 | } |
481 | |
482 | template <typename OpTy> |
483 | LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( |
484 | OpTy loadOp, PatternRewriter &rewriter) const { |
485 | auto collapseShapeOp = getMemRefOperand(loadOp) |
486 | .template getDefiningOp<memref::CollapseShapeOp>(); |
487 | |
488 | if (!collapseShapeOp) |
489 | return failure(); |
490 | |
491 | SmallVector<Value> indices(loadOp.getIndices().begin(), |
492 | loadOp.getIndices().end()); |
493 | // For affine ops, we need to apply the map to get the operands to get the |
494 | // "actual" indices. |
495 | if (auto affineLoadOp = |
496 | dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { |
497 | AffineMap affineMap = affineLoadOp.getAffineMap(); |
498 | auto expandedIndices = calculateExpandedAccessIndices( |
499 | affineMap, indices, loadOp.getLoc(), rewriter); |
500 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
501 | } |
502 | SmallVector<Value> sourceIndices; |
503 | if (failed(resolveSourceIndicesCollapseShape( |
504 | loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) |
505 | return failure(); |
506 | llvm::TypeSwitch<Operation *, void>(loadOp) |
507 | .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) { |
508 | rewriter.replaceOpWithNewOp<decltype(op)>( |
509 | loadOp, collapseShapeOp.getViewSource(), sourceIndices); |
510 | }) |
511 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
512 | return success(); |
513 | } |
514 | |
515 | template <typename OpTy> |
516 | LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite( |
517 | OpTy storeOp, PatternRewriter &rewriter) const { |
518 | auto subViewOp = |
519 | getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); |
520 | |
521 | if (!subViewOp) |
522 | return rewriter.notifyMatchFailure(storeOp, "not a subview producer" ); |
523 | |
524 | LogicalResult preconditionResult = |
525 | preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp); |
526 | if (failed(result: preconditionResult)) |
527 | return preconditionResult; |
528 | |
529 | SmallVector<Value> indices(storeOp.getIndices().begin(), |
530 | storeOp.getIndices().end()); |
531 | // For affine ops, we need to apply the map to get the operands to get the |
532 | // "actual" indices. |
533 | if (auto affineStoreOp = |
534 | dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { |
535 | AffineMap affineMap = affineStoreOp.getAffineMap(); |
536 | auto expandedIndices = calculateExpandedAccessIndices( |
537 | affineMap, indices, storeOp.getLoc(), rewriter); |
538 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
539 | } |
540 | SmallVector<Value> sourceIndices; |
541 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
542 | rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), |
543 | subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, |
544 | sourceIndices); |
545 | |
546 | llvm::TypeSwitch<Operation *, void>(storeOp) |
547 | .Case([&](affine::AffineStoreOp op) { |
548 | rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( |
549 | op, op.getValue(), subViewOp.getSource(), sourceIndices); |
550 | }) |
551 | .Case([&](memref::StoreOp op) { |
552 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
553 | op, op.getValue(), subViewOp.getSource(), sourceIndices, |
554 | op.getNontemporal()); |
555 | }) |
556 | .Case([&](vector::TransferWriteOp op) { |
557 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
558 | op, op.getValue(), subViewOp.getSource(), sourceIndices, |
559 | AffineMapAttr::get(expandDimsToRank( |
560 | op.getPermutationMap(), subViewOp.getSourceType().getRank(), |
561 | subViewOp.getDroppedDims())), |
562 | op.getMask(), op.getInBoundsAttr()); |
563 | }) |
564 | .Case([&](vector::StoreOp op) { |
565 | rewriter.replaceOpWithNewOp<vector::StoreOp>( |
566 | op, op.getValueToStore(), subViewOp.getSource(), sourceIndices); |
567 | }) |
568 | .Case([&](vector::MaskedStoreOp op) { |
569 | rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( |
570 | op, subViewOp.getSource(), sourceIndices, op.getMask(), |
571 | op.getValueToStore()); |
572 | }) |
573 | .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { |
574 | rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>( |
575 | op, op.getSrc(), subViewOp.getSource(), sourceIndices, |
576 | op.getLeadDimension(), op.getTransposeAttr()); |
577 | }) |
578 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
579 | return success(); |
580 | } |
581 | |
582 | template <typename OpTy> |
583 | LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( |
584 | OpTy storeOp, PatternRewriter &rewriter) const { |
585 | auto expandShapeOp = |
586 | getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>(); |
587 | |
588 | if (!expandShapeOp) |
589 | return failure(); |
590 | |
591 | SmallVector<Value> indices(storeOp.getIndices().begin(), |
592 | storeOp.getIndices().end()); |
593 | // For affine ops, we need to apply the map to get the operands to get the |
594 | // "actual" indices. |
595 | if (auto affineStoreOp = |
596 | dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { |
597 | AffineMap affineMap = affineStoreOp.getAffineMap(); |
598 | auto expandedIndices = calculateExpandedAccessIndices( |
599 | affineMap, indices, storeOp.getLoc(), rewriter); |
600 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
601 | } |
602 | SmallVector<Value> sourceIndices; |
603 | if (failed(resolveSourceIndicesExpandShape( |
604 | storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) |
605 | return failure(); |
606 | llvm::TypeSwitch<Operation *, void>(storeOp) |
607 | .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) { |
608 | rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(), |
609 | expandShapeOp.getViewSource(), |
610 | sourceIndices); |
611 | }) |
612 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
613 | return success(); |
614 | } |
615 | |
616 | template <typename OpTy> |
617 | LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( |
618 | OpTy storeOp, PatternRewriter &rewriter) const { |
619 | auto collapseShapeOp = getMemRefOperand(storeOp) |
620 | .template getDefiningOp<memref::CollapseShapeOp>(); |
621 | |
622 | if (!collapseShapeOp) |
623 | return failure(); |
624 | |
625 | SmallVector<Value> indices(storeOp.getIndices().begin(), |
626 | storeOp.getIndices().end()); |
627 | // For affine ops, we need to apply the map to get the operands to get the |
628 | // "actual" indices. |
629 | if (auto affineStoreOp = |
630 | dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { |
631 | AffineMap affineMap = affineStoreOp.getAffineMap(); |
632 | auto expandedIndices = calculateExpandedAccessIndices( |
633 | affineMap, indices, storeOp.getLoc(), rewriter); |
634 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
635 | } |
636 | SmallVector<Value> sourceIndices; |
637 | if (failed(resolveSourceIndicesCollapseShape( |
638 | storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) |
639 | return failure(); |
640 | llvm::TypeSwitch<Operation *, void>(storeOp) |
641 | .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) { |
642 | rewriter.replaceOpWithNewOp<decltype(op)>( |
643 | storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(), |
644 | sourceIndices); |
645 | }) |
646 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
647 | return success(); |
648 | } |
649 | |
650 | LogicalResult NvgpuAsyncCopyOpSubViewOpFolder::matchAndRewrite( |
651 | nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { |
652 | |
653 | LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n" ); |
654 | |
655 | auto srcSubViewOp = |
656 | copyOp.getSrc().template getDefiningOp<memref::SubViewOp>(); |
657 | auto dstSubViewOp = |
658 | copyOp.getDst().template getDefiningOp<memref::SubViewOp>(); |
659 | |
660 | if (!(srcSubViewOp || dstSubViewOp)) |
661 | return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for " |
662 | "source or destination" ); |
663 | |
664 | // If the source is a subview, we need to resolve the indices. |
665 | SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(), |
666 | copyOp.getSrcIndices().end()); |
667 | SmallVector<Value> foldedSrcIndices(srcindices); |
668 | |
669 | if (srcSubViewOp) { |
670 | LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n" ); |
671 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
672 | rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(), |
673 | srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(), |
674 | srcindices, foldedSrcIndices); |
675 | } |
676 | |
677 | // If the destination is a subview, we need to resolve the indices. |
678 | SmallVector<Value> dstindices(copyOp.getDstIndices().begin(), |
679 | copyOp.getDstIndices().end()); |
680 | SmallVector<Value> foldedDstIndices(dstindices); |
681 | |
682 | if (dstSubViewOp) { |
683 | LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n" ); |
684 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
685 | rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(), |
686 | dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(), |
687 | dstindices, foldedDstIndices); |
688 | } |
689 | |
690 | // Replace the copy op with a new copy op that uses the source and destination |
691 | // of the subview. |
692 | rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>( |
693 | copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()), |
694 | (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()), |
695 | foldedDstIndices, |
696 | (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), |
697 | foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), |
698 | copyOp.getBypassL1Attr()); |
699 | |
700 | return success(); |
701 | } |
702 | |
703 | void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { |
704 | patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>, |
705 | LoadOpOfSubViewOpFolder<memref::LoadOp>, |
706 | LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>, |
707 | LoadOpOfSubViewOpFolder<vector::LoadOp>, |
708 | LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>, |
709 | LoadOpOfSubViewOpFolder<vector::TransferReadOp>, |
710 | LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>, |
711 | StoreOpOfSubViewOpFolder<affine::AffineStoreOp>, |
712 | StoreOpOfSubViewOpFolder<memref::StoreOp>, |
713 | StoreOpOfSubViewOpFolder<vector::TransferWriteOp>, |
714 | StoreOpOfSubViewOpFolder<vector::StoreOp>, |
715 | StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>, |
716 | StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>, |
717 | LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>, |
718 | LoadOpOfExpandShapeOpFolder<memref::LoadOp>, |
719 | StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, |
720 | StoreOpOfExpandShapeOpFolder<memref::StoreOp>, |
721 | LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>, |
722 | LoadOpOfCollapseShapeOpFolder<memref::LoadOp>, |
723 | StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>, |
724 | StoreOpOfCollapseShapeOpFolder<memref::StoreOp>, |
725 | SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder>( |
726 | patterns.getContext()); |
727 | } |
728 | |
729 | //===----------------------------------------------------------------------===// |
730 | // Pass registration |
731 | //===----------------------------------------------------------------------===// |
732 | |
733 | namespace { |
734 | |
735 | struct FoldMemRefAliasOpsPass final |
736 | : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> { |
737 | void runOnOperation() override; |
738 | }; |
739 | |
740 | } // namespace |
741 | |
742 | void FoldMemRefAliasOpsPass::runOnOperation() { |
743 | RewritePatternSet patterns(&getContext()); |
744 | memref::populateFoldMemRefAliasOpPatterns(patterns); |
745 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
746 | } |
747 | |
748 | std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() { |
749 | return std::make_unique<FoldMemRefAliasOpsPass>(); |
750 | } |
751 | |