1 | //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements transforms to optimize accesses to shared memory. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/NVGPU/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
19 | #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
20 | #include "mlir/Dialect/NVGPU/Transforms/Utils.h" |
21 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "mlir/Support/LogicalResult.h" |
24 | #include "llvm/ADT/STLExtras.h" |
25 | #include "llvm/Support/MathExtras.h" |
26 | |
27 | namespace mlir { |
28 | namespace nvgpu { |
29 | #define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY |
30 | #include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc" |
31 | } // namespace nvgpu |
32 | } // namespace mlir |
33 | |
34 | using namespace mlir; |
35 | using namespace mlir::nvgpu; |
36 | |
37 | /// The size of a shared memory line according to NV documentation. |
38 | constexpr int64_t kSharedMemoryLineSizeBytes = 128; |
39 | /// We optimize for 128bit accesses, but this can be made an argument in the |
40 | /// future. |
41 | constexpr int64_t kDefaultVectorSizeBits = 128; |
42 | |
43 | /// Uses `srcIndexValue` to permute `tgtIndexValue` via |
44 | /// `result = xor(floordiv(srcIdxVal,permuteEveryN), |
45 | /// floordiv(tgtIdxVal,vectorSize))) |
46 | /// + tgtIdxVal % vectorSize` |
47 | /// This is done using an optimized sequence of `arith` operations. |
48 | static Value permuteVectorOffset(OpBuilder &b, Location loc, |
49 | ArrayRef<Value> indices, MemRefType memrefTy, |
50 | int64_t srcDim, int64_t tgtDim) { |
51 | // Adjust the src index to change how often the permutation changes |
52 | // if necessary. |
53 | Value src = indices[srcDim]; |
54 | |
55 | // We only want to permute every N iterations of the target dim where N is |
56 | // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). |
57 | const int64_t permuteEveryN = std::max<int64_t>( |
58 | 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * |
59 | memrefTy.getElementTypeBitWidth()) / |
60 | 8)); |
61 | |
62 | // clang-format off |
63 | // Index bit representation (b0 = least significant bit) for dim(1) |
64 | // of a `memref<?x?xDT>` is as follows: |
65 | // N := log2(128/elementSizeBits) |
66 | // M := log2(dimSize(1)) |
67 | // then |
68 | // bits[0:N] = sub-vector element offset |
69 | // bits[N:M] = vector index |
70 | // clang-format on |
71 | int64_t n = |
72 | llvm::Log2_64(Value: kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); |
73 | int64_t m = llvm::Log2_64(Value: memrefTy.getDimSize(tgtDim)); |
74 | |
75 | // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. |
76 | int64_t mask = (1LL << (m - n)) - 1; |
77 | if (permuteEveryN > 1) |
78 | mask = mask << llvm::Log2_64(Value: permuteEveryN); |
79 | Value srcBits = b.create<arith::ConstantIndexOp>(location: loc, args&: mask); |
80 | srcBits = b.create<arith::AndIOp>(loc, src, srcBits); |
81 | |
82 | // Use the src bits to permute the target bits b[N:M] containing the |
83 | // vector offset. |
84 | if (permuteEveryN > 1) { |
85 | int64_t shlBits = n - llvm::Log2_64(Value: permuteEveryN); |
86 | if (shlBits > 0) { |
87 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: shlBits); |
88 | srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); |
89 | } else if (shlBits < 0) { |
90 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args: -1 * shlBits); |
91 | srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal); |
92 | } |
93 | } else { |
94 | Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: n); |
95 | srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); |
96 | } |
97 | |
98 | Value permutedVectorIdx = |
99 | b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits); |
100 | return permutedVectorIdx; |
101 | } |
102 | |
103 | static void transformIndices(OpBuilder &builder, Location loc, |
104 | SmallVector<Value, 4> &indices, |
105 | MemRefType memrefTy, int64_t srcDim, |
106 | int64_t tgtDim) { |
107 | indices[tgtDim] = |
108 | permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); |
109 | } |
110 | |
111 | /// Return all operations within `parentOp` that read from or write to |
112 | /// `shmMemRef`. |
113 | static LogicalResult |
114 | getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, |
115 | SmallVector<Operation *, 16> &readOps, |
116 | SmallVector<Operation *, 16> &writeOps) { |
117 | parentOp->walk(callback: [&](Operation *op) { |
118 | MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op); |
119 | if (!iface) |
120 | return; |
121 | std::optional<MemoryEffects::EffectInstance> effect = |
122 | iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef); |
123 | if (effect) { |
124 | readOps.push_back(Elt: op); |
125 | return; |
126 | } |
127 | effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef); |
128 | if (effect) |
129 | writeOps.push_back(Elt: op); |
130 | }); |
131 | |
132 | // Restrict to a supported set of ops. We also require at least 2D access, |
133 | // although this could be relaxed. |
134 | if (llvm::any_of(Range&: readOps, P: [](Operation *op) { |
135 | return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) || |
136 | getIndices(op).size() < 2; |
137 | })) |
138 | return failure(); |
139 | if (llvm::any_of(Range&: writeOps, P: [](Operation *op) { |
140 | return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>( |
141 | op) || |
142 | getIndices(op).size() < 2; |
143 | })) |
144 | return failure(); |
145 | |
146 | return success(); |
147 | } |
148 | |
149 | mlir::LogicalResult |
150 | mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, |
151 | Value memrefValue) { |
152 | auto memRefType = dyn_cast<MemRefType>(memrefValue.getType()); |
153 | if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType)) |
154 | return failure(); |
155 | |
156 | // Abort if the given value has any sub-views; we do not do any alias |
157 | // analysis. |
158 | bool hasSubView = false; |
159 | parentOp->walk(callback: [&](memref::SubViewOp subView) { hasSubView = true; }); |
160 | if (hasSubView) |
161 | return failure(); |
162 | |
163 | // Check if this is necessary given the assumption of 128b accesses: |
164 | // If dim[rank-1] is small enough to fit 8 rows in a 128B line. |
165 | const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); |
166 | const int64_t rowsPerLine = |
167 | (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / |
168 | rowSize; |
169 | const int64_t threadGroupSize = |
170 | 1LL << (7 - llvm::Log2_64(Value: kDefaultVectorSizeBits / 8)); |
171 | if (rowsPerLine >= threadGroupSize) |
172 | return failure(); |
173 | |
174 | // Get sets of operations within the function that read/write to shared |
175 | // memory. |
176 | SmallVector<Operation *, 16> shmReadOps; |
177 | SmallVector<Operation *, 16> shmWriteOps; |
178 | if (failed(result: getShmReadAndWriteOps(parentOp, shmMemRef: memrefValue, readOps&: shmReadOps, |
179 | writeOps&: shmWriteOps))) |
180 | return failure(); |
181 | |
182 | if (shmReadOps.empty() || shmWriteOps.empty()) |
183 | return failure(); |
184 | |
185 | OpBuilder builder(parentOp->getContext()); |
186 | |
187 | int64_t tgtDim = memRefType.getRank() - 1; |
188 | int64_t srcDim = memRefType.getRank() - 2; |
189 | |
190 | // Transform indices for the ops writing to shared memory. |
191 | while (!shmWriteOps.empty()) { |
192 | Operation *shmWriteOp = shmWriteOps.back(); |
193 | shmWriteOps.pop_back(); |
194 | builder.setInsertionPoint(shmWriteOp); |
195 | |
196 | auto indices = getIndices(op: shmWriteOp); |
197 | SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); |
198 | transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, |
199 | memRefType, srcDim, tgtDim); |
200 | setIndices(op: shmWriteOp, indices: transformedIndices); |
201 | } |
202 | |
203 | // Transform indices for the ops reading from shared memory. |
204 | while (!shmReadOps.empty()) { |
205 | Operation *shmReadOp = shmReadOps.back(); |
206 | shmReadOps.pop_back(); |
207 | builder.setInsertionPoint(shmReadOp); |
208 | |
209 | auto indices = getIndices(op: shmReadOp); |
210 | SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); |
211 | transformIndices(builder, shmReadOp->getLoc(), transformedIndices, |
212 | memRefType, srcDim, tgtDim); |
213 | setIndices(op: shmReadOp, indices: transformedIndices); |
214 | } |
215 | |
216 | return success(); |
217 | } |
218 | |
219 | namespace { |
220 | class OptimizeSharedMemoryPass |
221 | : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> { |
222 | public: |
223 | OptimizeSharedMemoryPass() = default; |
224 | |
225 | void runOnOperation() override { |
226 | Operation *op = getOperation(); |
227 | SmallVector<memref::AllocOp> shmAllocOps; |
228 | op->walk(callback: [&](memref::AllocOp allocOp) { |
229 | if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType())) |
230 | return; |
231 | shmAllocOps.push_back(allocOp); |
232 | }); |
233 | for (auto allocOp : shmAllocOps) { |
234 | if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), |
235 | allocOp.getMemref()))) |
236 | return; |
237 | } |
238 | } |
239 | }; |
240 | } // namespace |
241 | |
242 | std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() { |
243 | return std::make_unique<OptimizeSharedMemoryPass>(); |
244 | } |
245 | |