1 | //===- CreateAsyncGroups.cpp - Create async device copies -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
13 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
14 | #include "mlir/Dialect/NVGPU/Transforms/Utils.h" |
15 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | /// Return "true" if the given vector transfer op is contiguous and suitable |
22 | /// for replacement with an async copy. |
23 | template <typename OpTy> |
24 | static bool isContiguousXferOp(OpTy op) { |
25 | return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) && |
26 | op.hasPureBufferSemantics() && |
27 | isLastMemrefDimUnitStride( |
28 | cast<MemRefType>(nvgpu::getMemrefOperand(op).getType())); |
29 | } |
30 | |
31 | /// Return "true" if the given op is a contiguous and suitable |
32 | /// vector.transfer_write or vector.store op. |
33 | static bool isContiguousStore(Operation *write) { |
34 | if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write)) |
35 | return isContiguousXferOp(transferWrite) && !transferWrite.getMask(); |
36 | // vector.store are always contiguous. |
37 | return isa<vector::StoreOp>(write); |
38 | } |
39 | |
40 | /// Return "true" if the given op is a contiguous and suitable |
41 | /// vector.transfer_read or vector.load op. |
42 | static bool isContiguousRead(Operation *read) { |
43 | if (auto transferRead = dyn_cast<vector::TransferReadOp>(read)) |
44 | return isContiguousXferOp(transferRead); |
45 | // vector.load are always contiguous. |
46 | return isa<vector::LoadOp>(read); |
47 | } |
48 | |
49 | namespace { |
50 | /// A vector.create_mask op and extract position. |
51 | struct TransferMask { |
52 | vector::CreateMaskOp createMaskOp; |
53 | SmallVector<int64_t> ; |
54 | }; |
55 | } // namespace |
56 | |
57 | /// If the given vector load op has a mask that is defined by |
58 | /// vector.create_mask, return that op. |
59 | static FailureOr<TransferMask> getMaskOp(Operation *loadOp) { |
60 | auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp); |
61 | if (!transferRead || !transferRead.getMask()) |
62 | return TransferMask{{}, {}}; |
63 | assert(transferRead.getMask().getType().getRank() == 1 && |
64 | "expected 1-D mask" ); |
65 | |
66 | // Case 1: Mask is the result of a vector.create_mask. |
67 | if (auto maskOp = |
68 | transferRead.getMask().getDefiningOp<vector::CreateMaskOp>()) |
69 | return TransferMask{maskOp, {}}; |
70 | |
71 | // Case 2: Mask is the result of a vector.extract(vector.create_mask). |
72 | if (auto extractOp = |
73 | transferRead.getMask().getDefiningOp<vector::ExtractOp>()) |
74 | if (auto maskOp = |
75 | extractOp.getVector().getDefiningOp<vector::CreateMaskOp>()) |
76 | return TransferMask{maskOp, |
77 | SmallVector<int64_t>(extractOp.getStaticPosition())}; |
78 | |
79 | // All other cases: not supported. |
80 | return failure(); |
81 | } |
82 | |
83 | /// Build an SSA value that represents the number of read elements. |
84 | static Value buildNumReadElements(OpBuilder &b, Location loc, |
85 | Operation *readOp) { |
86 | FailureOr<TransferMask> transferMask = getMaskOp(readOp); |
87 | assert(succeeded(transferMask) && "invalid transfer mask" ); |
88 | |
89 | // No mask => no num_read_elements. |
90 | if (!transferMask->createMaskOp) |
91 | return Value(); |
92 | |
93 | // No extract: return size of "ones" segment in the mask. |
94 | if (transferMask->extractPosition.empty()) { |
95 | assert(transferMask->createMaskOp.getNumOperands() == 1 && |
96 | "expected single operand" ); |
97 | return transferMask->createMaskOp.getOperand(0); |
98 | } |
99 | |
100 | // vector.extract(vector.create_mask). |
101 | // If extract_pos < num_ones, take number of elements from the least |
102 | // significant dimension. (Do this for all dimensions and bit-AND the |
103 | // conditions.) |
104 | assert(transferMask->createMaskOp.getVectorType().getRank() - |
105 | transferMask->extractPosition.size() == |
106 | 1 && |
107 | "expected N-D -> (N-1)-D extract" ); |
108 | Value cond; |
109 | // Note: There is one more `sz` than `pos`. The loop end with the last `pos`. |
110 | for (auto [pos, sz] : llvm::zip(transferMask->extractPosition, |
111 | transferMask->createMaskOp->getOperands())) { |
112 | Value cmp = |
113 | b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, |
114 | b.create<arith::ConstantIndexOp>(loc, pos), sz); |
115 | if (!cond) { |
116 | cond = cmp; |
117 | continue; |
118 | } |
119 | cond = b.create<arith::AndIOp>(loc, cmp, cond); |
120 | } |
121 | return b.create<arith::SelectOp>( |
122 | loc, cond, transferMask->createMaskOp->getOperands().back(), |
123 | b.create<arith::ConstantIndexOp>(loc, 0)); |
124 | } |
125 | |
126 | /// Return "true" if the conversion to async copy is supported by "async copy". |
127 | static bool resultsInSupportedAsyncCopy(MemRefType memrefType, |
128 | VectorType vecType) { |
129 | assert(vecType.getRank() == 1 && "expected 1-D vector" ); |
130 | constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16}; |
131 | |
132 | // Condition 1: the copy size must be supported. |
133 | bool supportedCopySize = false; |
134 | int64_t numElements = vecType.getNumElements(); |
135 | Type elementType = vecType.getElementType(); |
136 | for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) { |
137 | if (alignmentInBytes * 8 == |
138 | numElements * elementType.getIntOrFloatBitWidth()) { |
139 | supportedCopySize = true; |
140 | break; |
141 | } |
142 | } |
143 | if (!supportedCopySize) |
144 | return false; |
145 | |
146 | // TODO: Condition 2: the alignments must be supported. For cp.async the |
147 | // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to |
148 | // a multiple of the access size. If an address is not properly aligned, the |
149 | // resulting behavior is undefined.". |
150 | return true; |
151 | } |
152 | |
153 | void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, |
154 | bool bypassL1) { |
155 | llvm::SmallSetVector<Operation *, 16> copyToSharedMem; |
156 | |
157 | // Look for all the copy that can be converted to async copy ops. |
158 | op->walk(callback: [&](Operation *writeOp) { |
159 | // Look for contiguous 1D vector store into shared memory. |
160 | if (!isContiguousStore(write: writeOp)) |
161 | return; |
162 | Value vectorVal = nvgpu::getValueStored(op: writeOp); |
163 | if (cast<VectorType>(vectorVal.getType()).getRank() != 1) |
164 | return; |
165 | Value storeBase = nvgpu::getMemrefOperand(op: writeOp); |
166 | if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( |
167 | cast<MemRefType>(storeBase.getType()))) |
168 | return; |
169 | |
170 | // The stored vector must originate from a contiguous 1D vector load. |
171 | Operation *readOp = vectorVal.getDefiningOp(); |
172 | if (readOp == nullptr || !isContiguousRead(read: readOp)) |
173 | return; |
174 | Value loadBase = nvgpu::getMemrefOperand(op: readOp); |
175 | // Should be reading from global memory (not shared memory). |
176 | if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( |
177 | cast<MemRefType>(loadBase.getType()))) |
178 | return; |
179 | |
180 | // Look for compatible mask and padding. |
181 | if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) { |
182 | if (Value mask = transferRead.getMask()) { |
183 | if (getConstantIntValue(transferRead.getPadding()) == |
184 | static_cast<int64_t>(0)) |
185 | return; |
186 | if (failed(getMaskOp(readOp))) |
187 | return; |
188 | } |
189 | } |
190 | |
191 | // Check whether both accesses are supported before we emit: this is |
192 | // necessary to ensure the correctness of DeviceAsyncCopyOp. |
193 | VectorType vecType = cast<VectorType>(vectorVal.getType()); |
194 | |
195 | if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()), |
196 | vecType) || |
197 | !resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()), |
198 | vecType)) |
199 | return; |
200 | |
201 | copyToSharedMem.insert(X: writeOp); |
202 | return; |
203 | }); |
204 | |
205 | while (!copyToSharedMem.empty()) { |
206 | // Start a group with the first write. |
207 | SmallVector<Operation *> group; |
208 | Operation *writeOp = *copyToSharedMem.begin(); |
209 | copyToSharedMem.remove(X: writeOp); |
210 | group.push_back(Elt: writeOp); |
211 | Operation *nextNode = writeOp; |
212 | |
213 | // Look in the next nodes for more copies to add to the same group. |
214 | while ((nextNode = nextNode->getNextNode())) { |
215 | // Ignore ops without side effects. |
216 | auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode); |
217 | if (memInterface && memInterface.hasNoEffect() && |
218 | !nextNode->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) |
219 | continue; |
220 | // Ignore read from a different address space. |
221 | if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) { |
222 | Operation *readOp = nextNode; |
223 | Value memrefOperand = nvgpu::getMemrefOperand(op: readOp); |
224 | if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( |
225 | cast<MemRefType>(memrefOperand.getType()))) { |
226 | continue; |
227 | } |
228 | } |
229 | if (copyToSharedMem.count(key: nextNode)) { |
230 | // Found another copy, add it to the group. |
231 | copyToSharedMem.remove(X: nextNode); |
232 | group.push_back(Elt: nextNode); |
233 | continue; |
234 | } |
235 | // If the op is something else stop the accumulating op in the group. |
236 | break; |
237 | } |
238 | |
239 | // Emit the group. |
240 | SmallVector<Value> tokens; |
241 | for (Operation *writeOp : group) { |
242 | rewriter.setInsertionPoint(writeOp); |
243 | Value vectorVal = nvgpu::getValueStored(op: writeOp); |
244 | auto vectorType = cast<VectorType>(vectorVal.getType()); |
245 | int64_t numElements = vectorType.getNumElements(); |
246 | Operation *readOp = vectorVal.getDefiningOp(); |
247 | Value storeBase = nvgpu::getMemrefOperand(op: writeOp); |
248 | Value loadBase = nvgpu::getMemrefOperand(op: readOp); |
249 | Value numReadElements = |
250 | buildNumReadElements(b&: rewriter, loc: writeOp->getLoc(), readOp); |
251 | auto dstMemref = cast<MemRefType>(storeBase.getType()); |
252 | int64_t sizeInBytes = |
253 | (dstMemref.getElementTypeBitWidth() * numElements) / 8; |
254 | // bypass_l1 only possible with 16 byte transfer. |
255 | Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>( |
256 | writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), |
257 | /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp), |
258 | /*src=*/loadBase, |
259 | /*srcIndices=*/nvgpu::getIndices(readOp), |
260 | /*dstElements=*/rewriter.getIndexAttr(numElements), |
261 | /*srcElements=*/numReadElements, |
262 | /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr() |
263 | : UnitAttr()); |
264 | tokens.push_back(Elt: token); |
265 | } |
266 | |
267 | // Create the group and wait for it right after. |
268 | Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>( |
269 | op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), |
270 | tokens); |
271 | rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken, |
272 | nullptr); |
273 | // Clean up old stores. |
274 | for (Operation *writeOp : group) |
275 | rewriter.eraseOp(op: writeOp); |
276 | } |
277 | } |
278 | |