1//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to optimize accesses to shared memory.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
19#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
20#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
21#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "mlir/Interfaces/SideEffectInterfaces.h"
23#include "mlir/Support/LogicalResult.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/Support/MathExtras.h"
26
27namespace mlir {
28namespace nvgpu {
29#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
30#include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
31} // namespace nvgpu
32} // namespace mlir
33
34using namespace mlir;
35using namespace mlir::nvgpu;
36
37/// The size of a shared memory line according to NV documentation.
38constexpr int64_t kSharedMemoryLineSizeBytes = 128;
39/// We optimize for 128bit accesses, but this can be made an argument in the
40/// future.
41constexpr int64_t kDefaultVectorSizeBits = 128;
42
43/// Uses `srcIndexValue` to permute `tgtIndexValue` via
44/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
45/// floordiv(tgtIdxVal,vectorSize)))
46/// + tgtIdxVal % vectorSize`
47/// This is done using an optimized sequence of `arith` operations.
48static Value permuteVectorOffset(OpBuilder &b, Location loc,
49 ArrayRef<Value> indices, MemRefType memrefTy,
50 int64_t srcDim, int64_t tgtDim) {
51 // Adjust the src index to change how often the permutation changes
52 // if necessary.
53 Value src = indices[srcDim];
54
55 // We only want to permute every N iterations of the target dim where N is
56 // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
57 const int64_t permuteEveryN = std::max<int64_t>(
58 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
59 memrefTy.getElementTypeBitWidth()) /
60 8));
61
62 // clang-format off
63 // Index bit representation (b0 = least significant bit) for dim(1)
64 // of a `memref<?x?xDT>` is as follows:
65 // N := log2(128/elementSizeBits)
66 // M := log2(dimSize(1))
67 // then
68 // bits[0:N] = sub-vector element offset
69 // bits[N:M] = vector index
70 // clang-format on
71 int64_t n =
72 llvm::Log2_64(Value: kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
73 int64_t m = llvm::Log2_64(Value: memrefTy.getDimSize(tgtDim));
74
75 // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
76 int64_t mask = (1LL << (m - n)) - 1;
77 if (permuteEveryN > 1)
78 mask = mask << llvm::Log2_64(Value: permuteEveryN);
79 Value srcBits = b.create<arith::ConstantIndexOp>(location: loc, args&: mask);
80 srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
81
82 // Use the src bits to permute the target bits b[N:M] containing the
83 // vector offset.
84 if (permuteEveryN > 1) {
85 int64_t shlBits = n - llvm::Log2_64(Value: permuteEveryN);
86 if (shlBits > 0) {
87 Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: shlBits);
88 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
89 } else if (shlBits < 0) {
90 Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args: -1 * shlBits);
91 srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
92 }
93 } else {
94 Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: n);
95 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
96 }
97
98 Value permutedVectorIdx =
99 b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
100 return permutedVectorIdx;
101}
102
103static void transformIndices(OpBuilder &builder, Location loc,
104 SmallVector<Value, 4> &indices,
105 MemRefType memrefTy, int64_t srcDim,
106 int64_t tgtDim) {
107 indices[tgtDim] =
108 permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
109}
110
111/// Return all operations within `parentOp` that read from or write to
112/// `shmMemRef`.
113static LogicalResult
114getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
115 SmallVector<Operation *, 16> &readOps,
116 SmallVector<Operation *, 16> &writeOps) {
117 parentOp->walk(callback: [&](Operation *op) {
118 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
119 if (!iface)
120 return;
121 std::optional<MemoryEffects::EffectInstance> effect =
122 iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
123 if (effect) {
124 readOps.push_back(Elt: op);
125 return;
126 }
127 effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
128 if (effect)
129 writeOps.push_back(Elt: op);
130 });
131
132 // Restrict to a supported set of ops. We also require at least 2D access,
133 // although this could be relaxed.
134 if (llvm::any_of(Range&: readOps, P: [](Operation *op) {
135 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
136 getIndices(op).size() < 2;
137 }))
138 return failure();
139 if (llvm::any_of(Range&: writeOps, P: [](Operation *op) {
140 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
141 op) ||
142 getIndices(op).size() < 2;
143 }))
144 return failure();
145
146 return success();
147}
148
149mlir::LogicalResult
150mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
151 Value memrefValue) {
152 auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
153 if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
154 return failure();
155
156 // Abort if the given value has any sub-views; we do not do any alias
157 // analysis.
158 bool hasSubView = false;
159 parentOp->walk(callback: [&](memref::SubViewOp subView) { hasSubView = true; });
160 if (hasSubView)
161 return failure();
162
163 // Check if this is necessary given the assumption of 128b accesses:
164 // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
165 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
166 const int64_t rowsPerLine =
167 (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
168 rowSize;
169 const int64_t threadGroupSize =
170 1LL << (7 - llvm::Log2_64(Value: kDefaultVectorSizeBits / 8));
171 if (rowsPerLine >= threadGroupSize)
172 return failure();
173
174 // Get sets of operations within the function that read/write to shared
175 // memory.
176 SmallVector<Operation *, 16> shmReadOps;
177 SmallVector<Operation *, 16> shmWriteOps;
178 if (failed(result: getShmReadAndWriteOps(parentOp, shmMemRef: memrefValue, readOps&: shmReadOps,
179 writeOps&: shmWriteOps)))
180 return failure();
181
182 if (shmReadOps.empty() || shmWriteOps.empty())
183 return failure();
184
185 OpBuilder builder(parentOp->getContext());
186
187 int64_t tgtDim = memRefType.getRank() - 1;
188 int64_t srcDim = memRefType.getRank() - 2;
189
190 // Transform indices for the ops writing to shared memory.
191 while (!shmWriteOps.empty()) {
192 Operation *shmWriteOp = shmWriteOps.back();
193 shmWriteOps.pop_back();
194 builder.setInsertionPoint(shmWriteOp);
195
196 auto indices = getIndices(op: shmWriteOp);
197 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
198 transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
199 memRefType, srcDim, tgtDim);
200 setIndices(op: shmWriteOp, indices: transformedIndices);
201 }
202
203 // Transform indices for the ops reading from shared memory.
204 while (!shmReadOps.empty()) {
205 Operation *shmReadOp = shmReadOps.back();
206 shmReadOps.pop_back();
207 builder.setInsertionPoint(shmReadOp);
208
209 auto indices = getIndices(op: shmReadOp);
210 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
211 transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
212 memRefType, srcDim, tgtDim);
213 setIndices(op: shmReadOp, indices: transformedIndices);
214 }
215
216 return success();
217}
218
219namespace {
220class OptimizeSharedMemoryPass
221 : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
222public:
223 OptimizeSharedMemoryPass() = default;
224
225 void runOnOperation() override {
226 Operation *op = getOperation();
227 SmallVector<memref::AllocOp> shmAllocOps;
228 op->walk(callback: [&](memref::AllocOp allocOp) {
229 if (!NVGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
230 return;
231 shmAllocOps.push_back(allocOp);
232 });
233 for (auto allocOp : shmAllocOps) {
234 if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
235 allocOp.getMemref())))
236 return;
237 }
238 }
239};
240} // namespace
241
242std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() {
243 return std::make_unique<OptimizeSharedMemoryPass>();
244}
245

source code of mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp