1 | //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements transforms to optimize accesses to shared memory. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/NVGPU/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
19 | #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/NVGPU/Transforms/Utils.h" |
21 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "llvm/ADT/STLExtras.h" |
24 | #include "llvm/Support/MathExtras.h" |
25 | |
26 | namespace mlir { |
27 | namespace nvgpu { |
28 | #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY |
29 | #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc" |
30 | } // namespace nvgpu |
31 | } // namespace mlir |
32 | |
33 | using namespace mlir; |
34 | using namespace mlir::nvgpu; |
35 | |
36 | /// The size of a shared memory line according to NV documentation. |
37 | constexpr int64_t kSharedMemoryLineSizeBytes = 128; |
38 | /// We optimize for 128bit accesses, but this can be made an argument in the |
39 | /// future. |
40 | constexpr int64_t kDefaultVectorSizeBits = 128; |
41 | |
42 | /// Uses `srcIndexValue` to permute `tgtIndexValue` via |
43 | /// `result = xor(floordiv(srcIdxVal,permuteEveryN), |
44 | /// floordiv(tgtIdxVal,vectorSize))) |
45 | /// + tgtIdxVal % vectorSize` |
46 | /// This is done using an optimized sequence of `arith` operations. |
47 | static Value permuteVectorOffset(OpBuilder &b, Location loc, |
48 | ArrayRef<Value> indices, MemRefType memrefTy, |
49 | int64_t srcDim, int64_t tgtDim) { |
50 | // Adjust the src index to change how often the permutation changes |
51 | // if necessary. |
52 | Value src = indices[srcDim]; |
53 | |
54 | // We only want to permute every N iterations of the target dim where N is |
55 | // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). |
56 | const int64_t permuteEveryN = std::max<int64_t>( |
57 | 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * |
58 | memrefTy.getElementTypeBitWidth()) / |
59 | 8)); |
60 | |
61 | // clang-format off |
62 | // Index bit representation (b0 = least significant bit) for dim(1) |
63 | // of a `memref<?x?xDT>` is as follows: |
64 | // N := log2(128/elementSizeBits) |
65 | // M := log2(dimSize(1)) |
66 | // then |
67 | // bits[0:N] = sub-vector element offset |
68 | // bits[N:M] = vector index |
69 | // clang-format on |
70 | int64_t n = |
71 | llvm::Log2_64(Value: kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); |
72 | int64_t m = llvm::Log2_64(Value: memrefTy.getDimSize(tgtDim)); |
73 | |
74 | // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. |
75 | int64_t mask = (1LL << (m - n)) - 1; |
76 | if (permuteEveryN > 1) |
77 | mask = mask << llvm::Log2_64(Value: permuteEveryN); |
78 | Value srcBits = b.create<arith::ConstantIndexOp>(location: loc, args&: mask); |
79 | srcBits = b.create<arith::AndIOp>(loc, src, srcBits); |
80 | |
81 | // Use the src bits to permute the target bits b[N:M] containing the |
82 | // vector offset. |
83 | if (permuteEveryN > 1) { |
84 | int64_t shlBits = n - llvm::Log2_64(Value: permuteEveryN); |
85 | if (shlBits > 0) { |
86 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: shlBits); |
87 | srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); |
88 | } else if (shlBits < 0) { |
89 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args: -1 * shlBits); |
90 | srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal); |
91 | } |
92 | } else { |
93 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: n); |
94 | srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); |
95 | } |
96 | |
97 | Value permutedVectorIdx = |
98 | b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits); |
99 | return permutedVectorIdx; |
100 | } |
101 | |
102 | static void transformIndices(OpBuilder &builder, Location loc, |
103 | SmallVector<Value, 4> &indices, |
104 | MemRefType memrefTy, int64_t srcDim, |
105 | int64_t tgtDim) { |
106 | indices[tgtDim] = |
107 | permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); |
108 | } |
109 | |
110 | /// Return all operations within `parentOp` that read from or write to |
111 | /// `shmMemRef`. |
112 | static LogicalResult |
113 | getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, |
114 | SmallVector<Operation *, 16> &readOps, |
115 | SmallVector<Operation *, 16> &writeOps) { |
116 | parentOp->walk(callback: [&](Operation *op) { |
117 | MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op); |
118 | if (!iface) |
119 | return; |
120 | std::optional<MemoryEffects::EffectInstance> effect = |
121 | iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef); |
122 | if (effect) { |
123 | readOps.push_back(Elt: op); |
124 | return; |
125 | } |
126 | effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef); |
127 | if (effect) |
128 | writeOps.push_back(Elt: op); |
129 | }); |
130 | |
131 | // Restrict to a supported set of ops. We also require at least 2D access, |
132 | // although this could be relaxed. |
133 | if (llvm::any_of(Range&: readOps, P: [](Operation *op) { |
134 | return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) || |
135 | getIndices(op).size() < 2; |
136 | })) |
137 | return failure(); |
138 | if (llvm::any_of(Range&: writeOps, P: [](Operation *op) { |
139 | return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>( |
140 | op) || |
141 | getIndices(op).size() < 2; |
142 | })) |
143 | return failure(); |
144 | |
145 | return success(); |
146 | } |
147 | |
148 | llvm::LogicalResult |
149 | mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, |
150 | Value memrefValue) { |
151 | auto memRefType = dyn_cast<MemRefType>(memrefValue.getType()); |
152 | if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) |
153 | return failure(); |
154 | |
155 | // Not support 0D MemRefs. |
156 | if (memRefType.getRank() == 0) |
157 | return failure(); |
158 | |
159 | // Abort if the given value has any sub-views; we do not do any alias |
160 | // analysis. |
161 | bool hasSubView = false; |
162 | parentOp->walk(callback: [&](memref::SubViewOp subView) { hasSubView = true; }); |
163 | if (hasSubView) |
164 | return failure(); |
165 | |
166 | // Check if this is necessary given the assumption of 128b accesses: |
167 | // If dim[rank-1] is small enough to fit 8 rows in a 128B line. |
168 | const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); |
169 | const int64_t rowsPerLine = |
170 | (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / |
171 | rowSize; |
172 | const int64_t threadGroupSize = |
173 | 1LL << (7 - llvm::Log2_64(Value: kDefaultVectorSizeBits / 8)); |
174 | if (rowsPerLine >= threadGroupSize) |
175 | return failure(); |
176 | |
177 | // Get sets of operations within the function that read/write to shared |
178 | // memory. |
179 | SmallVector<Operation *, 16> shmReadOps; |
180 | SmallVector<Operation *, 16> shmWriteOps; |
181 | if (failed(Result: getShmReadAndWriteOps(parentOp, shmMemRef: memrefValue, readOps&: shmReadOps, |
182 | writeOps&: shmWriteOps))) |
183 | return failure(); |
184 | |
185 | if (shmReadOps.empty() || shmWriteOps.empty()) |
186 | return failure(); |
187 | |
188 | OpBuilder builder(parentOp->getContext()); |
189 | |
190 | int64_t tgtDim = memRefType.getRank() - 1; |
191 | int64_t srcDim = memRefType.getRank() - 2; |
192 | |
193 | // Transform indices for the ops writing to shared memory. |
194 | while (!shmWriteOps.empty()) { |
195 | Operation *shmWriteOp = shmWriteOps.back(); |
196 | shmWriteOps.pop_back(); |
197 | builder.setInsertionPoint(shmWriteOp); |
198 | |
199 | auto indices = getIndices(op: shmWriteOp); |
200 | SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); |
201 | transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, |
202 | memRefType, srcDim, tgtDim); |
203 | setIndices(op: shmWriteOp, indices: transformedIndices); |
204 | } |
205 | |
206 | // Transform indices for the ops reading from shared memory. |
207 | while (!shmReadOps.empty()) { |
208 | Operation *shmReadOp = shmReadOps.back(); |
209 | shmReadOps.pop_back(); |
210 | builder.setInsertionPoint(shmReadOp); |
211 | |
212 | auto indices = getIndices(op: shmReadOp); |
213 | SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); |
214 | transformIndices(builder, shmReadOp->getLoc(), transformedIndices, |
215 | memRefType, srcDim, tgtDim); |
216 | setIndices(op: shmReadOp, indices: transformedIndices); |
217 | } |
218 | |
219 | return success(); |
220 | } |
221 | |
222 | namespace { |
223 | class OptimizeSharedMemoryPass |
224 | : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> { |
225 | public: |
226 | OptimizeSharedMemoryPass() = default; |
227 | |
228 | void runOnOperation() override { |
229 | Operation *op = getOperation(); |
230 | SmallVector<memref::AllocOp> shmAllocOps; |
231 | op->walk(callback: [&](memref::AllocOp allocOp) { |
232 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) |
233 | return; |
234 | shmAllocOps.push_back(allocOp); |
235 | }); |
236 | for (auto allocOp : shmAllocOps) { |
237 | if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), |
238 | allocOp.getMemref()))) |
239 | return; |
240 | } |
241 | } |
242 | }; |
243 | } // namespace |
244 | |
245 | std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() { |
246 | return std::make_unique<OptimizeSharedMemoryPass>(); |
247 | } |
248 | |