| 1 | //===- ExtractAddressCmoputations.cpp - Extract address computations -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | /// This transformation pass rewrites loading/storing from/to a memref with |
| 10 | /// offsets into loading/storing from/to a subview and without any offset on |
| 11 | /// the instruction itself. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 18 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 19 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
| 20 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 21 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | |
| 24 | using namespace mlir; |
| 25 | |
| 26 | namespace { |
| 27 | |
| 28 | //===----------------------------------------------------------------------===// |
| 29 | // Helper functions for the `load base[off0...]` |
| 30 | // => `load (subview base[off0...])[0...]` pattern. |
| 31 | //===----------------------------------------------------------------------===// |
| 32 | |
| 33 | // Matches getFailureOrSrcMemRef specs for LoadOp. |
| 34 | // \see LoadStoreLikeOpRewriter. |
| 35 | static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) { |
| 36 | return loadOp.getMemRef(); |
| 37 | } |
| 38 | |
| 39 | // Matches rebuildOpFromAddressAndIndices specs for LoadOp. |
| 40 | // \see LoadStoreLikeOpRewriter. |
| 41 | static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, |
| 42 | memref::LoadOp loadOp, Value srcMemRef, |
| 43 | ArrayRef<Value> indices) { |
| 44 | Location loc = loadOp.getLoc(); |
| 45 | return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices, |
| 46 | loadOp.getNontemporal()); |
| 47 | } |
| 48 | |
| 49 | // Matches getViewSizeForEachDim specs for LoadOp. |
| 50 | // \see LoadStoreLikeOpRewriter. |
| 51 | static SmallVector<OpFoldResult> |
| 52 | getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) { |
| 53 | MemRefType ldTy = loadOp.getMemRefType(); |
| 54 | unsigned loadRank = ldTy.getRank(); |
| 55 | return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1)); |
| 56 | } |
| 57 | |
| 58 | //===----------------------------------------------------------------------===// |
| 59 | // Helper functions for the `store val, base[off0...]` |
| 60 | // => `store val, (subview base[off0...])[0...]` pattern. |
| 61 | //===----------------------------------------------------------------------===// |
| 62 | |
| 63 | // Matches getFailureOrSrcMemRef specs for StoreOp. |
| 64 | // \see LoadStoreLikeOpRewriter. |
| 65 | static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) { |
| 66 | return storeOp.getMemRef(); |
| 67 | } |
| 68 | |
| 69 | // Matches rebuildOpFromAddressAndIndices specs for StoreOp. |
| 70 | // \see LoadStoreLikeOpRewriter. |
| 71 | static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, |
| 72 | memref::StoreOp storeOp, Value srcMemRef, |
| 73 | ArrayRef<Value> indices) { |
| 74 | Location loc = storeOp.getLoc(); |
| 75 | return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(), |
| 76 | srcMemRef, indices, |
| 77 | storeOp.getNontemporal()); |
| 78 | } |
| 79 | |
| 80 | // Matches getViewSizeForEachDim specs for StoreOp. |
| 81 | // \see LoadStoreLikeOpRewriter. |
| 82 | static SmallVector<OpFoldResult> |
| 83 | getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) { |
| 84 | MemRefType ldTy = storeOp.getMemRefType(); |
| 85 | unsigned loadRank = ldTy.getRank(); |
| 86 | return SmallVector<OpFoldResult>(loadRank, rewriter.getIndexAttr(1)); |
| 87 | } |
| 88 | |
| 89 | //===----------------------------------------------------------------------===// |
| 90 | // Helper functions for the `ldmatrix base[off0...]` |
| 91 | // => `ldmatrix (subview base[off0...])[0...]` pattern. |
| 92 | //===----------------------------------------------------------------------===// |
| 93 | |
| 94 | // Matches getFailureOrSrcMemRef specs for LdMatrixOp. |
| 95 | // \see LoadStoreLikeOpRewriter. |
| 96 | static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) { |
| 97 | return ldMatrixOp.getSrcMemref(); |
| 98 | } |
| 99 | |
| 100 | // Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp. |
| 101 | // \see LoadStoreLikeOpRewriter. |
| 102 | static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, |
| 103 | nvgpu::LdMatrixOp ldMatrixOp, |
| 104 | Value srcMemRef, |
| 105 | ArrayRef<Value> indices) { |
| 106 | Location loc = ldMatrixOp.getLoc(); |
| 107 | return rewriter.create<nvgpu::LdMatrixOp>( |
| 108 | loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, |
| 109 | ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles()); |
| 110 | } |
| 111 | |
| 112 | //===----------------------------------------------------------------------===// |
| 113 | // Helper functions for the `transfer_read base[off0...]` |
| 114 | // => `transfer_read (subview base[off0...])[0...]` pattern. |
| 115 | //===----------------------------------------------------------------------===// |
| 116 | |
| 117 | // Matches getFailureOrSrcMemRef specs for TransferReadOp. |
| 118 | // \see LoadStoreLikeOpRewriter. |
| 119 | template <typename TransferLikeOp> |
| 120 | static FailureOr<Value> |
| 121 | getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { |
| 122 | Value src = transferLikeOp.getBase(); |
| 123 | if (isa<MemRefType>(Val: src.getType())) |
| 124 | return src; |
| 125 | return failure(); |
| 126 | } |
| 127 | |
| 128 | // Matches rebuildOpFromAddressAndIndices specs for TransferReadOp. |
| 129 | // \see LoadStoreLikeOpRewriter. |
| 130 | static vector::TransferReadOp |
| 131 | rebuildTransferReadOp(RewriterBase &rewriter, |
| 132 | vector::TransferReadOp transferReadOp, Value srcMemRef, |
| 133 | ArrayRef<Value> indices) { |
| 134 | Location loc = transferReadOp.getLoc(); |
| 135 | return rewriter.create<vector::TransferReadOp>( |
| 136 | loc, transferReadOp.getResult().getType(), srcMemRef, indices, |
| 137 | transferReadOp.getPermutationMap(), transferReadOp.getPadding(), |
| 138 | transferReadOp.getMask(), transferReadOp.getInBoundsAttr()); |
| 139 | } |
| 140 | |
| 141 | //===----------------------------------------------------------------------===// |
| 142 | // Helper functions for the `transfer_write base[off0...]` |
| 143 | // => `transfer_write (subview base[off0...])[0...]` pattern. |
| 144 | //===----------------------------------------------------------------------===// |
| 145 | |
| 146 | // Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp. |
| 147 | // \see LoadStoreLikeOpRewriter. |
| 148 | static vector::TransferWriteOp |
| 149 | rebuildTransferWriteOp(RewriterBase &rewriter, |
| 150 | vector::TransferWriteOp transferWriteOp, Value srcMemRef, |
| 151 | ArrayRef<Value> indices) { |
| 152 | Location loc = transferWriteOp.getLoc(); |
| 153 | return rewriter.create<vector::TransferWriteOp>( |
| 154 | loc, transferWriteOp.getValue(), srcMemRef, indices, |
| 155 | transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(), |
| 156 | transferWriteOp.getInBoundsAttr()); |
| 157 | } |
| 158 | |
| 159 | //===----------------------------------------------------------------------===// |
| 160 | // Generic helper functions used as default implementation in |
| 161 | // LoadStoreLikeOpRewriter. |
| 162 | //===----------------------------------------------------------------------===// |
| 163 | |
| 164 | /// Helper function to get the src memref. |
| 165 | /// It uses the already defined getFailureOrSrcMemRef but asserts |
| 166 | /// that the source is a memref. |
| 167 | template <typename LoadStoreLikeOp, |
| 168 | FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)> |
| 169 | static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) { |
| 170 | FailureOr<Value> failureOrSrcMemRef = getFailureOrSrcMemRef(loadStoreLikeOp); |
| 171 | assert(!failed(failureOrSrcMemRef) && "Generic getSrcMemRef cannot be used" ); |
| 172 | return *failureOrSrcMemRef; |
| 173 | } |
| 174 | |
| 175 | /// Helper function to get the sizes of the resulting view. |
| 176 | /// This function gets the sizes of the source memref then substracts the |
| 177 | /// offsets used within \p loadStoreLikeOp. This gives the maximal (for |
| 178 | /// inbound) sizes for the view. |
| 179 | /// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp. |
| 180 | template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)> |
| 181 | static SmallVector<OpFoldResult> |
| 182 | getGenericOpViewSizeForEachDim(RewriterBase &rewriter, |
| 183 | LoadStoreLikeOp loadStoreLikeOp) { |
| 184 | Location loc = loadStoreLikeOp.getLoc(); |
| 185 | auto = |
| 186 | rewriter.create<memref::ExtractStridedMetadataOp>( |
| 187 | loc, getSrcMemRef(loadStoreLikeOp)); |
| 188 | SmallVector<OpFoldResult> srcSizes = |
| 189 | extractStridedMetadataOp.getConstifiedMixedSizes(); |
| 190 | SmallVector<OpFoldResult> indices = |
| 191 | getAsOpFoldResult(loadStoreLikeOp.getIndices()); |
| 192 | SmallVector<OpFoldResult> finalSizes; |
| 193 | |
| 194 | AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0); |
| 195 | AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1); |
| 196 | |
| 197 | for (auto [srcSize, indice] : llvm::zip(srcSizes, indices)) { |
| 198 | finalSizes.push_back(affine::makeComposedFoldedAffineApply( |
| 199 | rewriter, loc, s0 - s1, {srcSize, indice})); |
| 200 | } |
| 201 | return finalSizes; |
| 202 | } |
| 203 | |
| 204 | /// Rewrite a store/load-like op so that all its indices are zeros. |
| 205 | /// E.g., %ld = memref.load %base[%off0]...[%offN] |
| 206 | /// => |
| 207 | /// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1] |
| 208 | /// %ld = memref.load %new_base[0,..,0] : |
| 209 | /// memref<1x..x1xTy, strided<[1,..,1], offset: ?>> |
| 210 | /// |
| 211 | /// `getSrcMemRef` returns the source memref for the given load-like operation. |
| 212 | /// |
| 213 | /// `getViewSizeForEachDim` returns the sizes of view that is going to feed |
| 214 | /// new operation. This must return one size per dimension of the view. |
| 215 | /// The sizes of the view needs to be at least as big as what is actually |
| 216 | /// going to be accessed. Use the provided `loadStoreOp` to get the right |
| 217 | /// sizes. |
| 218 | /// |
| 219 | /// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new |
| 220 | /// LoadStoreLikeOp that reads from srcMemRef[indices]. |
| 221 | /// The returned operation will be used to replace loadStoreOp. |
| 222 | template <typename LoadStoreLikeOp, |
| 223 | FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp), |
| 224 | LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)( |
| 225 | RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/, |
| 226 | Value /*srcMemRef*/, ArrayRef<Value> /*indices*/), |
| 227 | SmallVector<OpFoldResult> (*getViewSizeForEachDim)( |
| 228 | RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) = |
| 229 | getGenericOpViewSizeForEachDim< |
| 230 | LoadStoreLikeOp, |
| 231 | getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>> |
| 232 | struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> { |
| 233 | using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern; |
| 234 | |
| 235 | LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp, |
| 236 | PatternRewriter &rewriter) const override { |
| 237 | FailureOr<Value> failureOrSrcMemRef = |
| 238 | getFailureOrSrcMemRef(loadStoreLikeOp); |
| 239 | if (failed(Result: failureOrSrcMemRef)) |
| 240 | return rewriter.notifyMatchFailure(loadStoreLikeOp, |
| 241 | "source is not a memref" ); |
| 242 | Value srcMemRef = *failureOrSrcMemRef; |
| 243 | auto ldStTy = cast<MemRefType>(srcMemRef.getType()); |
| 244 | unsigned loadStoreRank = ldStTy.getRank(); |
| 245 | // Don't waste compile time if there is nothing to rewrite. |
| 246 | if (loadStoreRank == 0) |
| 247 | return rewriter.notifyMatchFailure(loadStoreLikeOp, |
| 248 | "0-D accesses don't need rewriting" ); |
| 249 | |
| 250 | // If our load already has only zeros as indices there is nothing |
| 251 | // to do. |
| 252 | SmallVector<OpFoldResult> indices = |
| 253 | getAsOpFoldResult(loadStoreLikeOp.getIndices()); |
| 254 | if (llvm::all_of(Range&: indices, P: isZeroInteger)) { |
| 255 | return rewriter.notifyMatchFailure( |
| 256 | loadStoreLikeOp, "no computation to extract: offsets are 0s" ); |
| 257 | } |
| 258 | |
| 259 | // Create the array of ones of the right size. |
| 260 | SmallVector<OpFoldResult> ones(loadStoreRank, rewriter.getIndexAttr(1)); |
| 261 | SmallVector<OpFoldResult> sizes = |
| 262 | getViewSizeForEachDim(rewriter, loadStoreLikeOp); |
| 263 | assert(sizes.size() == loadStoreRank && |
| 264 | "Expected one size per load dimension" ); |
| 265 | Location loc = loadStoreLikeOp.getLoc(); |
| 266 | // The subview inherits its strides from the original memref and will |
| 267 | // apply them properly to the input indices. |
| 268 | // Therefore the strides multipliers are simply ones. |
| 269 | auto subview = |
| 270 | rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef, |
| 271 | /*offsets=*/indices, |
| 272 | /*sizes=*/sizes, /*strides=*/ones); |
| 273 | // Rewrite the load/store with the subview as the base pointer. |
| 274 | SmallVector<Value> zeros(loadStoreRank, |
| 275 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
| 276 | LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices( |
| 277 | rewriter, loadStoreLikeOp, subview.getResult(), zeros); |
| 278 | rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults()); |
| 279 | return success(); |
| 280 | } |
| 281 | }; |
| 282 | } // namespace |
| 283 | |
| 284 | void memref::( |
| 285 | RewritePatternSet &patterns) { |
| 286 | patterns.add< |
| 287 | LoadStoreLikeOpRewriter< |
| 288 | memref::LoadOp, |
| 289 | /*getSrcMemRef=*/getLoadOpSrcMemRef, |
| 290 | /*rebuildOpFromAddressAndIndices=*/rebuildLoadOp, |
| 291 | /*getViewSizeForEachDim=*/getLoadOpViewSizeForEachDim>, |
| 292 | LoadStoreLikeOpRewriter< |
| 293 | memref::StoreOp, |
| 294 | /*getSrcMemRef=*/getStoreOpSrcMemRef, |
| 295 | /*rebuildOpFromAddressAndIndices=*/rebuildStoreOp, |
| 296 | /*getViewSizeForEachDim=*/getStoreOpViewSizeForEachDim>, |
| 297 | LoadStoreLikeOpRewriter< |
| 298 | nvgpu::LdMatrixOp, |
| 299 | /*getSrcMemRef=*/getLdMatrixOpSrcMemRef, |
| 300 | /*rebuildOpFromAddressAndIndices=*/rebuildLdMatrixOp>, |
| 301 | LoadStoreLikeOpRewriter< |
| 302 | vector::TransferReadOp, |
| 303 | /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferReadOp>, |
| 304 | /*rebuildOpFromAddressAndIndices=*/rebuildTransferReadOp>, |
| 305 | LoadStoreLikeOpRewriter< |
| 306 | vector::TransferWriteOp, |
| 307 | /*getSrcMemRef=*/getTransferLikeOpSrcMemRef<vector::TransferWriteOp>, |
| 308 | /*rebuildOpFromAddressAndIndices=*/rebuildTransferWriteOp>>( |
| 309 | patterns.getContext()); |
| 310 | } |
| 311 | |