1//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements transforms to optimize accesses to shared memory.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
18#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
19#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
20#include "mlir/Dialect/Vector/IR/VectorOps.h"
21#include "mlir/Interfaces/SideEffectInterfaces.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/MathExtras.h"
24
25namespace mlir {
26namespace nvgpu {
27#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
28#include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
29} // namespace nvgpu
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::nvgpu;
34
35/// The size of a shared memory line according to NV documentation.
36constexpr int64_t kSharedMemoryLineSizeBytes = 128;
37/// We optimize for 128bit accesses, but this can be made an argument in the
38/// future.
39constexpr int64_t kDefaultVectorSizeBits = 128;
40
41/// Uses `srcIndexValue` to permute `tgtIndexValue` via
42/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
43/// floordiv(tgtIdxVal,vectorSize)))
44/// + tgtIdxVal % vectorSize`
45/// This is done using an optimized sequence of `arith` operations.
46static Value permuteVectorOffset(OpBuilder &b, Location loc,
47 ArrayRef<Value> indices, MemRefType memrefTy,
48 int64_t srcDim, int64_t tgtDim) {
49 // Adjust the src index to change how often the permutation changes
50 // if necessary.
51 Value src = indices[srcDim];
52
53 // We only want to permute every N iterations of the target dim where N is
54 // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
55 const int64_t permuteEveryN = std::max<int64_t>(
56 a: 1, b: kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(idx: tgtDim) *
57 memrefTy.getElementTypeBitWidth()) /
58 8));
59
60 // clang-format off
61 // Index bit representation (b0 = least significant bit) for dim(1)
62 // of a `memref<?x?xDT>` is as follows:
63 // N := log2(128/elementSizeBits)
64 // M := log2(dimSize(1))
65 // then
66 // bits[0:N] = sub-vector element offset
67 // bits[N:M] = vector index
68 // clang-format on
69 int64_t n =
70 llvm::Log2_64(Value: kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
71 int64_t m = llvm::Log2_64(Value: memrefTy.getDimSize(idx: tgtDim));
72
73 // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
74 int64_t mask = (1LL << (m - n)) - 1;
75 if (permuteEveryN > 1)
76 mask = mask << llvm::Log2_64(Value: permuteEveryN);
77 Value srcBits = b.create<arith::ConstantIndexOp>(location: loc, args&: mask);
78 srcBits = b.create<arith::AndIOp>(location: loc, args&: src, args&: srcBits);
79
80 // Use the src bits to permute the target bits b[N:M] containing the
81 // vector offset.
82 if (permuteEveryN > 1) {
83 int64_t shlBits = n - llvm::Log2_64(Value: permuteEveryN);
84 if (shlBits > 0) {
85 Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: shlBits);
86 srcBits = b.createOrFold<arith::ShLIOp>(location: loc, args&: srcBits, args&: finalShiftVal);
87 } else if (shlBits < 0) {
88 Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args: -1 * shlBits);
89 srcBits = b.createOrFold<arith::ShRUIOp>(location: loc, args&: srcBits, args&: finalShiftVal);
90 }
91 } else {
92 Value finalShiftVal = b.create<arith::ConstantIndexOp>(location: loc, args&: n);
93 srcBits = b.createOrFold<arith::ShLIOp>(location: loc, args&: srcBits, args&: finalShiftVal);
94 }
95
96 Value permutedVectorIdx =
97 b.create<arith::XOrIOp>(location: loc, args: indices[tgtDim], args&: srcBits);
98 return permutedVectorIdx;
99}
100
101static void transformIndices(OpBuilder &builder, Location loc,
102 SmallVector<Value, 4> &indices,
103 MemRefType memrefTy, int64_t srcDim,
104 int64_t tgtDim) {
105 indices[tgtDim] =
106 permuteVectorOffset(b&: builder, loc, indices, memrefTy, srcDim, tgtDim);
107}
108
109/// Return all operations within `parentOp` that read from or write to
110/// `shmMemRef`.
111static LogicalResult
112getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
113 SmallVector<Operation *, 16> &readOps,
114 SmallVector<Operation *, 16> &writeOps) {
115 parentOp->walk(callback: [&](Operation *op) {
116 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(Val: op);
117 if (!iface)
118 return;
119 std::optional<MemoryEffects::EffectInstance> effect =
120 iface.getEffectOnValue<MemoryEffects::Read>(value: shmMemRef);
121 if (effect) {
122 readOps.push_back(Elt: op);
123 return;
124 }
125 effect = iface.getEffectOnValue<MemoryEffects::Write>(value: shmMemRef);
126 if (effect)
127 writeOps.push_back(Elt: op);
128 });
129
130 // Restrict to a supported set of ops. We also require at least 2D access,
131 // although this could be relaxed.
132 if (llvm::any_of(Range&: readOps, P: [](Operation *op) {
133 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(Val: op) ||
134 getIndices(op).size() < 2;
135 }))
136 return failure();
137 if (llvm::any_of(Range&: writeOps, P: [](Operation *op) {
138 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
139 Val: op) ||
140 getIndices(op).size() < 2;
141 }))
142 return failure();
143
144 return success();
145}
146
147llvm::LogicalResult
148mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
149 Value memrefValue) {
150 auto memRefType = dyn_cast<MemRefType>(Val: memrefValue.getType());
151 if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(type: memRefType))
152 return failure();
153
154 // Not support 0D MemRefs.
155 if (memRefType.getRank() == 0)
156 return failure();
157
158 // Abort if the given value has any sub-views; we do not do any alias
159 // analysis.
160 bool hasSubView = false;
161 parentOp->walk(callback: [&](memref::SubViewOp subView) { hasSubView = true; });
162 if (hasSubView)
163 return failure();
164
165 // Check if this is necessary given the assumption of 128b accesses:
166 // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
167 const int64_t rowSize = memRefType.getDimSize(idx: memRefType.getRank() - 1);
168 const int64_t rowsPerLine =
169 (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
170 rowSize;
171 const int64_t threadGroupSize =
172 1LL << (7 - llvm::Log2_64(Value: kDefaultVectorSizeBits / 8));
173 if (rowsPerLine >= threadGroupSize)
174 return failure();
175
176 // Get sets of operations within the function that read/write to shared
177 // memory.
178 SmallVector<Operation *, 16> shmReadOps;
179 SmallVector<Operation *, 16> shmWriteOps;
180 if (failed(Result: getShmReadAndWriteOps(parentOp, shmMemRef: memrefValue, readOps&: shmReadOps,
181 writeOps&: shmWriteOps)))
182 return failure();
183
184 if (shmReadOps.empty() || shmWriteOps.empty())
185 return failure();
186
187 OpBuilder builder(parentOp->getContext());
188
189 int64_t tgtDim = memRefType.getRank() - 1;
190 int64_t srcDim = memRefType.getRank() - 2;
191
192 // Transform indices for the ops writing to shared memory.
193 while (!shmWriteOps.empty()) {
194 Operation *shmWriteOp = shmWriteOps.back();
195 shmWriteOps.pop_back();
196 builder.setInsertionPoint(shmWriteOp);
197
198 auto indices = getIndices(op: shmWriteOp);
199 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
200 transformIndices(builder, loc: shmWriteOp->getLoc(), indices&: transformedIndices,
201 memrefTy: memRefType, srcDim, tgtDim);
202 setIndices(op: shmWriteOp, indices: transformedIndices);
203 }
204
205 // Transform indices for the ops reading from shared memory.
206 while (!shmReadOps.empty()) {
207 Operation *shmReadOp = shmReadOps.back();
208 shmReadOps.pop_back();
209 builder.setInsertionPoint(shmReadOp);
210
211 auto indices = getIndices(op: shmReadOp);
212 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
213 transformIndices(builder, loc: shmReadOp->getLoc(), indices&: transformedIndices,
214 memrefTy: memRefType, srcDim, tgtDim);
215 setIndices(op: shmReadOp, indices: transformedIndices);
216 }
217
218 return success();
219}
220
221namespace {
222class OptimizeSharedMemoryPass
223 : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
224public:
225 OptimizeSharedMemoryPass() = default;
226
227 void runOnOperation() override {
228 Operation *op = getOperation();
229 SmallVector<memref::AllocOp> shmAllocOps;
230 op->walk(callback: [&](memref::AllocOp allocOp) {
231 if (!NVGPUDialect::hasSharedMemoryAddressSpace(type: allocOp.getType()))
232 return;
233 shmAllocOps.push_back(Elt: allocOp);
234 });
235 for (auto allocOp : shmAllocOps) {
236 if (failed(Result: optimizeSharedMemoryReadsAndWrites(parentOp: getOperation(),
237 memrefValue: allocOp.getMemref())))
238 return;
239 }
240 }
241};
242} // namespace
243
244std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() {
245 return std::make_unique<OptimizeSharedMemoryPass>();
246}
247

source code of mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp