1//===- ExtractAddressCmoputations.cpp - Extract address computations -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This transformation pass rewrites loading/storing from/to a memref with
10/// offsets into loading/storing from/to a subview and without any offset on
11/// the instruction itself.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
19#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "mlir/IR/PatternMatch.h"
23
24using namespace mlir;
25
26namespace {
27
28//===----------------------------------------------------------------------===//
29// Helper functions for the `load base[off0...]`
30// => `load (subview base[off0...])[0...]` pattern.
31//===----------------------------------------------------------------------===//
32
33// Matches getFailureOrSrcMemRef specs for LoadOp.
34// \see LoadStoreLikeOpRewriter.
35static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {
36 return loadOp.getMemRef();
37}
38
39// Matches rebuildOpFromAddressAndIndices specs for LoadOp.
40// \see LoadStoreLikeOpRewriter.
41static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
42 memref::LoadOp loadOp, Value srcMemRef,
43 ArrayRef<Value> indices) {
44 Location loc = loadOp.getLoc();
45 return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices,
46 loadOp.getNontemporal());
47}
48
49// Matches getViewSizeForEachDim specs for LoadOp.
50// \see LoadStoreLikeOpRewriter.
51static SmallVector<OpFoldResult>
52getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {
53 MemRefType ldTy = loadOp.getMemRefType();
54 unsigned loadRank = ldTy.getRank();
55 return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
56}
57
58//===----------------------------------------------------------------------===//
59// Helper functions for the `store val, base[off0...]`
60// => `store val, (subview base[off0...])[0...]` pattern.
61//===----------------------------------------------------------------------===//
62
63// Matches getFailureOrSrcMemRef specs for StoreOp.
64// \see LoadStoreLikeOpRewriter.
65static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {
66 return storeOp.getMemRef();
67}
68
69// Matches rebuildOpFromAddressAndIndices specs for StoreOp.
70// \see LoadStoreLikeOpRewriter.
71static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
72 memref::StoreOp storeOp, Value srcMemRef,
73 ArrayRef<Value> indices) {
74 Location loc = storeOp.getLoc();
75 return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(),
76 srcMemRef, indices,
77 storeOp.getNontemporal());
78}
79
80// Matches getViewSizeForEachDim specs for StoreOp.
81// \see LoadStoreLikeOpRewriter.
82static SmallVector<OpFoldResult>
83getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {
84 MemRefType ldTy = storeOp.getMemRefType();
85 unsigned loadRank = ldTy.getRank();
86 return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1));
87}
88
89//===----------------------------------------------------------------------===//
90// Helper functions for the `ldmatrix base[off0...]`
91// => `ldmatrix (subview base[off0...])[0...]` pattern.
92//===----------------------------------------------------------------------===//
93
94// Matches getFailureOrSrcMemRef specs for LdMatrixOp.
95// \see LoadStoreLikeOpRewriter.
96static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {
97 return ldMatrixOp.getSrcMemref();
98}
99
100// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
101// \see LoadStoreLikeOpRewriter.
102static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
103 nvgpu::LdMatrixOp ldMatrixOp,
104 Value srcMemRef,
105 ArrayRef<Value> indices) {
106 Location loc = ldMatrixOp.getLoc();
107 return rewriter.create<nvgpu::LdMatrixOp>(
108 loc, ldMatrixOp.getResult().getType(), srcMemRef, indices,
109 ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles());
110}
111
112//===----------------------------------------------------------------------===//
113// Helper functions for the `transfer_read base[off0...]`
114// => `transfer_read (subview base[off0...])[0...]` pattern.
115//===----------------------------------------------------------------------===//
116
117// Matches getFailureOrSrcMemRef specs for TransferReadOp.
118// \see LoadStoreLikeOpRewriter.
119template <typename TransferLikeOp>
120static FailureOr<Value>
121getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
122 Value src = transferLikeOp.getSource();
123 if (isa<MemRefType>(Val: src.getType()))
124 return src;
125 return failure();
126}
127
128// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
129// \see LoadStoreLikeOpRewriter.
130static vector::TransferReadOp
131rebuildTransferReadOp(RewriterBase &rewriter,
132 vector::TransferReadOp transferReadOp, Value srcMemRef,
133 ArrayRef<Value> indices) {
134 Location loc = transferReadOp.getLoc();
135 return rewriter.create<vector::TransferReadOp>(
136 loc, transferReadOp.getResult().getType(), srcMemRef, indices,
137 transferReadOp.getPermutationMap(), transferReadOp.getPadding(),
138 transferReadOp.getMask(), transferReadOp.getInBoundsAttr());
139}
140
141//===----------------------------------------------------------------------===//
142// Helper functions for the `transfer_write base[off0...]`
143// => `transfer_write (subview base[off0...])[0...]` pattern.
144//===----------------------------------------------------------------------===//
145
146// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
147// \see LoadStoreLikeOpRewriter.
148static vector::TransferWriteOp
149rebuildTransferWriteOp(RewriterBase &rewriter,
150 vector::TransferWriteOp transferWriteOp, Value srcMemRef,
151 ArrayRef<Value> indices) {
152 Location loc = transferWriteOp.getLoc();
153 return rewriter.create<vector::TransferWriteOp>(
154 loc, transferWriteOp.getValue(), srcMemRef, indices,
155 transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(),
156 transferWriteOp.getInBoundsAttr());
157}
158
159//===----------------------------------------------------------------------===//
160// Generic helper functions used as default implementation in
161// LoadStoreLikeOpRewriter.
162//===----------------------------------------------------------------------===//
163
164/// Helper function to get the src memref.
165/// It uses the already defined getFailureOrSrcMemRef but asserts
166/// that the source is a memref.
167template <typename LoadStoreLikeOp,
168 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
169static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {
170 FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp);
171 assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used");
172 return *failureOrSrcMemRef;
173}
174
175/// Helper function to get the sizes of the resulting view.
176/// This function gets the sizes of the source memref then substracts the
177/// offsets used within \p loadStoreLikeOp. This gives the maximal (for
178/// inbound) sizes for the view.
179/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
180template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
181static SmallVector<OpFoldResult>
182getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
183 LoadStoreLikeOp loadStoreLikeOp) {
184 Location loc = loadStoreLikeOp.getLoc();
185 auto extractStridedMetadataOp =
186 rewriter.create<memref::ExtractStridedMetadataOp>(
187 loc, getSrcMemRef(loadStoreLikeOp));
188 SmallVector<OpFoldResult> srcSizes =
189 extractStridedMetadataOp.getConstifiedMixedSizes();
190 SmallVector<OpFoldResult> indices =
191 getAsOpFoldResult(loadStoreLikeOp.getIndices());
192 SmallVector<OpFoldResult> finalSizes;
193
194 AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0);
195 AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1);
196
197 for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) {
198 finalSizes.push_back(affine::makeComposedFoldedAffineApply(
199 rewriter, loc, s0 - s1, {srcSize, indice}));
200 }
201 return finalSizes;
202}
203
204/// Rewrite a store/load-like op so that all its indices are zeros.
205/// E.g., %ld = memref.load %base[%off0]...[%offN]
206/// =>
207/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
208/// %ld = memref.load %new_base[0,..,0] :
209/// memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
210///
211/// `getSrcMemRef` returns the source memref for the given load-like operation.
212///
213/// `getViewSizeForEachDim` returns the sizes of view that is going to feed
214/// new operation. This must return one size per dimension of the view.
215/// The sizes of the view needs to be at least as big as what is actually
216/// going to be accessed. Use the provided `loadStoreOp` to get the right
217/// sizes.
218///
219/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
220/// LoadStoreLikeOp that reads from srcMemRef[indices].
221/// The returned operation will be used to replace loadStoreOp.
222template <typename LoadStoreLikeOp,
223 FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
224 LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
225 RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
226 Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
227 SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
228 RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
229 getGenericOpViewSizeForEachDim<
230 LoadStoreLikeOp,
231 getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
232struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
233 using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;
234
235 LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
236 PatternRewriter &rewriter) const override {
237 FailureOr<Value> failureOrSrcMemRef =
238 getFailureOrSrcMemRef(loadStoreLikeOp);
239 if (failed(result: failureOrSrcMemRef))
240 return rewriter.notifyMatchFailure(loadStoreLikeOp,
241 "source is not a memref");
242 Value srcMemRef = *failureOrSrcMemRef;
243 auto ldStTy = cast<MemRefType>(srcMemRef.getType());
244 unsigned loadStoreRank = ldStTy.getRank();
245 // Don't waste compile time if there is nothing to rewrite.
246 if (loadStoreRank == 0)
247 return rewriter.notifyMatchFailure(loadStoreLikeOp,
248 "0-D accesses don't need rewriting");
249
250 // If our load already has only zeros as indices there is nothing
251 // to do.
252 SmallVector<OpFoldResult> indices =
253 getAsOpFoldResult(loadStoreLikeOp.getIndices());
254 if (std::all_of(indices.begin(), indices.end(),
255 [](const OpFoldResult &opFold) {
256 return isConstantIntValue(ofr: opFold, value: 0);
257 })) {
258 return rewriter.notifyMatchFailure(
259 loadStoreLikeOp, "no computation to extract: offsets are 0s");
260 }
261
262 // Create the array of ones of the right size.
263 SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1));
264 SmallVector<OpFoldResult> sizes =
265 getViewSizeForEachDim(rewriter, loadStoreLikeOp);
266 assert(sizes.size() == loadStoreRank &&
267 "Expected one size per load dimension");
268 Location loc = loadStoreLikeOp.getLoc();
269 // The subview inherits its strides from the original memref and will
270 // apply them properly to the input indices.
271 // Therefore the strides multipliers are simply ones.
272 auto subview =
273 rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef,
274 /*offsets=*/indices,
275 /*sizes=*/sizes, /*strides=*/ones);
276 // Rewrite the load/store with the subview as the base pointer.
277 SmallVector<Value> zeros(loadStoreRank,
278 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
279 LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices(
280 rewriter, loadStoreLikeOp, subview.getResult(), zeros);
281 rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults());
282 return success();
283 }
284};
285} // namespace
286
287void memref::populateExtractAddressComputationsPatterns(
288 RewritePatternSet &patterns) {
289 patterns.add<
290 LoadStoreLikeOpRewriter<
291 memref::LoadOp,
292 /*getSrcMemRef=*/getLoadOpSrcMemRef,
293 /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp,
294 /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>,
295 LoadStoreLikeOpRewriter<
296 memref::StoreOp,
297 /*getSrcMemRef=*/getStoreOpSrcMemRef,
298 /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp,
299 /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>,
300 LoadStoreLikeOpRewriter<
301 nvgpu::LdMatrixOp,
302 /*getSrcMemRef=*/getLdMatrixOpSrcMemRef,
303 /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>,
304 LoadStoreLikeOpRewriter<
305 vector::TransferReadOp,
306 /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>,
307 /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>,
308 LoadStoreLikeOpRewriter<
309 vector::TransferWriteOp,
310 /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>,
311 /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>(
312 patterns.getContext());
313}
314

source code of mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp