| 1 | //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements transforms to optimize accesses to shared memory. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/NVGPU/Transforms/Passes.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 18 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
| 19 | #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
| 20 | #include "mlir/Dialect/NVGPU/Transforms/Utils.h" |
| 21 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
| 23 | #include "llvm/ADT/STLExtras.h" |
| 24 | #include "llvm/Support/MathExtras.h" |
| 25 | |
| 26 | namespace mlir { |
| 27 | namespace nvgpu { |
| 28 | #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY |
| 29 | #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc" |
| 30 | } // namespace nvgpu |
| 31 | } // namespace mlir |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace mlir::nvgpu; |
| 35 | |
| 36 | /// The size of a shared memory line according to NV documentation. |
| 37 | constexpr int64_t kSharedMemoryLineSizeBytes = 128; |
| 38 | /// We optimize for 128bit accesses, but this can be made an argument in the |
| 39 | /// future. |
| 40 | constexpr int64_t kDefaultVectorSizeBits = 128; |
| 41 | |
| 42 | /// Uses `srcIndexValue` to permute `tgtIndexValue` via |
| 43 | /// `result = xor(floordiv(srcIdxVal,permuteEveryN), |
| 44 | /// floordiv(tgtIdxVal,vectorSize))) |
| 45 | /// + tgtIdxVal % vectorSize` |
| 46 | /// This is done using an optimized sequence of `arith` operations. |
| 47 | static Value permuteVectorOffset(OpBuilder &b, Location loc, |
| 48 | ArrayRef<Value> indices, MemRefType memrefTy, |
| 49 | int64_t srcDim, int64_t tgtDim) { |
| 50 | // Adjust the src index to change how often the permutation changes |
| 51 | // if necessary. |
| 52 | Value src = indices[srcDim]; |
| 53 | |
| 54 | // We only want to permute every N iterations of the target dim where N is |
| 55 | // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). |
| 56 | const int64_t permuteEveryN = std::max<int64_t>( |
| 57 | 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * |
| 58 | memrefTy.getElementTypeBitWidth()) / |
| 59 | 8)); |
| 60 | |
| 61 | // clang-format off |
| 62 | // Index bit representation (b0 = least significant bit) for dim(1) |
| 63 | // of a `memref<?x?xDT>` is as follows: |
| 64 | // N := log2(128/elementSizeBits) |
| 65 | // M := log2(dimSize(1)) |
| 66 | // then |
| 67 | // bits[0:N] = sub-vector element offset |
| 68 | // bits[N:M] = vector index |
| 69 | // clang-format on |
| 70 | int64_t n = |
| 71 | llvm::Log2_64(Value: kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); |
| 72 | int64_t m = llvm::Log2_64(Value: memrefTy.getDimSize(tgtDim)); |
| 73 | |
| 74 | // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. |
| 75 | int64_t mask = (1LL << (m - n)) - 1; |
| 76 | if (permuteEveryN > 1) |
| 77 | mask = mask << llvm::Log2_64(Value: permuteEveryN); |
| 78 | Value srcBits = b.create<arith::ConstantIndexOp>(location: loc, args&: mask); |
| 79 | srcBits = b.create<arith::AndIOp>(loc, src, srcBits); |
| 80 | |
| 81 | // Use the src bits to permute the target bits b[N:M] containing the |
| 82 | // vector offset. |
| 83 | if (permuteEveryN > 1) { |
| 84 | int64_t shlBits = n - llvm::Log2_64(Value: permuteEveryN); |
| 85 | if (shlBits > 0) { |
| 86 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: shlBits); |
| 87 | srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); |
| 88 | } else if (shlBits < 0) { |
| 89 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args: -1 * shlBits); |
| 90 | srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal); |
| 91 | } |
| 92 | } else { |
| 93 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: n); |
| 94 | srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); |
| 95 | } |
| 96 | |
| 97 | Value permutedVectorIdx = |
| 98 | b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits); |
| 99 | return permutedVectorIdx; |
| 100 | } |
| 101 | |
| 102 | static void transformIndices(OpBuilder &builder, Location loc, |
| 103 | SmallVector<Value, 4> &indices, |
| 104 | MemRefType memrefTy, int64_t srcDim, |
| 105 | int64_t tgtDim) { |
| 106 | indices[tgtDim] = |
| 107 | permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); |
| 108 | } |
| 109 | |
| 110 | /// Return all operations within `parentOp` that read from or write to |
| 111 | /// `shmMemRef`. |
| 112 | static LogicalResult |
| 113 | getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, |
| 114 | SmallVector<Operation *, 16> &readOps, |
| 115 | SmallVector<Operation *, 16> &writeOps) { |
| 116 | parentOp->walk(callback: [&](Operation *op) { |
| 117 | MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op); |
| 118 | if (!iface) |
| 119 | return; |
| 120 | std::optional<MemoryEffects::EffectInstance> effect = |
| 121 | iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef); |
| 122 | if (effect) { |
| 123 | readOps.push_back(Elt: op); |
| 124 | return; |
| 125 | } |
| 126 | effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef); |
| 127 | if (effect) |
| 128 | writeOps.push_back(Elt: op); |
| 129 | }); |
| 130 | |
| 131 | // Restrict to a supported set of ops. We also require at least 2D access, |
| 132 | // although this could be relaxed. |
| 133 | if (llvm::any_of(Range&: readOps, P: [](Operation *op) { |
| 134 | return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) || |
| 135 | getIndices(op).size() < 2; |
| 136 | })) |
| 137 | return failure(); |
| 138 | if (llvm::any_of(Range&: writeOps, P: [](Operation *op) { |
| 139 | return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>( |
| 140 | op) || |
| 141 | getIndices(op).size() < 2; |
| 142 | })) |
| 143 | return failure(); |
| 144 | |
| 145 | return success(); |
| 146 | } |
| 147 | |
| 148 | llvm::LogicalResult |
| 149 | mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, |
| 150 | Value memrefValue) { |
| 151 | auto memRefType = dyn_cast<MemRefType>(memrefValue.getType()); |
| 152 | if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) |
| 153 | return failure(); |
| 154 | |
| 155 | // Not support 0D MemRefs. |
| 156 | if (memRefType.getRank() == 0) |
| 157 | return failure(); |
| 158 | |
| 159 | // Abort if the given value has any sub-views; we do not do any alias |
| 160 | // analysis. |
| 161 | bool hasSubView = false; |
| 162 | parentOp->walk(callback: [&](memref::SubViewOp subView) { hasSubView = true; }); |
| 163 | if (hasSubView) |
| 164 | return failure(); |
| 165 | |
| 166 | // Check if this is necessary given the assumption of 128b accesses: |
| 167 | // If dim[rank-1] is small enough to fit 8 rows in a 128B line. |
| 168 | const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); |
| 169 | const int64_t rowsPerLine = |
| 170 | (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / |
| 171 | rowSize; |
| 172 | const int64_t threadGroupSize = |
| 173 | 1LL << (7 - llvm::Log2_64(Value: kDefaultVectorSizeBits / 8)); |
| 174 | if (rowsPerLine >= threadGroupSize) |
| 175 | return failure(); |
| 176 | |
| 177 | // Get sets of operations within the function that read/write to shared |
| 178 | // memory. |
| 179 | SmallVector<Operation *, 16> shmReadOps; |
| 180 | SmallVector<Operation *, 16> shmWriteOps; |
| 181 | if (failed(Result: getShmReadAndWriteOps(parentOp, shmMemRef: memrefValue, readOps&: shmReadOps, |
| 182 | writeOps&: shmWriteOps))) |
| 183 | return failure(); |
| 184 | |
| 185 | if (shmReadOps.empty() || shmWriteOps.empty()) |
| 186 | return failure(); |
| 187 | |
| 188 | OpBuilder builder(parentOp->getContext()); |
| 189 | |
| 190 | int64_t tgtDim = memRefType.getRank() - 1; |
| 191 | int64_t srcDim = memRefType.getRank() - 2; |
| 192 | |
| 193 | // Transform indices for the ops writing to shared memory. |
| 194 | while (!shmWriteOps.empty()) { |
| 195 | Operation *shmWriteOp = shmWriteOps.back(); |
| 196 | shmWriteOps.pop_back(); |
| 197 | builder.setInsertionPoint(shmWriteOp); |
| 198 | |
| 199 | auto indices = getIndices(op: shmWriteOp); |
| 200 | SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); |
| 201 | transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, |
| 202 | memRefType, srcDim, tgtDim); |
| 203 | setIndices(op: shmWriteOp, indices: transformedIndices); |
| 204 | } |
| 205 | |
| 206 | // Transform indices for the ops reading from shared memory. |
| 207 | while (!shmReadOps.empty()) { |
| 208 | Operation *shmReadOp = shmReadOps.back(); |
| 209 | shmReadOps.pop_back(); |
| 210 | builder.setInsertionPoint(shmReadOp); |
| 211 | |
| 212 | auto indices = getIndices(op: shmReadOp); |
| 213 | SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); |
| 214 | transformIndices(builder, shmReadOp->getLoc(), transformedIndices, |
| 215 | memRefType, srcDim, tgtDim); |
| 216 | setIndices(op: shmReadOp, indices: transformedIndices); |
| 217 | } |
| 218 | |
| 219 | return success(); |
| 220 | } |
| 221 | |
| 222 | namespace { |
| 223 | class OptimizeSharedMemoryPass |
| 224 | : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> { |
| 225 | public: |
| 226 | OptimizeSharedMemoryPass() = default; |
| 227 | |
| 228 | void runOnOperation() override { |
| 229 | Operation *op = getOperation(); |
| 230 | SmallVector<memref::AllocOp> shmAllocOps; |
| 231 | op->walk(callback: [&](memref::AllocOp allocOp) { |
| 232 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) |
| 233 | return; |
| 234 | shmAllocOps.push_back(allocOp); |
| 235 | }); |
| 236 | for (auto allocOp : shmAllocOps) { |
| 237 | if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), |
| 238 | allocOp.getMemref()))) |
| 239 | return; |
| 240 | } |
| 241 | } |
| 242 | }; |
| 243 | } // namespace |
| 244 | |
| 245 | std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() { |
| 246 | return std::make_unique<OptimizeSharedMemoryPass>(); |
| 247 | } |
| 248 | |