1 | //===- ExtractAddressCmoputations.cpp - Extract address computations -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// This transformation pass rewrites loading/storing from/to a memref with |
10 | /// offsets into loading/storing from/to a subview and without any offset on |
11 | /// the instruction itself. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
19 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
21 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
22 | #include "mlir/IR/PatternMatch.h" |
23 | |
24 | using namespace mlir; |
25 | |
26 | namespace { |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // Helper functions for the `load base[off0...]` |
30 | // => `load (subview base[off0...])[0...]` pattern. |
31 | //===----------------------------------------------------------------------===// |
32 | |
33 | // Matches getFailureOrSrcMemRef specs for LoadOp. |
34 | // \see LoadStoreLikeOpRewriter. |
35 | static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) { |
36 | return loadOp.getMemRef(); |
37 | } |
38 | |
39 | // Matches rebuildOpFromAddressAndIndices specs for LoadOp. |
40 | // \see LoadStoreLikeOpRewriter. |
41 | static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, |
42 | memref::LoadOp loadOp, Value srcMemRef, |
43 | ArrayRef<Value> indices) { |
44 | Location loc = loadOp.getLoc(); |
45 | return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices, |
46 | loadOp.getNontemporal()); |
47 | } |
48 | |
49 | // Matches getViewSizeForEachDim specs for LoadOp. |
50 | // \see LoadStoreLikeOpRewriter. |
51 | static SmallVector<OpFoldResult> |
52 | getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) { |
53 | MemRefType ldTy = loadOp.getMemRefType(); |
54 | unsigned loadRank = ldTy.getRank(); |
55 | return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1)); |
56 | } |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // Helper functions for the `store val, base[off0...]` |
60 | // => `store val, (subview base[off0...])[0...]` pattern. |
61 | //===----------------------------------------------------------------------===// |
62 | |
63 | // Matches getFailureOrSrcMemRef specs for StoreOp. |
64 | // \see LoadStoreLikeOpRewriter. |
65 | static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) { |
66 | return storeOp.getMemRef(); |
67 | } |
68 | |
69 | // Matches rebuildOpFromAddressAndIndices specs for StoreOp. |
70 | // \see LoadStoreLikeOpRewriter. |
71 | static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, |
72 | memref::StoreOp storeOp, Value srcMemRef, |
73 | ArrayRef<Value> indices) { |
74 | Location loc = storeOp.getLoc(); |
75 | return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(), |
76 | srcMemRef, indices, |
77 | storeOp.getNontemporal()); |
78 | } |
79 | |
80 | // Matches getViewSizeForEachDim specs for StoreOp. |
81 | // \see LoadStoreLikeOpRewriter. |
82 | static SmallVector<OpFoldResult> |
83 | getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) { |
84 | MemRefType ldTy = storeOp.getMemRefType(); |
85 | unsigned loadRank = ldTy.getRank(); |
86 | return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1)); |
87 | } |
88 | |
89 | //===----------------------------------------------------------------------===// |
90 | // Helper functions for the `ldmatrix base[off0...]` |
91 | // => `ldmatrix (subview base[off0...])[0...]` pattern. |
92 | //===----------------------------------------------------------------------===// |
93 | |
94 | // Matches getFailureOrSrcMemRef specs for LdMatrixOp. |
95 | // \see LoadStoreLikeOpRewriter. |
96 | static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) { |
97 | return ldMatrixOp.getSrcMemref(); |
98 | } |
99 | |
100 | // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp. |
101 | // \see LoadStoreLikeOpRewriter. |
102 | static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, |
103 | nvgpu::LdMatrixOp ldMatrixOp, |
104 | Value srcMemRef, |
105 | ArrayRef<Value> indices) { |
106 | Location loc = ldMatrixOp.getLoc(); |
107 | return rewriter.create<nvgpu::LdMatrixOp>( |
108 | loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, |
109 | ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles()); |
110 | } |
111 | |
112 | //===----------------------------------------------------------------------===// |
113 | // Helper functions for the `transfer_read base[off0...]` |
114 | // => `transfer_read (subview base[off0...])[0...]` pattern. |
115 | //===----------------------------------------------------------------------===// |
116 | |
117 | // Matches getFailureOrSrcMemRef specs for TransferReadOp. |
118 | // \see LoadStoreLikeOpRewriter. |
119 | template <typename TransferLikeOp> |
120 | static FailureOr<Value> |
121 | getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { |
122 | Value src = transferLikeOp.getSource(); |
123 | if (isa<MemRefType>(Val: src.getType())) |
124 | return src; |
125 | return failure(); |
126 | } |
127 | |
128 | // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp. |
129 | // \see LoadStoreLikeOpRewriter. |
130 | static vector::TransferReadOp |
131 | rebuildTransferReadOp(RewriterBase &rewriter, |
132 | vector::TransferReadOp transferReadOp, Value srcMemRef, |
133 | ArrayRef<Value> indices) { |
134 | Location loc = transferReadOp.getLoc(); |
135 | return rewriter.create<vector::TransferReadOp>( |
136 | loc, transferReadOp.getResult().getType(), srcMemRef, indices, |
137 | transferReadOp.getPermutationMap(), transferReadOp.getPadding(), |
138 | transferReadOp.getMask(), transferReadOp.getInBoundsAttr()); |
139 | } |
140 | |
141 | //===----------------------------------------------------------------------===// |
142 | // Helper functions for the `transfer_write base[off0...]` |
143 | // => `transfer_write (subview base[off0...])[0...]` pattern. |
144 | //===----------------------------------------------------------------------===// |
145 | |
146 | // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp. |
147 | // \see LoadStoreLikeOpRewriter. |
148 | static vector::TransferWriteOp |
149 | rebuildTransferWriteOp(RewriterBase &rewriter, |
150 | vector::TransferWriteOp transferWriteOp, Value srcMemRef, |
151 | ArrayRef<Value> indices) { |
152 | Location loc = transferWriteOp.getLoc(); |
153 | return rewriter.create<vector::TransferWriteOp>( |
154 | loc, transferWriteOp.getValue(), srcMemRef, indices, |
155 | transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(), |
156 | transferWriteOp.getInBoundsAttr()); |
157 | } |
158 | |
159 | //===----------------------------------------------------------------------===// |
160 | // Generic helper functions used as default implementation in |
161 | // LoadStoreLikeOpRewriter. |
162 | //===----------------------------------------------------------------------===// |
163 | |
164 | /// Helper function to get the src memref. |
165 | /// It uses the already defined getFailureOrSrcMemRef but asserts |
166 | /// that the source is a memref. |
167 | template <typename LoadStoreLikeOp, |
168 | FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)> |
169 | static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) { |
170 | FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp); |
171 | assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used" ); |
172 | return *failureOrSrcMemRef; |
173 | } |
174 | |
175 | /// Helper function to get the sizes of the resulting view. |
176 | /// This function gets the sizes of the source memref then substracts the |
177 | /// offsets used within \p loadStoreLikeOp. This gives the maximal (for |
178 | /// inbound) sizes for the view. |
179 | /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp. |
180 | template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)> |
181 | static SmallVector<OpFoldResult> |
182 | getGenericOpViewSizeForEachDim(RewriterBase &rewriter, |
183 | LoadStoreLikeOp loadStoreLikeOp) { |
184 | Location loc = loadStoreLikeOp.getLoc(); |
185 | auto = |
186 | rewriter.create<memref::ExtractStridedMetadataOp>( |
187 | loc, getSrcMemRef(loadStoreLikeOp)); |
188 | SmallVector<OpFoldResult> srcSizes = |
189 | extractStridedMetadataOp.getConstifiedMixedSizes(); |
190 | SmallVector<OpFoldResult> indices = |
191 | getAsOpFoldResult(loadStoreLikeOp.getIndices()); |
192 | SmallVector<OpFoldResult> finalSizes; |
193 | |
194 | AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0); |
195 | AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1); |
196 | |
197 | for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) { |
198 | finalSizes.push_back(affine::makeComposedFoldedAffineApply( |
199 | rewriter, loc, s0 - s1, {srcSize, indice})); |
200 | } |
201 | return finalSizes; |
202 | } |
203 | |
204 | /// Rewrite a store/load-like op so that all its indices are zeros. |
205 | /// E.g., %ld = memref.load %base[%off0]...[%offN] |
206 | /// => |
207 | /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1] |
208 | /// %ld = memref.load %new_base[0,..,0] : |
209 | /// memref<1x..x1xTy, strided<[1,..,1], offset: ?>> |
210 | /// |
211 | /// `getSrcMemRef` returns the source memref for the given load-like operation. |
212 | /// |
213 | /// `getViewSizeForEachDim` returns the sizes of view that is going to feed |
214 | /// new operation. This must return one size per dimension of the view. |
215 | /// The sizes of the view needs to be at least as big as what is actually |
216 | /// going to be accessed. Use the provided `loadStoreOp` to get the right |
217 | /// sizes. |
218 | /// |
219 | /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new |
220 | /// LoadStoreLikeOp that reads from srcMemRef[indices]. |
221 | /// The returned operation will be used to replace loadStoreOp. |
222 | template <typename LoadStoreLikeOp, |
223 | FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp), |
224 | LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)( |
225 | RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/, |
226 | Value /*srcMemRef*/, ArrayRef<Value> /*indices*/), |
227 | SmallVector<OpFoldResult> (*getViewSizeForEachDim)( |
228 | RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) = |
229 | getGenericOpViewSizeForEachDim< |
230 | LoadStoreLikeOp, |
231 | getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>> |
232 | struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> { |
233 | using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern; |
234 | |
235 | LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp, |
236 | PatternRewriter &rewriter) const override { |
237 | FailureOr<Value> failureOrSrcMemRef = |
238 | getFailureOrSrcMemRef(loadStoreLikeOp); |
239 | if (failed(result: failureOrSrcMemRef)) |
240 | return rewriter.notifyMatchFailure(loadStoreLikeOp, |
241 | "source is not a memref" ); |
242 | Value srcMemRef = *failureOrSrcMemRef; |
243 | auto ldStTy = cast<MemRefType>(srcMemRef.getType()); |
244 | unsigned loadStoreRank = ldStTy.getRank(); |
245 | // Don't waste compile time if there is nothing to rewrite. |
246 | if (loadStoreRank == 0) |
247 | return rewriter.notifyMatchFailure(loadStoreLikeOp, |
248 | "0-D accesses don't need rewriting" ); |
249 | |
250 | // If our load already has only zeros as indices there is nothing |
251 | // to do. |
252 | SmallVector<OpFoldResult> indices = |
253 | getAsOpFoldResult(loadStoreLikeOp.getIndices()); |
254 | if (std::all_of(indices.begin(), indices.end(), |
255 | [](const OpFoldResult &opFold) { |
256 | return isConstantIntValue(ofr: opFold, value: 0); |
257 | })) { |
258 | return rewriter.notifyMatchFailure( |
259 | loadStoreLikeOp, "no computation to extract: offsets are 0s" ); |
260 | } |
261 | |
262 | // Create the array of ones of the right size. |
263 | SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1)); |
264 | SmallVector<OpFoldResult> sizes = |
265 | getViewSizeForEachDim(rewriter, loadStoreLikeOp); |
266 | assert(sizes.size() == loadStoreRank && |
267 | "Expected one size per load dimension" ); |
268 | Location loc = loadStoreLikeOp.getLoc(); |
269 | // The subview inherits its strides from the original memref and will |
270 | // apply them properly to the input indices. |
271 | // Therefore the strides multipliers are simply ones. |
272 | auto subview = |
273 | rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef, |
274 | /*offsets=*/indices, |
275 | /*sizes=*/sizes, /*strides=*/ones); |
276 | // Rewrite the load/store with the subview as the base pointer. |
277 | SmallVector<Value> zeros(loadStoreRank, |
278 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
279 | LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices( |
280 | rewriter, loadStoreLikeOp, subview.getResult(), zeros); |
281 | rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults()); |
282 | return success(); |
283 | } |
284 | }; |
285 | } // namespace |
286 | |
287 | void memref::( |
288 | RewritePatternSet &patterns) { |
289 | patterns.add< |
290 | LoadStoreLikeOpRewriter< |
291 | memref::LoadOp, |
292 | /*getSrcMemRef=*/getLoadOpSrcMemRef, |
293 | /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp, |
294 | /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>, |
295 | LoadStoreLikeOpRewriter< |
296 | memref::StoreOp, |
297 | /*getSrcMemRef=*/getStoreOpSrcMemRef, |
298 | /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp, |
299 | /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>, |
300 | LoadStoreLikeOpRewriter< |
301 | nvgpu::LdMatrixOp, |
302 | /*getSrcMemRef=*/getLdMatrixOpSrcMemRef, |
303 | /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>, |
304 | LoadStoreLikeOpRewriter< |
305 | vector::TransferReadOp, |
306 | /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>, |
307 | /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>, |
308 | LoadStoreLikeOpRewriter< |
309 | vector::TransferWriteOp, |
310 | /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>, |
311 | /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>( |
312 | patterns.getContext()); |
313 | } |
314 | |