1 | //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation pass folds loading/storing from/to subview ops into |
10 | // loading/storing from/to the original memref. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
20 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
21 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
22 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
23 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
24 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
25 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
26 | #include "mlir/IR/AffineMap.h" |
27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
28 | #include "llvm/ADT/STLExtras.h" |
29 | #include "llvm/ADT/SmallBitVector.h" |
30 | #include "llvm/ADT/TypeSwitch.h" |
31 | #include "llvm/Support/Debug.h" |
32 | |
33 | #define DEBUG_TYPE "fold-memref-alias-ops" |
34 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
35 | |
36 | namespace mlir { |
37 | namespace memref { |
38 | #define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS |
39 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
40 | } // namespace memref |
41 | } // namespace mlir |
42 | |
43 | using namespace mlir; |
44 | |
45 | //===----------------------------------------------------------------------===// |
46 | // Utility functions |
47 | //===----------------------------------------------------------------------===// |
48 | |
49 | /// Given the 'indices' of a load/store operation where the memref is a result |
50 | /// of a expand_shape op, returns the indices w.r.t to the source memref of the |
51 | /// expand_shape op. For example |
52 | /// |
53 | /// %0 = ... : memref<12x42xf32> |
54 | /// %1 = memref.expand_shape %0 [[0, 1], [2]] |
55 | /// : memref<12x42xf32> into memref<2x6x42xf32> |
56 | /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32 |
57 | /// |
58 | /// could be folded into |
59 | /// |
60 | /// %2 = load %0[6 * i1 + i2, %i3] : |
61 | /// memref<12x42xf32> |
62 | static LogicalResult resolveSourceIndicesExpandShape( |
63 | Location loc, PatternRewriter &rewriter, |
64 | memref::ExpandShapeOp expandShapeOp, ValueRange indices, |
65 | SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) { |
66 | SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape(); |
67 | |
68 | // Traverse all reassociation groups to determine the appropriate indices |
69 | // corresponding to each one of them post op folding. |
70 | for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) { |
71 | assert(!group.empty() && "association indices groups cannot be empty" ); |
72 | int64_t groupSize = group.size(); |
73 | if (groupSize == 1) { |
74 | sourceIndices.push_back(indices[group[0]]); |
75 | continue; |
76 | } |
77 | SmallVector<OpFoldResult> groupBasis = |
78 | llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; }); |
79 | SmallVector<Value> groupIndices = |
80 | llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; }); |
81 | Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>( |
82 | loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds); |
83 | sourceIndices.push_back(collapsedIndex); |
84 | } |
85 | return success(); |
86 | } |
87 | |
88 | /// Given the 'indices' of a load/store operation where the memref is a result |
89 | /// of a collapse_shape op, returns the indices w.r.t to the source memref of |
90 | /// the collapse_shape op. For example |
91 | /// |
92 | /// %0 = ... : memref<2x6x42xf32> |
93 | /// %1 = memref.collapse_shape %0 [[0, 1], [2]] |
94 | /// : memref<2x6x42xf32> into memref<12x42xf32> |
95 | /// %2 = load %1[%i1, %i2] : memref<12x42xf32> |
96 | /// |
97 | /// could be folded into |
98 | /// |
99 | /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] : |
100 | /// memref<2x6x42xf32> |
101 | static LogicalResult |
102 | resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, |
103 | memref::CollapseShapeOp collapseShapeOp, |
104 | ValueRange indices, |
105 | SmallVectorImpl<Value> &sourceIndices) { |
106 | // Note: collapse_shape requires a strided memref, we can do this. |
107 | auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>( |
108 | loc, collapseShapeOp.getSrc()); |
109 | SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes(); |
110 | for (auto [index, group] : |
111 | llvm::zip(indices, collapseShapeOp.getReassociationIndices())) { |
112 | assert(!group.empty() && "association indices groups cannot be empty" ); |
113 | int64_t groupSize = group.size(); |
114 | |
115 | if (groupSize == 1) { |
116 | sourceIndices.push_back(index); |
117 | continue; |
118 | } |
119 | |
120 | SmallVector<OpFoldResult> basis = |
121 | llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); |
122 | auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>( |
123 | loc, index, basis, /*hasOuterBound=*/true); |
124 | llvm::append_range(sourceIndices, delinearize.getResults()); |
125 | } |
126 | if (collapseShapeOp.getReassociationIndices().empty()) { |
127 | auto zeroAffineMap = rewriter.getConstantAffineMap(val: 0); |
128 | int64_t srcRank = |
129 | cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank(); |
130 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
131 | b&: rewriter, loc, map: zeroAffineMap, operands: ArrayRef<OpFoldResult>{}); |
132 | for (int64_t i = 0; i < srcRank; i++) { |
133 | sourceIndices.push_back( |
134 | Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr)); |
135 | } |
136 | } |
137 | return success(); |
138 | } |
139 | |
140 | /// Helpers to access the memref operand for each op. |
141 | template <typename LoadOrStoreOpTy> |
142 | static Value getMemRefOperand(LoadOrStoreOpTy op) { |
143 | return op.getMemref(); |
144 | } |
145 | |
146 | static Value getMemRefOperand(vector::TransferReadOp op) { |
147 | return op.getBase(); |
148 | } |
149 | |
150 | static Value getMemRefOperand(nvgpu::LdMatrixOp op) { |
151 | return op.getSrcMemref(); |
152 | } |
153 | |
154 | static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } |
155 | |
156 | static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); } |
157 | |
158 | static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); } |
159 | |
160 | static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); } |
161 | |
162 | static Value getMemRefOperand(vector::TransferWriteOp op) { |
163 | return op.getBase(); |
164 | } |
165 | |
166 | static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) { |
167 | return op.getSrcMemref(); |
168 | } |
169 | |
170 | static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { |
171 | return op.getDstMemref(); |
172 | } |
173 | |
174 | //===----------------------------------------------------------------------===// |
175 | // Patterns |
176 | //===----------------------------------------------------------------------===// |
177 | |
178 | namespace { |
179 | /// Merges subview operation with load/transferRead operation. |
180 | template <typename OpTy> |
181 | class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { |
182 | public: |
183 | using OpRewritePattern<OpTy>::OpRewritePattern; |
184 | |
185 | LogicalResult matchAndRewrite(OpTy loadOp, |
186 | PatternRewriter &rewriter) const override; |
187 | }; |
188 | |
189 | /// Merges expand_shape operation with load/transferRead operation. |
190 | template <typename OpTy> |
191 | class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { |
192 | public: |
193 | using OpRewritePattern<OpTy>::OpRewritePattern; |
194 | |
195 | LogicalResult matchAndRewrite(OpTy loadOp, |
196 | PatternRewriter &rewriter) const override; |
197 | }; |
198 | |
199 | /// Merges collapse_shape operation with load/transferRead operation. |
200 | template <typename OpTy> |
201 | class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { |
202 | public: |
203 | using OpRewritePattern<OpTy>::OpRewritePattern; |
204 | |
205 | LogicalResult matchAndRewrite(OpTy loadOp, |
206 | PatternRewriter &rewriter) const override; |
207 | }; |
208 | |
209 | /// Merges subview operation with store/transferWriteOp operation. |
210 | template <typename OpTy> |
211 | class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { |
212 | public: |
213 | using OpRewritePattern<OpTy>::OpRewritePattern; |
214 | |
215 | LogicalResult matchAndRewrite(OpTy storeOp, |
216 | PatternRewriter &rewriter) const override; |
217 | }; |
218 | |
219 | /// Merges expand_shape operation with store/transferWriteOp operation. |
220 | template <typename OpTy> |
221 | class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { |
222 | public: |
223 | using OpRewritePattern<OpTy>::OpRewritePattern; |
224 | |
225 | LogicalResult matchAndRewrite(OpTy storeOp, |
226 | PatternRewriter &rewriter) const override; |
227 | }; |
228 | |
229 | /// Merges collapse_shape operation with store/transferWriteOp operation. |
230 | template <typename OpTy> |
231 | class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { |
232 | public: |
233 | using OpRewritePattern<OpTy>::OpRewritePattern; |
234 | |
235 | LogicalResult matchAndRewrite(OpTy storeOp, |
236 | PatternRewriter &rewriter) const override; |
237 | }; |
238 | |
239 | /// Folds subview(subview(x)) to a single subview(x). |
240 | class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> { |
241 | public: |
242 | using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; |
243 | |
244 | LogicalResult matchAndRewrite(memref::SubViewOp subView, |
245 | PatternRewriter &rewriter) const override { |
246 | auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>(); |
247 | if (!srcSubView) |
248 | return failure(); |
249 | |
250 | // TODO: relax unit stride assumption. |
251 | if (!subView.hasUnitStride()) { |
252 | return rewriter.notifyMatchFailure(subView, "requires unit strides" ); |
253 | } |
254 | if (!srcSubView.hasUnitStride()) { |
255 | return rewriter.notifyMatchFailure(srcSubView, "requires unit strides" ); |
256 | } |
257 | |
258 | // Resolve sizes according to dropped dims. |
259 | SmallVector<OpFoldResult> resolvedSizes; |
260 | llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims(); |
261 | affine::resolveSizesIntoOpWithSizes(sourceSizes: srcSubView.getMixedSizes(), |
262 | destSizes: subView.getMixedSizes(), rankReducedSourceDims: srcDroppedDims, |
263 | resolvedSizes); |
264 | |
265 | // Resolve offsets according to source offsets and strides. |
266 | SmallVector<Value> resolvedOffsets; |
267 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
268 | rewriter, subView.getLoc(), srcSubView.getMixedOffsets(), |
269 | srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(), |
270 | resolvedOffsets); |
271 | |
272 | // Replace original op. |
273 | rewriter.replaceOpWithNewOp<memref::SubViewOp>( |
274 | subView, subView.getType(), srcSubView.getSource(), |
275 | getAsOpFoldResult(resolvedOffsets), resolvedSizes, |
276 | srcSubView.getMixedStrides()); |
277 | |
278 | return success(); |
279 | } |
280 | }; |
281 | |
282 | /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern |
283 | /// is folds subview on src and dst memref of the copy. |
284 | class NVGPUAsyncCopyOpSubViewOpFolder final |
285 | : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> { |
286 | public: |
287 | using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern; |
288 | |
289 | LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, |
290 | PatternRewriter &rewriter) const override; |
291 | }; |
292 | } // namespace |
293 | |
294 | static SmallVector<Value> |
295 | calculateExpandedAccessIndices(AffineMap affineMap, |
296 | const SmallVector<Value> &indices, Location loc, |
297 | PatternRewriter &rewriter) { |
298 | SmallVector<OpFoldResult> indicesOfr(llvm::to_vector( |
299 | Range: llvm::map_range(C: indices, F: [](Value v) -> OpFoldResult { return v; }))); |
300 | SmallVector<Value> expandedIndices; |
301 | for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { |
302 | OpFoldResult ofr = affine::makeComposedFoldedAffineApply( |
303 | b&: rewriter, loc, map: affineMap.getSubMap(resultPos: {i}), operands: indicesOfr); |
304 | expandedIndices.push_back( |
305 | Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr)); |
306 | } |
307 | return expandedIndices; |
308 | } |
309 | |
310 | template <typename XferOp> |
311 | static LogicalResult |
312 | preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, |
313 | memref::SubViewOp subviewOp) { |
314 | static_assert( |
315 | !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value, |
316 | "must be a vector transfer op" ); |
317 | if (xferOp.hasOutOfBoundsDim()) |
318 | return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim" ); |
319 | if (!subviewOp.hasUnitStride()) { |
320 | return rewriter.notifyMatchFailure( |
321 | xferOp, "non-1 stride subview, need to track strides in folded memref" ); |
322 | } |
323 | return success(); |
324 | } |
325 | |
326 | static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, |
327 | Operation *op, |
328 | memref::SubViewOp subviewOp) { |
329 | return success(); |
330 | } |
331 | |
332 | static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, |
333 | vector::TransferReadOp readOp, |
334 | memref::SubViewOp subviewOp) { |
335 | return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp); |
336 | } |
337 | |
338 | static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, |
339 | vector::TransferWriteOp writeOp, |
340 | memref::SubViewOp subviewOp) { |
341 | return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp); |
342 | } |
343 | |
344 | template <typename OpTy> |
345 | LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite( |
346 | OpTy loadOp, PatternRewriter &rewriter) const { |
347 | auto subViewOp = |
348 | getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); |
349 | |
350 | if (!subViewOp) |
351 | return rewriter.notifyMatchFailure(loadOp, "not a subview producer" ); |
352 | |
353 | LogicalResult preconditionResult = |
354 | preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp); |
355 | if (failed(Result: preconditionResult)) |
356 | return preconditionResult; |
357 | |
358 | SmallVector<Value> indices(loadOp.getIndices().begin(), |
359 | loadOp.getIndices().end()); |
360 | // For affine ops, we need to apply the map to get the operands to get the |
361 | // "actual" indices. |
362 | if (auto affineLoadOp = |
363 | dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { |
364 | AffineMap affineMap = affineLoadOp.getAffineMap(); |
365 | auto expandedIndices = calculateExpandedAccessIndices( |
366 | affineMap, indices, loadOp.getLoc(), rewriter); |
367 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
368 | } |
369 | SmallVector<Value> sourceIndices; |
370 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
371 | rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), |
372 | subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, |
373 | sourceIndices); |
374 | |
375 | llvm::TypeSwitch<Operation *, void>(loadOp) |
376 | .Case([&](affine::AffineLoadOp op) { |
377 | rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( |
378 | loadOp, subViewOp.getSource(), sourceIndices); |
379 | }) |
380 | .Case([&](memref::LoadOp op) { |
381 | rewriter.replaceOpWithNewOp<memref::LoadOp>( |
382 | loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); |
383 | }) |
384 | .Case([&](vector::LoadOp op) { |
385 | rewriter.replaceOpWithNewOp<vector::LoadOp>( |
386 | op, op.getType(), subViewOp.getSource(), sourceIndices); |
387 | }) |
388 | .Case([&](vector::MaskedLoadOp op) { |
389 | rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( |
390 | op, op.getType(), subViewOp.getSource(), sourceIndices, |
391 | op.getMask(), op.getPassThru()); |
392 | }) |
393 | .Case([&](vector::TransferReadOp op) { |
394 | rewriter.replaceOpWithNewOp<vector::TransferReadOp>( |
395 | op, op.getVectorType(), subViewOp.getSource(), sourceIndices, |
396 | AffineMapAttr::get(expandDimsToRank( |
397 | op.getPermutationMap(), subViewOp.getSourceType().getRank(), |
398 | subViewOp.getDroppedDims())), |
399 | op.getPadding(), op.getMask(), op.getInBoundsAttr()); |
400 | }) |
401 | .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { |
402 | rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>( |
403 | op, op.getType(), subViewOp.getSource(), sourceIndices, |
404 | op.getLeadDimension(), op.getTransposeAttr()); |
405 | }) |
406 | .Case([&](nvgpu::LdMatrixOp op) { |
407 | rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>( |
408 | op, op.getType(), subViewOp.getSource(), sourceIndices, |
409 | op.getTranspose(), op.getNumTiles()); |
410 | }) |
411 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
412 | return success(); |
413 | } |
414 | |
415 | template <typename OpTy> |
416 | LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( |
417 | OpTy loadOp, PatternRewriter &rewriter) const { |
418 | auto expandShapeOp = |
419 | getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>(); |
420 | |
421 | if (!expandShapeOp) |
422 | return failure(); |
423 | |
424 | SmallVector<Value> indices(loadOp.getIndices().begin(), |
425 | loadOp.getIndices().end()); |
426 | // For affine ops, we need to apply the map to get the operands to get the |
427 | // "actual" indices. |
428 | if (auto affineLoadOp = |
429 | dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { |
430 | AffineMap affineMap = affineLoadOp.getAffineMap(); |
431 | auto expandedIndices = calculateExpandedAccessIndices( |
432 | affineMap, indices, loadOp.getLoc(), rewriter); |
433 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
434 | } |
435 | SmallVector<Value> sourceIndices; |
436 | // memref.load and affine.load guarantee that indexes start inbounds |
437 | // while the vector operations don't. This impacts if our linearization |
438 | // is `disjoint` |
439 | if (failed(resolveSourceIndicesExpandShape( |
440 | loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices, |
441 | isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation())))) |
442 | return failure(); |
443 | llvm::TypeSwitch<Operation *, void>(loadOp) |
444 | .Case([&](affine::AffineLoadOp op) { |
445 | rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( |
446 | loadOp, expandShapeOp.getViewSource(), sourceIndices); |
447 | }) |
448 | .Case([&](memref::LoadOp op) { |
449 | rewriter.replaceOpWithNewOp<memref::LoadOp>( |
450 | loadOp, expandShapeOp.getViewSource(), sourceIndices, |
451 | op.getNontemporal()); |
452 | }) |
453 | .Case([&](vector::LoadOp op) { |
454 | rewriter.replaceOpWithNewOp<vector::LoadOp>( |
455 | op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, |
456 | op.getNontemporal()); |
457 | }) |
458 | .Case([&](vector::MaskedLoadOp op) { |
459 | rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( |
460 | op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, |
461 | op.getMask(), op.getPassThru()); |
462 | }) |
463 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
464 | return success(); |
465 | } |
466 | |
467 | template <typename OpTy> |
468 | LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( |
469 | OpTy loadOp, PatternRewriter &rewriter) const { |
470 | auto collapseShapeOp = getMemRefOperand(loadOp) |
471 | .template getDefiningOp<memref::CollapseShapeOp>(); |
472 | |
473 | if (!collapseShapeOp) |
474 | return failure(); |
475 | |
476 | SmallVector<Value> indices(loadOp.getIndices().begin(), |
477 | loadOp.getIndices().end()); |
478 | // For affine ops, we need to apply the map to get the operands to get the |
479 | // "actual" indices. |
480 | if (auto affineLoadOp = |
481 | dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { |
482 | AffineMap affineMap = affineLoadOp.getAffineMap(); |
483 | auto expandedIndices = calculateExpandedAccessIndices( |
484 | affineMap, indices, loadOp.getLoc(), rewriter); |
485 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
486 | } |
487 | SmallVector<Value> sourceIndices; |
488 | if (failed(resolveSourceIndicesCollapseShape( |
489 | loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) |
490 | return failure(); |
491 | llvm::TypeSwitch<Operation *, void>(loadOp) |
492 | .Case([&](affine::AffineLoadOp op) { |
493 | rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( |
494 | loadOp, collapseShapeOp.getViewSource(), sourceIndices); |
495 | }) |
496 | .Case([&](memref::LoadOp op) { |
497 | rewriter.replaceOpWithNewOp<memref::LoadOp>( |
498 | loadOp, collapseShapeOp.getViewSource(), sourceIndices, |
499 | op.getNontemporal()); |
500 | }) |
501 | .Case([&](vector::LoadOp op) { |
502 | rewriter.replaceOpWithNewOp<vector::LoadOp>( |
503 | op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, |
504 | op.getNontemporal()); |
505 | }) |
506 | .Case([&](vector::MaskedLoadOp op) { |
507 | rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( |
508 | op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, |
509 | op.getMask(), op.getPassThru()); |
510 | }) |
511 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
512 | return success(); |
513 | } |
514 | |
515 | template <typename OpTy> |
516 | LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite( |
517 | OpTy storeOp, PatternRewriter &rewriter) const { |
518 | auto subViewOp = |
519 | getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); |
520 | |
521 | if (!subViewOp) |
522 | return rewriter.notifyMatchFailure(storeOp, "not a subview producer" ); |
523 | |
524 | LogicalResult preconditionResult = |
525 | preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp); |
526 | if (failed(Result: preconditionResult)) |
527 | return preconditionResult; |
528 | |
529 | SmallVector<Value> indices(storeOp.getIndices().begin(), |
530 | storeOp.getIndices().end()); |
531 | // For affine ops, we need to apply the map to get the operands to get the |
532 | // "actual" indices. |
533 | if (auto affineStoreOp = |
534 | dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { |
535 | AffineMap affineMap = affineStoreOp.getAffineMap(); |
536 | auto expandedIndices = calculateExpandedAccessIndices( |
537 | affineMap, indices, storeOp.getLoc(), rewriter); |
538 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
539 | } |
540 | SmallVector<Value> sourceIndices; |
541 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
542 | rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), |
543 | subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, |
544 | sourceIndices); |
545 | |
546 | llvm::TypeSwitch<Operation *, void>(storeOp) |
547 | .Case([&](affine::AffineStoreOp op) { |
548 | rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( |
549 | op, op.getValue(), subViewOp.getSource(), sourceIndices); |
550 | }) |
551 | .Case([&](memref::StoreOp op) { |
552 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
553 | op, op.getValue(), subViewOp.getSource(), sourceIndices, |
554 | op.getNontemporal()); |
555 | }) |
556 | .Case([&](vector::TransferWriteOp op) { |
557 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
558 | op, op.getValue(), subViewOp.getSource(), sourceIndices, |
559 | AffineMapAttr::get(expandDimsToRank( |
560 | op.getPermutationMap(), subViewOp.getSourceType().getRank(), |
561 | subViewOp.getDroppedDims())), |
562 | op.getMask(), op.getInBoundsAttr()); |
563 | }) |
564 | .Case([&](vector::StoreOp op) { |
565 | rewriter.replaceOpWithNewOp<vector::StoreOp>( |
566 | op, op.getValueToStore(), subViewOp.getSource(), sourceIndices); |
567 | }) |
568 | .Case([&](vector::MaskedStoreOp op) { |
569 | rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( |
570 | op, subViewOp.getSource(), sourceIndices, op.getMask(), |
571 | op.getValueToStore()); |
572 | }) |
573 | .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { |
574 | rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>( |
575 | op, op.getSrc(), subViewOp.getSource(), sourceIndices, |
576 | op.getLeadDimension(), op.getTransposeAttr()); |
577 | }) |
578 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
579 | return success(); |
580 | } |
581 | |
582 | template <typename OpTy> |
583 | LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( |
584 | OpTy storeOp, PatternRewriter &rewriter) const { |
585 | auto expandShapeOp = |
586 | getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>(); |
587 | |
588 | if (!expandShapeOp) |
589 | return failure(); |
590 | |
591 | SmallVector<Value> indices(storeOp.getIndices().begin(), |
592 | storeOp.getIndices().end()); |
593 | // For affine ops, we need to apply the map to get the operands to get the |
594 | // "actual" indices. |
595 | if (auto affineStoreOp = |
596 | dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { |
597 | AffineMap affineMap = affineStoreOp.getAffineMap(); |
598 | auto expandedIndices = calculateExpandedAccessIndices( |
599 | affineMap, indices, storeOp.getLoc(), rewriter); |
600 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
601 | } |
602 | SmallVector<Value> sourceIndices; |
603 | // memref.store and affine.store guarantee that indexes start inbounds |
604 | // while the vector operations don't. This impacts if our linearization |
605 | // is `disjoint` |
606 | if (failed(resolveSourceIndicesExpandShape( |
607 | storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices, |
608 | isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation())))) |
609 | return failure(); |
610 | llvm::TypeSwitch<Operation *, void>(storeOp) |
611 | .Case([&](affine::AffineStoreOp op) { |
612 | rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( |
613 | storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), |
614 | sourceIndices); |
615 | }) |
616 | .Case([&](memref::StoreOp op) { |
617 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
618 | storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), |
619 | sourceIndices, op.getNontemporal()); |
620 | }) |
621 | .Case([&](vector::StoreOp op) { |
622 | rewriter.replaceOpWithNewOp<vector::StoreOp>( |
623 | op, op.getValueToStore(), expandShapeOp.getViewSource(), |
624 | sourceIndices, op.getNontemporal()); |
625 | }) |
626 | .Case([&](vector::MaskedStoreOp op) { |
627 | rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( |
628 | op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(), |
629 | op.getValueToStore()); |
630 | }) |
631 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
632 | return success(); |
633 | } |
634 | |
635 | template <typename OpTy> |
636 | LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( |
637 | OpTy storeOp, PatternRewriter &rewriter) const { |
638 | auto collapseShapeOp = getMemRefOperand(storeOp) |
639 | .template getDefiningOp<memref::CollapseShapeOp>(); |
640 | |
641 | if (!collapseShapeOp) |
642 | return failure(); |
643 | |
644 | SmallVector<Value> indices(storeOp.getIndices().begin(), |
645 | storeOp.getIndices().end()); |
646 | // For affine ops, we need to apply the map to get the operands to get the |
647 | // "actual" indices. |
648 | if (auto affineStoreOp = |
649 | dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { |
650 | AffineMap affineMap = affineStoreOp.getAffineMap(); |
651 | auto expandedIndices = calculateExpandedAccessIndices( |
652 | affineMap, indices, storeOp.getLoc(), rewriter); |
653 | indices.assign(expandedIndices.begin(), expandedIndices.end()); |
654 | } |
655 | SmallVector<Value> sourceIndices; |
656 | if (failed(resolveSourceIndicesCollapseShape( |
657 | storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) |
658 | return failure(); |
659 | llvm::TypeSwitch<Operation *, void>(storeOp) |
660 | .Case([&](affine::AffineStoreOp op) { |
661 | rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( |
662 | storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), |
663 | sourceIndices); |
664 | }) |
665 | .Case([&](memref::StoreOp op) { |
666 | rewriter.replaceOpWithNewOp<memref::StoreOp>( |
667 | storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), |
668 | sourceIndices, op.getNontemporal()); |
669 | }) |
670 | .Case([&](vector::StoreOp op) { |
671 | rewriter.replaceOpWithNewOp<vector::StoreOp>( |
672 | op, op.getValueToStore(), collapseShapeOp.getViewSource(), |
673 | sourceIndices, op.getNontemporal()); |
674 | }) |
675 | .Case([&](vector::MaskedStoreOp op) { |
676 | rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( |
677 | op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(), |
678 | op.getValueToStore()); |
679 | }) |
680 | .Default([](Operation *) { llvm_unreachable("unexpected operation." ); }); |
681 | return success(); |
682 | } |
683 | |
684 | LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite( |
685 | nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { |
686 | |
687 | LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n" ); |
688 | |
689 | auto srcSubViewOp = |
690 | copyOp.getSrc().template getDefiningOp<memref::SubViewOp>(); |
691 | auto dstSubViewOp = |
692 | copyOp.getDst().template getDefiningOp<memref::SubViewOp>(); |
693 | |
694 | if (!(srcSubViewOp || dstSubViewOp)) |
695 | return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for " |
696 | "source or destination" ); |
697 | |
698 | // If the source is a subview, we need to resolve the indices. |
699 | SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(), |
700 | copyOp.getSrcIndices().end()); |
701 | SmallVector<Value> foldedSrcIndices(srcindices); |
702 | |
703 | if (srcSubViewOp) { |
704 | LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n" ); |
705 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
706 | rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(), |
707 | srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(), |
708 | srcindices, foldedSrcIndices); |
709 | } |
710 | |
711 | // If the destination is a subview, we need to resolve the indices. |
712 | SmallVector<Value> dstindices(copyOp.getDstIndices().begin(), |
713 | copyOp.getDstIndices().end()); |
714 | SmallVector<Value> foldedDstIndices(dstindices); |
715 | |
716 | if (dstSubViewOp) { |
717 | LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n" ); |
718 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
719 | rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(), |
720 | dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(), |
721 | dstindices, foldedDstIndices); |
722 | } |
723 | |
724 | // Replace the copy op with a new copy op that uses the source and destination |
725 | // of the subview. |
726 | rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>( |
727 | copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()), |
728 | (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()), |
729 | foldedDstIndices, |
730 | (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), |
731 | foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), |
732 | copyOp.getBypassL1Attr()); |
733 | |
734 | return success(); |
735 | } |
736 | |
737 | void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { |
738 | patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>, |
739 | LoadOpOfSubViewOpFolder<memref::LoadOp>, |
740 | LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>, |
741 | LoadOpOfSubViewOpFolder<vector::LoadOp>, |
742 | LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>, |
743 | LoadOpOfSubViewOpFolder<vector::TransferReadOp>, |
744 | LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>, |
745 | StoreOpOfSubViewOpFolder<affine::AffineStoreOp>, |
746 | StoreOpOfSubViewOpFolder<memref::StoreOp>, |
747 | StoreOpOfSubViewOpFolder<vector::TransferWriteOp>, |
748 | StoreOpOfSubViewOpFolder<vector::StoreOp>, |
749 | StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>, |
750 | StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>, |
751 | LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>, |
752 | LoadOpOfExpandShapeOpFolder<memref::LoadOp>, |
753 | LoadOpOfExpandShapeOpFolder<vector::LoadOp>, |
754 | LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, |
755 | StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, |
756 | StoreOpOfExpandShapeOpFolder<memref::StoreOp>, |
757 | StoreOpOfExpandShapeOpFolder<vector::StoreOp>, |
758 | StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>, |
759 | LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>, |
760 | LoadOpOfCollapseShapeOpFolder<memref::LoadOp>, |
761 | LoadOpOfCollapseShapeOpFolder<vector::LoadOp>, |
762 | LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>, |
763 | StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>, |
764 | StoreOpOfCollapseShapeOpFolder<memref::StoreOp>, |
765 | StoreOpOfCollapseShapeOpFolder<vector::StoreOp>, |
766 | StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>, |
767 | SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>( |
768 | patterns.getContext()); |
769 | } |
770 | |
771 | //===----------------------------------------------------------------------===// |
772 | // Pass registration |
773 | //===----------------------------------------------------------------------===// |
774 | |
775 | namespace { |
776 | |
777 | struct FoldMemRefAliasOpsPass final |
778 | : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> { |
779 | void runOnOperation() override; |
780 | }; |
781 | |
782 | } // namespace |
783 | |
784 | void FoldMemRefAliasOpsPass::runOnOperation() { |
785 | RewritePatternSet patterns(&getContext()); |
786 | memref::populateFoldMemRefAliasOpPatterns(patterns); |
787 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
788 | } |
789 | |