1//===- CreateAsyncGroups.cpp - Create async device copies -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/GPU/IR/GPUDialect.h"
13#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
14#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
15#include "mlir/Dialect/Vector/IR/VectorOps.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/BuiltinTypes.h"
18
19using namespace mlir;
20
21/// Return "true" if the given vector transfer op is contiguous and suitable
22/// for replacement with an async copy.
23template <typename OpTy>
24static bool isContiguousXferOp(OpTy op) {
25 return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
26 op.hasPureBufferSemantics() &&
27 isLastMemrefDimUnitStride(
28 cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
29}
30
31/// Return "true" if the given op is a contiguous and suitable
32/// vector.transfer_write or vector.store op.
33static bool isContiguousStore(Operation *write) {
34 if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
35 return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
36 // vector.store are always contiguous.
37 return isa<vector::StoreOp>(write);
38}
39
40/// Return "true" if the given op is a contiguous and suitable
41/// vector.transfer_read or vector.load op.
42static bool isContiguousRead(Operation *read) {
43 if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
44 return isContiguousXferOp(transferRead);
45 // vector.load are always contiguous.
46 return isa<vector::LoadOp>(read);
47}
48
49namespace {
50/// A vector.create_mask op and extract position.
51struct TransferMask {
52 vector::CreateMaskOp createMaskOp;
53 SmallVector<int64_t> extractPosition;
54};
55} // namespace
56
57/// If the given vector load op has a mask that is defined by
58/// vector.create_mask, return that op.
59static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
60 auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
61 if (!transferRead || !transferRead.getMask())
62 return TransferMask{{}, {}};
63 assert(transferRead.getMask().getType().getRank() == 1 &&
64 "expected 1-D mask");
65
66 // Case 1: Mask is the result of a vector.create_mask.
67 if (auto maskOp =
68 transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
69 return TransferMask{maskOp, {}};
70
71 // Case 2: Mask is the result of a vector.extract(vector.create_mask).
72 if (auto extractOp =
73 transferRead.getMask().getDefiningOp<vector::ExtractOp>())
74 if (auto maskOp =
75 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
76 return TransferMask{maskOp,
77 SmallVector<int64_t>(extractOp.getStaticPosition())};
78
79 // All other cases: not supported.
80 return failure();
81}
82
83/// Build an SSA value that represents the number of read elements.
84static Value buildNumReadElements(OpBuilder &b, Location loc,
85 Operation *readOp) {
86 FailureOr<TransferMask> transferMask = getMaskOp(readOp);
87 assert(succeeded(transferMask) && "invalid transfer mask");
88
89 // No mask => no num_read_elements.
90 if (!transferMask->createMaskOp)
91 return Value();
92
93 // No extract: return size of "ones" segment in the mask.
94 if (transferMask->extractPosition.empty()) {
95 assert(transferMask->createMaskOp.getNumOperands() == 1 &&
96 "expected single operand");
97 return transferMask->createMaskOp.getOperand(0);
98 }
99
100 // vector.extract(vector.create_mask).
101 // If extract_pos < num_ones, take number of elements from the least
102 // significant dimension. (Do this for all dimensions and bit-AND the
103 // conditions.)
104 assert(transferMask->createMaskOp.getVectorType().getRank() -
105 transferMask->extractPosition.size() ==
106 1 &&
107 "expected N-D -> (N-1)-D extract");
108 Value cond;
109 // Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
110 for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
111 transferMask->createMaskOp->getOperands())) {
112 Value cmp =
113 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
114 b.create<arith::ConstantIndexOp>(loc, pos), sz);
115 if (!cond) {
116 cond = cmp;
117 continue;
118 }
119 cond = b.create<arith::AndIOp>(loc, cmp, cond);
120 }
121 return b.create<arith::SelectOp>(
122 loc, cond, transferMask->createMaskOp->getOperands().back(),
123 b.create<arith::ConstantIndexOp>(loc, 0));
124}
125
126/// Return "true" if the conversion to async copy is supported by "async copy".
127static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
128 VectorType vecType) {
129 assert(vecType.getRank() == 1 && "expected 1-D vector");
130 constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
131
132 // Condition 1: the copy size must be supported.
133 bool supportedCopySize = false;
134 int64_t numElements = vecType.getNumElements();
135 Type elementType = vecType.getElementType();
136 for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
137 if (alignmentInBytes * 8 ==
138 numElements * elementType.getIntOrFloatBitWidth()) {
139 supportedCopySize = true;
140 break;
141 }
142 }
143 if (!supportedCopySize)
144 return false;
145
146 // TODO: Condition 2: the alignments must be supported. For cp.async the
147 // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to
148 // a multiple of the access size. If an address is not properly aligned, the
149 // resulting behavior is undefined.".
150 return true;
151}
152
153void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
154 bool bypassL1) {
155 llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
156
157 // Look for all the copy that can be converted to async copy ops.
158 op->walk(callback: [&](Operation *writeOp) {
159 // Look for contiguous 1D vector store into shared memory.
160 if (!isContiguousStore(write: writeOp))
161 return;
162 Value vectorVal = nvgpu::getValueStored(op: writeOp);
163 if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
164 return;
165 Value storeBase = nvgpu::getMemrefOperand(op: writeOp);
166 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
167 cast<MemRefType>(storeBase.getType())))
168 return;
169
170 // The stored vector must originate from a contiguous 1D vector load.
171 Operation *readOp = vectorVal.getDefiningOp();
172 if (readOp == nullptr || !isContiguousRead(read: readOp))
173 return;
174 Value loadBase = nvgpu::getMemrefOperand(op: readOp);
175 // Should be reading from global memory (not shared memory).
176 if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
177 cast<MemRefType>(loadBase.getType())))
178 return;
179
180 // Look for compatible mask and padding.
181 if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
182 if (Value mask = transferRead.getMask()) {
183 if (getConstantIntValue(transferRead.getPadding()) ==
184 static_cast<int64_t>(0))
185 return;
186 if (failed(getMaskOp(readOp)))
187 return;
188 }
189 }
190
191 // Check whether both accesses are supported before we emit: this is
192 // necessary to ensure the correctness of DeviceAsyncCopyOp.
193 VectorType vecType = cast<VectorType>(vectorVal.getType());
194
195 if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
196 vecType) ||
197 !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
198 vecType))
199 return;
200
201 copyToSharedMem.insert(X: writeOp);
202 return;
203 });
204
205 while (!copyToSharedMem.empty()) {
206 // Start a group with the first write.
207 SmallVector<Operation *> group;
208 Operation *writeOp = *copyToSharedMem.begin();
209 copyToSharedMem.remove(X: writeOp);
210 group.push_back(Elt: writeOp);
211 Operation *nextNode = writeOp;
212
213 // Look in the next nodes for more copies to add to the same group.
214 while ((nextNode = nextNode->getNextNode())) {
215 // Ignore ops without side effects.
216 auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
217 if (memInterface && memInterface.hasNoEffect() &&
218 !nextNode->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
219 continue;
220 // Ignore read from a different address space.
221 if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
222 Operation *readOp = nextNode;
223 Value memrefOperand = nvgpu::getMemrefOperand(op: readOp);
224 if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
225 cast<MemRefType>(memrefOperand.getType()))) {
226 continue;
227 }
228 }
229 if (copyToSharedMem.count(key: nextNode)) {
230 // Found another copy, add it to the group.
231 copyToSharedMem.remove(X: nextNode);
232 group.push_back(Elt: nextNode);
233 continue;
234 }
235 // If the op is something else stop the accumulating op in the group.
236 break;
237 }
238
239 // Emit the group.
240 SmallVector<Value> tokens;
241 for (Operation *writeOp : group) {
242 rewriter.setInsertionPoint(writeOp);
243 Value vectorVal = nvgpu::getValueStored(op: writeOp);
244 auto vectorType = cast<VectorType>(vectorVal.getType());
245 int64_t numElements = vectorType.getNumElements();
246 Operation *readOp = vectorVal.getDefiningOp();
247 Value storeBase = nvgpu::getMemrefOperand(op: writeOp);
248 Value loadBase = nvgpu::getMemrefOperand(op: readOp);
249 Value numReadElements =
250 buildNumReadElements(b&: rewriter, loc: writeOp->getLoc(), readOp);
251 auto dstMemref = cast<MemRefType>(storeBase.getType());
252 int64_t sizeInBytes =
253 (dstMemref.getElementTypeBitWidth() * numElements) / 8;
254 // bypass_l1 only possible with 16 byte transfer.
255 Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
256 writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
257 /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
258 /*src=*/loadBase,
259 /*srcIndices=*/nvgpu::getIndices(readOp),
260 /*dstElements=*/rewriter.getIndexAttr(numElements),
261 /*srcElements=*/numReadElements,
262 /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
263 : UnitAttr());
264 tokens.push_back(Elt: token);
265 }
266
267 // Create the group and wait for it right after.
268 Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
269 op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
270 tokens);
271 rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
272 nullptr);
273 // Clean up old stores.
274 for (Operation *writeOp : group)
275 rewriter.eraseOp(op: writeOp);
276 }
277}
278

source code of mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp