1 | //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" |
10 | |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
23 | #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
26 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
27 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
28 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
29 | #include "mlir/IR/AffineExpr.h" |
30 | #include "mlir/IR/BuiltinTypes.h" |
31 | #include "mlir/IR/Value.h" |
32 | #include "llvm/ADT/ArrayRef.h" |
33 | |
34 | using namespace mlir; |
35 | using namespace mlir::linalg; |
36 | using namespace mlir::nvgpu; |
37 | using namespace mlir::NVVM; |
38 | using namespace mlir::transform; |
39 | |
40 | #define DEBUG_TYPE "nvgpu-transforms" |
41 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
42 | #define DBGSNL() (llvm::dbgs() << "\n") |
43 | #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
44 | |
45 | //===----------------------------------------------------------------------===// |
46 | // Apply...ConversionPatternsOp |
47 | //===----------------------------------------------------------------------===// |
48 | |
49 | void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( |
50 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
51 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
52 | /// device-side async tokens cannot be materialized in nvvm. We just |
53 | /// convert them to a dummy i32 type in order to easily drop them during |
54 | /// conversion. |
55 | populateGpuMemorySpaceAttributeConversions( |
56 | llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned { |
57 | switch (space) { |
58 | case gpu::AddressSpace::Global: |
59 | return static_cast<unsigned>( |
60 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
61 | case gpu::AddressSpace::Workgroup: |
62 | return static_cast<unsigned>( |
63 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
64 | case gpu::AddressSpace::Private: |
65 | return 0; |
66 | } |
67 | llvm_unreachable("unknown address space enum value" ); |
68 | return 0; |
69 | }); |
70 | llvmTypeConverter.addConversion( |
71 | [&](nvgpu::DeviceAsyncTokenType type) -> Type { |
72 | return llvmTypeConverter.convertType( |
73 | IntegerType::get(type.getContext(), 32)); |
74 | }); |
75 | llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { |
76 | return llvmTypeConverter.convertType( |
77 | IntegerType::get(type.getContext(), 64)); |
78 | }); |
79 | llvmTypeConverter.addConversion( |
80 | [&](nvgpu::WarpgroupAccumulatorType type) -> Type { |
81 | Type elemType = type.getFragmented().getElementType(); |
82 | int64_t sizeM = type.getFragmented().getDimSize(0); |
83 | int64_t sizeN = type.getFragmented().getDimSize(1); |
84 | |
85 | unsigned numMembers; |
86 | if (elemType.isF32() || elemType.isInteger(32)) |
87 | numMembers = sizeN / 2; |
88 | else if (elemType.isF16()) |
89 | numMembers = sizeN / 4; |
90 | else |
91 | llvm_unreachable("unsupported type for warpgroup accumulator" ); |
92 | |
93 | SmallVector<Type> innerStructBody; |
94 | for (unsigned i = 0; i < numMembers; i++) |
95 | innerStructBody.push_back(elemType); |
96 | auto innerStructType = LLVM::LLVMStructType::getLiteral( |
97 | type.getContext(), innerStructBody); |
98 | |
99 | SmallVector<Type> structBody; |
100 | for (int i = 0; i < sizeM; i += kWgmmaSizeM) |
101 | structBody.push_back(innerStructType); |
102 | |
103 | auto convertedType = |
104 | LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); |
105 | return llvmTypeConverter.convertType(convertedType); |
106 | }); |
107 | llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { |
108 | return llvmTypeConverter.convertType( |
109 | getMBarrierMemrefType(type.getContext(), type)); |
110 | }); |
111 | llvmTypeConverter.addConversion( |
112 | [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { |
113 | return llvmTypeConverter.convertType( |
114 | IntegerType::get(type.getContext(), 64)); |
115 | }); |
116 | llvmTypeConverter.addConversion( |
117 | [&](nvgpu::TensorMapDescriptorType type) -> Type { |
118 | return LLVM::LLVMPointerType::get(type.getContext()); |
119 | }); |
120 | populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); |
121 | } |
122 | |
123 | LogicalResult |
124 | transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
125 | transform::TypeConverterBuilderOpInterface builder) { |
126 | if (builder.getTypeConverterType() != "LLVMTypeConverter" ) |
127 | return emitOpError("expected LLVMTypeConverter" ); |
128 | return success(); |
129 | } |
130 | |
131 | //===---------------------------------------------------------------------===// |
132 | // CreateAsyncGroupsOp |
133 | //===---------------------------------------------------------------------===// |
134 | |
135 | void transform::CreateAsyncGroupsOp::getEffects( |
136 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
137 | transform::consumesHandle(getTarget(), effects); |
138 | transform::producesHandle(getResult(), effects); |
139 | transform::modifiesPayload(effects); |
140 | } |
141 | |
142 | DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( |
143 | TransformRewriter &rewriter, Operation *target, |
144 | ApplyToEachResultList &results, TransformState &state) { |
145 | nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); |
146 | results.push_back(target); |
147 | return DiagnosedSilenceableFailure::success(); |
148 | } |
149 | |
150 | //===----------------------------------------------------------------------===// |
151 | // PipelineSharedMemoryCopiesOp |
152 | //===----------------------------------------------------------------------===// |
153 | |
154 | /// Returns true if the given type has the default memory space. |
155 | static bool hasDefaultMemorySpace(BaseMemRefType type) { |
156 | return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0; |
157 | } |
158 | |
159 | /// Returns true if the given type has the shared (workgroup) memory space. |
160 | static bool hasSharedMemorySpace(BaseMemRefType type) { |
161 | auto space = |
162 | dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace()); |
163 | return space && |
164 | space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); |
165 | } |
166 | |
167 | /// Returns the value produced by a load from the default memory space. Returns |
168 | /// null if the operation is not such a load. |
169 | static Value getValueLoadedFromGlobal(Operation *op) { |
170 | // TODO: consider an interface or leveraging the memory effects interface. |
171 | auto load = dyn_cast<vector::TransferReadOp>(op); |
172 | if (!load) |
173 | return nullptr; |
174 | |
175 | auto loadType = dyn_cast<MemRefType>(load.getSource().getType()); |
176 | if (!loadType || !hasDefaultMemorySpace(loadType)) |
177 | return nullptr; |
178 | return load; |
179 | } |
180 | |
181 | /// Returns true if the operation is storing the given value into shared memory. |
182 | static bool isStoreToShared(Operation *op, Value v) { |
183 | // TOD: consider an interface or leveraging the memory effects interface. |
184 | auto store = dyn_cast<vector::TransferWriteOp>(op); |
185 | if (!store || store.getVector() != v) |
186 | return false; |
187 | |
188 | auto storeType = dyn_cast<MemRefType>(store.getSource().getType()); |
189 | return storeType || hasSharedMemorySpace(storeType); |
190 | } |
191 | |
192 | /// Returns true if the operation is a load from the default memory space the |
193 | /// result of which is only stored into the shared memory space. |
194 | static bool isLoadFromGlobalStoredToShared(Operation *op) { |
195 | Value loaded = getValueLoadedFromGlobal(op); |
196 | if (!loaded || !loaded.hasOneUse()) |
197 | return false; |
198 | |
199 | return isStoreToShared(op: *loaded.getUsers().begin(), v: loaded); |
200 | } |
201 | |
202 | /// Populate `ops` with the set of operations that belong to the stage 0 of the |
203 | /// pipelined version of the given loop when pipelining copies to shared memory. |
204 | /// Specifically, this collects: |
205 | /// |
206 | /// 1. all loads from global memory, both sync and async; |
207 | /// 2. the barriers for async loads. |
208 | /// |
209 | /// In particular, barriers are omitted if they do not dominate at least one |
210 | /// async load for which there is not yet a barrier. |
211 | static LogicalResult |
212 | collectStage0PipeliningOps(scf::ForOp forOp, |
213 | llvm::SmallPtrSet<Operation *, 16> &ops) { |
214 | |
215 | llvm::SmallPtrSet<Operation *, 4> barriers; |
216 | for (Operation &op : *forOp.getBody()) { |
217 | // Bail on nested ops for now. |
218 | if (op.getNumRegions() > 0) |
219 | return failure(); |
220 | |
221 | if (isa<gpu::BarrierOp>(op)) { |
222 | barriers.insert(&op); |
223 | continue; |
224 | } |
225 | |
226 | if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) { |
227 | ops.insert(&op); |
228 | ops.insert(std::make_move_iterator(barriers.begin()), |
229 | std::make_move_iterator(barriers.end())); |
230 | assert(barriers.empty() && |
231 | "expected to have moved the barriers into another set" ); |
232 | continue; |
233 | } |
234 | |
235 | if (isLoadFromGlobalStoredToShared(&op)) { |
236 | ops.insert(&op); |
237 | continue; |
238 | } |
239 | } |
240 | |
241 | return success(); |
242 | } |
243 | |
244 | /// Hook for the loop pipeliner that sets the "num groups in flight" attribute |
245 | /// of async wait operations corresponding to pipelined shared memory copies. |
246 | // TODO: this currently assumes that there are no groups that could be in flight |
247 | // in the existing code. |
248 | static void |
249 | setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, |
250 | scf::PipeliningOption::PipelinerPart part, |
251 | unsigned iteration, unsigned depth) { |
252 | // Based on the order of copies within the loop we need to set the number |
253 | // of copies in flight, unless it is already set. |
254 | auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op); |
255 | if (!waitOp || waitOp.getNumGroups()) |
256 | return; |
257 | |
258 | int numGroupInFlight = 0; |
259 | if (part == scf::PipeliningOption::PipelinerPart::Kernel || |
260 | part == scf::PipeliningOption::PipelinerPart::Prologue) { |
261 | numGroupInFlight = depth - 1; |
262 | } else { |
263 | // By construction there should be no wait op in the prologue as all the |
264 | // wait should be in the last stage. |
265 | assert(part == scf::PipeliningOption::PipelinerPart::Epilogue); |
266 | // Based on the schedule we pick we know how many groups are in flight for |
267 | // each iteration of the epilogue. |
268 | numGroupInFlight = depth - 1 - iteration; |
269 | } |
270 | waitOp.setNumGroups(numGroupInFlight); |
271 | } |
272 | |
273 | /// Hook for the loop pipeliner that populates `ops` with the stage information |
274 | /// as follows: |
275 | /// |
276 | /// - operations in `stage0Ops` (typically loads from global memory and |
277 | /// related barriers) are at stage 0; |
278 | /// - operations in the backward slice of any stage0Ops are all at stage 0; |
279 | /// - other operations are at stage `depth`; |
280 | /// - the internal order of the pipelined loop has ops at stage `depth` first, |
281 | /// then those at stage 0, with relative order within each group preserved. |
282 | /// |
283 | static void getPipelineStages( |
284 | scf::ForOp forOp, |
285 | std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages, |
286 | unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) { |
287 | SetVector<Operation *> dependencies; |
288 | BackwardSliceOptions options([&](Operation *visited) { |
289 | return visited->getBlock() == forOp.getBody(); |
290 | }); |
291 | options.inclusive = true; |
292 | for (Operation &op : forOp.getBody()->getOperations()) { |
293 | if (stage0Ops.contains(&op)) |
294 | getBackwardSlice(&op, &dependencies, options); |
295 | } |
296 | |
297 | for (Operation &op : forOp.getBody()->getOperations()) { |
298 | if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op)) |
299 | opsWithPipelineStages.emplace_back(&op, depth); |
300 | } |
301 | for (Operation &op : forOp.getBody()->getOperations()) { |
302 | if (dependencies.contains(&op)) |
303 | opsWithPipelineStages.emplace_back(&op, 0); |
304 | } |
305 | } |
306 | |
307 | /// Hook for the loop pipeliner. Replaces op with a predicated version and |
308 | /// returns the resulting operation. Returns the original op if the predication |
309 | /// isn't necessary for the given op. Returns null if predication is needed but |
310 | /// not supported. |
311 | static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, |
312 | Operation *op, Value predicate) { |
313 | // Some operations may be fine to execute "speculatively" more times than the |
314 | // original number of iterations, in particular side-effect free operations |
315 | // and barriers, even if they cannot be predicated. |
316 | if (isMemoryEffectFree(op) || |
317 | isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, |
318 | nvgpu::DeviceAsyncWaitOp>(op)) { |
319 | return op; |
320 | } |
321 | |
322 | // Otherwise, only async copies can currently be predicated. |
323 | auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op); |
324 | if (!asyncCopyOp) |
325 | return nullptr; |
326 | |
327 | // Create srcElement Value based on `predicate`. The next lines generate |
328 | // the following code: |
329 | // |
330 | // srcElement = (pred) ? prevSrcElements : 0; |
331 | // |
332 | Location loc = asyncCopyOp->getLoc(); |
333 | Value dstElements = |
334 | rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr()); |
335 | Value originalSrcElement = |
336 | asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; |
337 | Value c0Index = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
338 | auto srcElements = rewriter.create<arith::SelectOp>( |
339 | loc, predicate, originalSrcElement, c0Index); |
340 | auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>( |
341 | loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), |
342 | asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), |
343 | asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, |
344 | UnitAttr()); |
345 | rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp); |
346 | return asyncCopyZeroFillOp; |
347 | } |
348 | |
349 | /// Applies loop pipelining with the given depth to the given loop so that |
350 | /// copies into the shared memory are pipelined. Doesn't affect other loops. |
351 | /// Returns a pair containing the error state and the pipelined op, the latter |
352 | /// being null in case of any failure. The error state contains a definite error |
353 | /// if the IR has been modified and a silenceable error otherwise. |
354 | static std::tuple<DiagnosedSilenceableFailure, scf::ForOp> |
355 | pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, |
356 | bool epiloguePeeling) { |
357 | llvm::SmallPtrSet<Operation *, 16> stage0Ops; |
358 | if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) { |
359 | return std::make_tuple( |
360 | emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining" ), |
361 | scf::ForOp()); |
362 | } |
363 | if (stage0Ops.empty()) { |
364 | return std::make_tuple( |
365 | emitSilenceableFailure(forOp, "no shared memory copy" ), scf::ForOp()); |
366 | } |
367 | |
368 | scf::PipeliningOption options; |
369 | unsigned maxDepth = depth; |
370 | auto setAnnotation = [&](Operation *op, |
371 | scf::PipeliningOption::PipelinerPart part, |
372 | unsigned iteration) { |
373 | return setAsyncWaitGroupsInFlight(builder&: rewriter, op, part, iteration, depth: maxDepth); |
374 | }; |
375 | options.getScheduleFn = |
376 | [&](scf::ForOp schedulingFor, |
377 | std::vector<std::pair<Operation *, unsigned>> &ops) { |
378 | if (schedulingFor != forOp) |
379 | return; |
380 | return getPipelineStages(forOp, ops, maxDepth, stage0Ops); |
381 | }; |
382 | options.annotateFn = setAnnotation; |
383 | if (!epiloguePeeling) { |
384 | options.peelEpilogue = false; |
385 | options.predicateFn = replaceOpWithPredicatedOp; |
386 | } |
387 | |
388 | OpBuilder::InsertionGuard guard(rewriter); |
389 | rewriter.setInsertionPoint(forOp); |
390 | bool modifiedIR; |
391 | FailureOr<scf::ForOp> maybePipelined = |
392 | pipelineForLoop(rewriter, forOp, options, &modifiedIR); |
393 | if (succeeded(maybePipelined)) { |
394 | return std::make_tuple(DiagnosedSilenceableFailure::success(), |
395 | *maybePipelined); |
396 | } |
397 | return std::make_tuple( |
398 | modifiedIR |
399 | ? DiagnosedSilenceableFailure::definiteFailure() |
400 | : emitSilenceableFailure(forOp, "pipelining preconditions failed" ), |
401 | scf::ForOp()); |
402 | } |
403 | |
404 | DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne( |
405 | TransformRewriter &rewriter, scf::ForOp forOp, |
406 | ApplyToEachResultList &results, TransformState &state) { |
407 | auto [diag, pipelined] = pipelineForSharedCopies( |
408 | rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue()); |
409 | if (diag.succeeded()) { |
410 | results.push_back(pipelined); |
411 | return DiagnosedSilenceableFailure::success(); |
412 | } |
413 | if (diag.isDefiniteFailure()) { |
414 | auto diag = emitDefiniteFailure("irreversible pipelining failure" ); |
415 | if (!getPeelEpilogue()) { |
416 | diag.attachNote(forOp->getLoc()) << "couldn't predicate?" ; |
417 | diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName(); |
418 | } |
419 | return diag; |
420 | } |
421 | |
422 | return std::move(diag); |
423 | } |
424 | |
425 | //===----------------------------------------------------------------------===// |
426 | // RewriteMatmulAsMmaSyncOp |
427 | //===----------------------------------------------------------------------===// |
428 | |
429 | /// Helper struct to encode a pair of row/column indexings in the form of |
430 | /// affine expressions. |
431 | struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> { |
432 | RowColIndexing(AffineExpr row, AffineExpr col) |
433 | : std::pair<AffineExpr, AffineExpr>(row, col) {} |
434 | |
435 | AffineExpr row() const { return first; }; |
436 | AffineExpr col() const { return second; }; |
437 | |
438 | void print(llvm::raw_ostream &os) const { |
439 | os << "- indexing: " << first << ", " << second; |
440 | } |
441 | }; |
442 | |
443 | /// Helper struct to provide a simple mapping from matmul operations to the |
444 | /// corresponding mma.sync operation. This is constrained to the case where the |
445 | /// matmul matches the mma.sync operation 1-1. |
446 | struct MmaSyncBuilder { |
447 | MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId) |
448 | : b(b), loc(loc), laneId(laneId) {} |
449 | |
450 | using IndexCalculator = |
451 | std::function<SmallVector<RowColIndexing>(MLIRContext *)>; |
452 | |
453 | /// Create the mma.sync operation corresponding to `linalgOp` along with all |
454 | /// the supporting load/store and vector operations. |
455 | FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp); |
456 | |
457 | private: |
458 | struct MmaSyncInfo { |
459 | std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns; |
460 | std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>> |
461 | vectorShapes; |
462 | SmallVector<int64_t> mmaShape; |
463 | bool tf32Enabled; |
464 | }; |
465 | |
466 | /// Return the specific index calculator for the given `linalgOp` or failure |
467 | /// if the op is not supported. This is the toplevel switch that should just |
468 | /// be Tablegen'd in the future. |
469 | FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape, |
470 | TypeRange elementalTypes); |
471 | |
472 | //===--------------------------------------------------------------------===// |
473 | // Instruction-specific row, column indexing expression builders. |
474 | // These should all be declaratively specified via Tablegen in the future. |
475 | // The Tablegen specification should be as straightforward as possible to |
476 | // only model the existing size and type combinations. |
477 | //===--------------------------------------------------------------------===// |
478 | // |
479 | // TODO: Tablegen all this. |
480 | //===--------------------------------------------------------------------===// |
481 | // m16n8k4 tf32 case. |
482 | //===--------------------------------------------------------------------===// |
483 | /// From the NVIDIA doc: |
484 | /// groupID = %laneid >> 2 |
485 | /// threadIDInGroup = %laneid % 4 |
486 | /// row = groupID for a0 |
487 | /// groupID + 8 for a1 |
488 | /// col = threadIDInGroup |
489 | static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) { |
490 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
491 | AffineExpr groupID = dim.floorDiv(v: 4); |
492 | AffineExpr threadIDInGroup = dim % 4; |
493 | return {RowColIndexing{groupID, threadIDInGroup}, |
494 | RowColIndexing{groupID + 8, threadIDInGroup}}; |
495 | } |
496 | |
497 | /// From the NVIDIA doc: |
498 | /// groupID = %laneid >> 2 |
499 | /// threadIDInGroup = %laneid % 4 |
500 | /// row = threadIDInGroup |
501 | /// col = groupID |
502 | static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) { |
503 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
504 | AffineExpr groupID = dim.floorDiv(v: 4); |
505 | AffineExpr threadIDInGroup = dim % 4; |
506 | return {RowColIndexing{threadIDInGroup, groupID}}; |
507 | } |
508 | |
509 | /// From the NVIDIA doc: |
510 | /// groupID = %laneid >> 2 |
511 | /// threadIDInGroup = %laneid % 4 |
512 | /// row = groupID for c0 and c1 |
513 | /// groupID + 8 for c2 and c3 |
514 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
515 | static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) { |
516 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
517 | AffineExpr groupID = dim.floorDiv(v: 4); |
518 | AffineExpr threadIDInGroup = dim % 4; |
519 | return {RowColIndexing{groupID, threadIDInGroup * 2 + 0}, |
520 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, |
521 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, |
522 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}}; |
523 | } |
524 | |
525 | //===--------------------------------------------------------------------===// |
526 | // m16n8k16 f16 case. |
527 | //===--------------------------------------------------------------------===// |
528 | /// From the NVIDIA doc: |
529 | /// groupID = %laneid >> 2 |
530 | /// threadIDInGroup = %laneid % 4 |
531 | /// |
532 | /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 |
533 | /// groupID + 8 Otherwise |
534 | /// |
535 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4 |
536 | /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4 |
537 | static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) { |
538 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
539 | AffineExpr groupID = dim.floorDiv(v: 4); |
540 | AffineExpr threadIDInGroup = dim % 4; |
541 | // clang-format off |
542 | return { |
543 | RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
544 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
545 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
546 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3 |
547 | RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4 |
548 | RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5 |
549 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6 |
550 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7 |
551 | }; |
552 | // clang-format on |
553 | } |
554 | |
555 | /// From the NVIDIA doc: |
556 | /// groupID = %laneid >> 2 |
557 | /// threadIDInGroup = %laneid % 4 |
558 | /// |
559 | /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2 |
560 | /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2 |
561 | /// |
562 | /// col = groupID |
563 | static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) { |
564 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
565 | AffineExpr groupID = dim.floorDiv(v: 4); |
566 | AffineExpr threadIDInGroup = dim % 4; |
567 | // clang-format off |
568 | return { |
569 | RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0 |
570 | RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1 |
571 | RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2 |
572 | RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3 |
573 | }; |
574 | // clang-format on |
575 | } |
576 | |
577 | /// From the NVIDIA doc: |
578 | /// groupID = %laneid >> 2 |
579 | /// threadIDInGroup = %laneid % 4 |
580 | /// |
581 | /// row = groupID for ci where i < 2 |
582 | /// groupID + 8 for ci where i >= 2 |
583 | /// |
584 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
585 | static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) { |
586 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
587 | AffineExpr groupID = dim.floorDiv(v: 4); |
588 | AffineExpr threadIDInGroup = dim % 4; |
589 | // clang-format off |
590 | return { |
591 | RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
592 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
593 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
594 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3 |
595 | }; |
596 | // clang-format on |
597 | } |
598 | |
599 | //===--------------------------------------------------------------------===// |
600 | /// Helper functions to create customizable load and stores operations. The |
601 | /// specific shapes of each MMA instruction are passed via the |
602 | /// IndexCalculator callback. |
603 | //===--------------------------------------------------------------------===// |
604 | /// Build a list of memref.load operations indexed at `(row, col)` indices |
605 | /// that make sense for a particular MMA instruction and specified via the |
606 | /// IndexCalculator callback. |
607 | SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc, |
608 | OpFoldResult laneId, Value memref, |
609 | const IndexCalculator &indexFn); |
610 | |
611 | /// Perform a distributed load of a vector operand of `vectorShape` for a |
612 | /// particular MMA instruction whose `(row, col)` indices are specified via |
613 | /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
614 | /// data that makes sense for the particular MMA operation. |
615 | /// The `vectorShape` matches existing NVGPU dialect op specification but |
616 | /// could also be flattened in the future if needed for simplification. |
617 | Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc, |
618 | OpFoldResult laneId, Value memref, |
619 | IndexCalculator indexFn, |
620 | ArrayRef<int64_t> vectorShape); |
621 | |
622 | /// Build a list of memref.store operations indexed at `(row, col)` indices |
623 | /// that make sense for a particular MMA instruction and specified via the |
624 | /// IndexCalculator callback. |
625 | SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc, |
626 | ValueRange toStore, |
627 | OpFoldResult laneId, Value memref, |
628 | const IndexCalculator &indexFn); |
629 | |
630 | /// Perform a distributed store of a vector operand of `vectorShape` for a |
631 | /// particular MMA instruction whose `(row, col)` indices are specified via |
632 | /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
633 | /// data that makes sense for the particular MMA operation. |
634 | /// The `vectorShape` matches existing NVGPU dialect op specification but |
635 | /// could also be flattened in the future if needed for simplification. |
636 | SmallVector<Operation *> buildMmaSyncMemRefStoreOperand( |
637 | OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
638 | Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape); |
639 | |
640 | OpBuilder &b; |
641 | Location loc; |
642 | OpFoldResult laneId; |
643 | }; |
644 | |
645 | //===--------------------------------------------------------------------===// |
646 | /// Helper functions to create customizable load and stores operations. The |
647 | /// specific shapes of each MMA instruction are passed via the |
648 | /// IndexCalculator callback. |
649 | //===--------------------------------------------------------------------===// |
650 | |
651 | template <typename ApplyFn, typename ReduceFn> |
652 | static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, |
653 | ReduceFn reduceFn) { |
654 | VectorType vectorType = cast<VectorType>(vector.getType()); |
655 | auto vectorShape = vectorType.getShape(); |
656 | auto strides = computeStrides(vectorShape); |
657 | for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { |
658 | auto indices = delinearize(idx, strides); |
659 | reduceFn(applyFn(vector, idx, indices), idx, indices); |
660 | } |
661 | } |
662 | |
663 | SmallVector<Value> |
664 | MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, |
665 | OpFoldResult laneId, Value memref, |
666 | const IndexCalculator &indexFn) { |
667 | auto aff = [&](AffineExpr e) { |
668 | return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId); |
669 | }; |
670 | SmallVector<Value> res; |
671 | SmallVector<RowColIndexing> indexings = indexFn(b.getContext()); |
672 | for (auto indexing : indexings) { |
673 | Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row())); |
674 | Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col())); |
675 | auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col}); |
676 | res.push_back(Elt: load); |
677 | } |
678 | return res; |
679 | } |
680 | |
681 | Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( |
682 | OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, |
683 | IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
684 | auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn: std::move(indexFn)); |
685 | |
686 | Type elementType = getElementTypeOrSelf(type: memref.getType()); |
687 | auto vt = VectorType::get(vectorShape, elementType); |
688 | Value res = b.create<vector::SplatOp>(loc, vt, loads[0]); |
689 | foreachIndividualVectorElement( |
690 | vector: res, |
691 | /*applyFn=*/ |
692 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
693 | return loads[linearIdx]; |
694 | }, |
695 | /*reduceFn=*/ |
696 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
697 | res = b.create<vector::InsertOp>(loc, v, res, indices); |
698 | }); |
699 | |
700 | return res; |
701 | } |
702 | |
703 | SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores( |
704 | OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, |
705 | Value memref, const IndexCalculator &indexFn) { |
706 | auto aff = [&](AffineExpr e) { |
707 | return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId); |
708 | }; |
709 | SmallVector<Operation *> res; |
710 | for (auto [indexing, val] : |
711 | llvm::zip_equal(t: indexFn(b.getContext()), u&: toStore)) { |
712 | Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row())); |
713 | Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col())); |
714 | Operation *store = |
715 | b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col}); |
716 | res.push_back(Elt: store); |
717 | } |
718 | return res; |
719 | } |
720 | |
721 | SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( |
722 | OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
723 | Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
724 | SmallVector<Value> toStore; |
725 | toStore.reserve(N: 32); |
726 | foreachIndividualVectorElement( |
727 | vector: vectorToStore, |
728 | /*applyFn=*/ |
729 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
730 | return b.create<vector::ExtractOp>(loc, vectorToStore, indices); |
731 | }, |
732 | /*reduceFn=*/ |
733 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
734 | toStore.push_back(Elt: v); |
735 | }); |
736 | return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn: std::move(indexFn)); |
737 | } |
738 | |
739 | static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, |
740 | SmallVector<int64_t>> |
741 | makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, |
742 | ArrayRef<int64_t> res) { |
743 | SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()}; |
744 | SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()}; |
745 | SmallVector<int64_t> vres{res.begin(), res.end()}; |
746 | return std::make_tuple(args&: vlhs, args&: vrhs, args&: vres); |
747 | } |
748 | |
749 | FailureOr<MmaSyncBuilder::MmaSyncInfo> |
750 | MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape, |
751 | TypeRange elementalTypes) { |
752 | // TODO: Tablegen all this. |
753 | Type f16 = b.getF16Type(); |
754 | Type f32 = b.getF32Type(); |
755 | if (opShape == ArrayRef<int64_t>{16, 8, 4} && |
756 | elementalTypes == TypeRange{f32, f32, f32}) { |
757 | return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k4tf32Lhs, |
758 | args: &MmaSyncBuilder::m16n8k4tf32Rhs, |
759 | args: &MmaSyncBuilder::m16n8k4tf32Res), |
760 | .vectorShapes: makeVectorShapes(lhs: {2, 1}, rhs: {1, 1}, res: {2, 2}), |
761 | .mmaShape: SmallVector<int64_t>{opShape.begin(), opShape.end()}, |
762 | /*tf32Enabled=*/true}; |
763 | } |
764 | // This is the version with f16 accumulation. |
765 | // TODO: version with f32 accumulation. |
766 | if (opShape == ArrayRef<int64_t>{16, 8, 16} && |
767 | elementalTypes == TypeRange{f16, f16, f16}) { |
768 | return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k16f16Lhs, |
769 | args: &MmaSyncBuilder::m16n8k16f16Rhs, |
770 | args: &MmaSyncBuilder::m16n8k16f16Res), |
771 | .vectorShapes: makeVectorShapes(lhs: {4, 2}, rhs: {2, 2}, res: {2, 2}), |
772 | .mmaShape: SmallVector<int64_t>{opShape.begin(), opShape.end()}, |
773 | /*tf32Enabled=*/false}; |
774 | } |
775 | return failure(); |
776 | } |
777 | |
778 | FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { |
779 | Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); |
780 | Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); |
781 | Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); |
782 | assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 && |
783 | "expected lhs to be a 2D memref" ); |
784 | assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 && |
785 | "expected rhs to be a 2D memref" ); |
786 | assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 && |
787 | "expected res to be a 2D memref" ); |
788 | |
789 | int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0]; |
790 | int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1]; |
791 | int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1]; |
792 | Type lhsType = getElementTypeOrSelf(type: lhsMemRef.getType()); |
793 | Type rhsType = getElementTypeOrSelf(type: rhsMemRef.getType()); |
794 | Type resType = getElementTypeOrSelf(type: resMemRef.getType()); |
795 | |
796 | FailureOr<MmaSyncInfo> maybeInfo = |
797 | getIndexCalculators(opShape: {m, n, k}, elementalTypes: {lhsType, rhsType, resType}); |
798 | if (failed(result: maybeInfo)) |
799 | return failure(); |
800 | |
801 | MmaSyncInfo info = *maybeInfo; |
802 | auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; |
803 | auto [lhsShape, rhsShape, resShape] = info.vectorShapes; |
804 | Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: lhsMemRef, |
805 | indexFn: lhsIndexFn, vectorShape: lhsShape); |
806 | Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: rhsMemRef, |
807 | indexFn: rhsIndexFn, vectorShape: rhsShape); |
808 | Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: resMemRef, |
809 | indexFn: resIndexFn, vectorShape: resShape); |
810 | res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape, |
811 | info.tf32Enabled); |
812 | buildMmaSyncMemRefStoreOperand(b, loc, vectorToStore: res, laneId, memref: resMemRef, indexFn: resIndexFn, |
813 | vectorShape: resShape); |
814 | return res.getDefiningOp(); |
815 | } |
816 | |
817 | DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( |
818 | transform::TransformRewriter &rewriter, LinalgOp linalgOp, |
819 | transform::ApplyToEachResultList &results, |
820 | transform::TransformState &state) { |
821 | bool fail = true; |
822 | // TODO: more robust detection of matmulOp, with transposes etc. |
823 | if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) { |
824 | Location loc = linalgOp.getLoc(); |
825 | // TODO: more robust computation of laneId, for now assume a single warp. |
826 | Value laneId = rewriter.create<gpu::ThreadIdOp>( |
827 | loc, rewriter.getIndexType(), gpu::Dimension::x); |
828 | if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) |
829 | fail = false; |
830 | } |
831 | |
832 | if (fail) { |
833 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
834 | << "unsupported target op: " << linalgOp; |
835 | diag.attachNote(linalgOp->getLoc()) << "target op" ; |
836 | return diag; |
837 | } |
838 | |
839 | rewriter.eraseOp(linalgOp); |
840 | return DiagnosedSilenceableFailure::success(); |
841 | } |
842 | |
843 | //===----------------------------------------------------------------------===// |
844 | // Hopper builders. |
845 | //===----------------------------------------------------------------------===// |
846 | |
847 | /// Helper to create the base Hopper-specific operations that are reused in |
848 | /// various other places. |
849 | struct HopperBuilder { |
850 | HopperBuilder(RewriterBase &rewriter, Location loc) |
851 | : rewriter(rewriter), loc(loc) {} |
852 | |
853 | TypedValue<nvgpu::MBarrierGroupType> |
854 | buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); |
855 | |
856 | /// Create tma descriptor op to initiate transfer from global to shared |
857 | /// memory. This must be done before the launch op, on the host. |
858 | TypedValue<nvgpu::TensorMapDescriptorType> |
859 | buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
860 | gpu::LaunchOp launchOp); |
861 | |
862 | /// Build a tma load from global memory to shared memory using `barrier` to |
863 | /// synchronize. Return the number of bytes that will be transferred. |
864 | OpFoldResult |
865 | buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
866 | TypedValue<MemRefType> sharedMemref, |
867 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
868 | SmallVectorImpl<Operation *> &loadOps); |
869 | void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier, |
870 | ArrayRef<OpFoldResult> sizes); |
871 | |
872 | /// If threadIdx.x == 0 does TMA request + wait, else just wait. |
873 | /// Return the operation that performs the transfer on thread0. |
874 | // TODO: In the future, don't hardcode to thread 0 but elect a leader. |
875 | SmallVector<Operation *> buildPredicateLoadsOnThread0( |
876 | ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
877 | ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
878 | TypedValue<nvgpu::MBarrierGroupType> barrier); |
879 | |
880 | void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier); |
881 | |
882 | RewriterBase &rewriter; |
883 | Location loc; |
884 | }; |
885 | |
886 | SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0( |
887 | ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
888 | ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
889 | TypedValue<nvgpu::MBarrierGroupType> barrier) { |
890 | SmallVector<Operation *> loadOps; |
891 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
892 | Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
893 | Value cond = |
894 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero); |
895 | // clang-format off |
896 | rewriter.create<scf::IfOp>( |
897 | /*location=*/loc, |
898 | /*conditional=*/cond, |
899 | /*thenBuilder=*/ |
900 | [&](OpBuilder &lb, Location loc) { |
901 | SmallVector<OpFoldResult> sizes; |
902 | sizes.reserve(N: globalDescriptors.size()); |
903 | for (auto [desc, shmem] : llvm::zip_equal( |
904 | globalDescriptors, sharedMemBuffers)) { |
905 | OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps); |
906 | sizes.push_back(sz); |
907 | } |
908 | // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. |
909 | // This may or may not have perf implications. |
910 | buildBarrierArriveTx(barrier, sizes); |
911 | rewriter.create<scf::YieldOp>(loc); |
912 | }, |
913 | /*elseBuilder=*/ |
914 | [&](OpBuilder &lb, Location loc) { |
915 | // TODO: is this for no-thread divergence? |
916 | // Should we just yield the size and hoist? |
917 | buildBarrierArriveTx(barrier, sizes: getAsIndexOpFoldResult(ctx: rewriter.getContext(), val: 0)); |
918 | rewriter.create<scf::YieldOp>(loc); |
919 | }); |
920 | // clang-format on |
921 | return loadOps; |
922 | } |
923 | |
924 | static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { |
925 | return gpu::AddressSpaceAttr::get( |
926 | b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); |
927 | // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace)); |
928 | } |
929 | |
930 | TypedValue<nvgpu::MBarrierGroupType> |
931 | HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { |
932 | auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter); |
933 | Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>( |
934 | loc, |
935 | nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); |
936 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
937 | rewriter.create<nvgpu::MBarrierInitOp>( |
938 | loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), |
939 | zero, Value()); |
940 | rewriter.create<gpu::BarrierOp>(loc); |
941 | return cast<TypedValue<nvgpu::MBarrierGroupType>>(Val&: barrier); |
942 | } |
943 | |
944 | TypedValue<nvgpu::TensorMapDescriptorType> |
945 | HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
946 | gpu::LaunchOp launchOp) { |
947 | OpBuilder::InsertionGuard guard(rewriter); |
948 | rewriter.setInsertionPoint(launchOp); |
949 | Value unrankedMemRef = rewriter.create<memref::CastOp>( |
950 | loc, |
951 | UnrankedMemRefType::get(memref.getType().getElementType(), |
952 | memref.getType().getMemorySpace()), |
953 | memref); |
954 | SmallVector<OpFoldResult> mixedSizes = |
955 | memref::getMixedSizes(builder&: rewriter, loc, value: memref); |
956 | SmallVector<Value> sizes = |
957 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: mixedSizes); |
958 | |
959 | auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter); |
960 | Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>( |
961 | loc, |
962 | nvgpu::TensorMapDescriptorType::get( |
963 | rewriter.getContext(), |
964 | MemRefType::Builder(memref.getType()) |
965 | .setMemorySpace(sharedMemorySpace), |
966 | TensorMapSwizzleKind::SWIZZLE_NONE, |
967 | TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, |
968 | TensorMapInterleaveKind::INTERLEAVE_NONE), |
969 | unrankedMemRef, sizes); |
970 | return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc); |
971 | } |
972 | |
973 | OpFoldResult HopperBuilder::buildTmaAsyncLoad( |
974 | TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
975 | TypedValue<MemRefType> sharedMemref, |
976 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
977 | SmallVectorImpl<Operation *> &loadOps) { |
978 | MLIRContext *ctx = rewriter.getContext(); |
979 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
980 | Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>( |
981 | loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero, |
982 | Value(), Value()); |
983 | loadOps.push_back(Elt: loadOp); |
984 | auto mixedSizes = memref::getMixedSizes(builder&: rewriter, loc, value: sharedMemref); |
985 | SmallVector<AffineExpr> symbols(mixedSizes.size()); |
986 | bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols}); |
987 | AffineExpr prodExprInBytes = |
988 | computeProduct(ctx, basis: symbols) * |
989 | (sharedMemref.getType().getElementTypeBitWidth() / 8); |
990 | auto res = affine::makeComposedFoldedAffineApply(b&: rewriter, loc, |
991 | expr: prodExprInBytes, operands: mixedSizes); |
992 | return res; |
993 | } |
994 | |
995 | void HopperBuilder::buildBarrierArriveTx( |
996 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
997 | ArrayRef<OpFoldResult> mixedSizes) { |
998 | assert(!mixedSizes.empty() && "expecte non-empty sizes" ); |
999 | MLIRContext *ctx = rewriter.getContext(); |
1000 | SmallVector<AffineExpr> symbols(mixedSizes.size()); |
1001 | bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols}); |
1002 | AffineExpr sumExpr = computeSum(ctx, basis: symbols); |
1003 | OpFoldResult size = |
1004 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: sumExpr, operands: mixedSizes); |
1005 | Value sizeVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: size); |
1006 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1007 | rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero, |
1008 | Value()); |
1009 | } |
1010 | |
1011 | void HopperBuilder::buildTryWaitParity( |
1012 | TypedValue<nvgpu::MBarrierGroupType> barrier) { |
1013 | Type i1 = rewriter.getI1Type(); |
1014 | Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0); |
1015 | // 10M is an arbitrary, not too small or too big number to specify the number |
1016 | // of ticks before retry. |
1017 | // TODO: hoist this in a default dialect constant. |
1018 | Value ticksBeforeRetry = |
1019 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 10000000); |
1020 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1021 | rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity, |
1022 | ticksBeforeRetry, zero); |
1023 | } |
1024 | |
1025 | //===----------------------------------------------------------------------===// |
1026 | // RewriteCopyAsTmaOp |
1027 | //===----------------------------------------------------------------------===// |
1028 | |
1029 | /// Helper to create the tma operations corresponding to `linalg::CopyOp`. |
1030 | struct CopyBuilder : public HopperBuilder { |
1031 | CopyBuilder(RewriterBase &rewriter, Location loc) |
1032 | : HopperBuilder(rewriter, loc) {} |
1033 | |
1034 | SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps); |
1035 | }; |
1036 | |
1037 | SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { |
1038 | MLIRContext *ctx = rewriter.getContext(); |
1039 | if (copyOps.empty()) |
1040 | return SmallVector<Operation *>(); |
1041 | |
1042 | auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>(); |
1043 | assert(launchOp && "expected launch op" ); |
1044 | |
1045 | // 1. Init a barrier object in shared memory. |
1046 | OpBuilder::InsertionGuard g(rewriter); |
1047 | rewriter.setInsertionPoint(copyOps.front()); |
1048 | AffineExpr bx, by, bz; |
1049 | bindSymbols(ctx, exprs&: bx, exprs&: by, exprs&: bz); |
1050 | AffineExpr prod = computeProduct(ctx, basis: ArrayRef<AffineExpr>{bx, by, bz}); |
1051 | OpFoldResult numThreads = affine::makeComposedFoldedAffineApply( |
1052 | b&: rewriter, loc, expr: prod, |
1053 | operands: ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), |
1054 | launchOp.getBlockSizeZ()}); |
1055 | |
1056 | TypedValue<nvgpu::MBarrierGroupType> barrier = |
1057 | buildAndInitBarrierInSharedMemory(numThreads); |
1058 | |
1059 | SmallVector<TypedValue<MemRefType>> shmems; |
1060 | SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs; |
1061 | for (Operation *op : copyOps) { |
1062 | auto copyOp = cast<linalg::CopyOp>(op); |
1063 | auto inMemRef = |
1064 | cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get()); |
1065 | assert(inMemRef.getType().getRank() == 2 && |
1066 | "expected in to be a 2D memref" ); |
1067 | |
1068 | // 2. Build global memory descriptor. |
1069 | TypedValue<nvgpu::TensorMapDescriptorType> globalDesc = |
1070 | buildGlobalMemRefDescriptor(inMemRef, launchOp); |
1071 | globalDescs.push_back(globalDesc); |
1072 | |
1073 | // 3. Shared memory and descriptor for the tmp array. |
1074 | auto shmem = |
1075 | cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get()); |
1076 | shmems.push_back(Elt: shmem); |
1077 | } |
1078 | |
1079 | // 4. Load in from global memory to shared memory using tma. |
1080 | OpBuilder::InsertionGuard g2(rewriter); |
1081 | rewriter.setInsertionPoint(copyOps.front()); |
1082 | SmallVector<Operation *> results = |
1083 | buildPredicateLoadsOnThread0(globalDescs, shmems, barrier); |
1084 | |
1085 | // 5. Spin-loop until data is ready. |
1086 | buildTryWaitParity(barrier); |
1087 | |
1088 | // 6. Erase the ops that have now been rewritten. |
1089 | for (Operation *op : copyOps) |
1090 | rewriter.eraseOp(op); |
1091 | |
1092 | return results; |
1093 | } |
1094 | |
1095 | DiagnosedSilenceableFailure |
1096 | transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, |
1097 | transform::TransformResults &results, |
1098 | transform::TransformState &state) { |
1099 | auto payloadOps = state.getPayloadOps(getTarget()); |
1100 | gpu::LaunchOp commonLaunchOp; |
1101 | Operation *firstOp, *failingOp; |
1102 | if (llvm::any_of(payloadOps, [&](Operation *op) { |
1103 | if (!commonLaunchOp) { |
1104 | commonLaunchOp = op->getParentOfType<gpu::LaunchOp>(); |
1105 | firstOp = op; |
1106 | } |
1107 | auto fail = !op->getParentOfType<gpu::LaunchOp>() || |
1108 | commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() || |
1109 | !isa<linalg::CopyOp>(op); |
1110 | if (fail) |
1111 | failingOp = op; |
1112 | return fail; |
1113 | })) { |
1114 | DiagnosedSilenceableFailure diag = |
1115 | emitSilenceableError() |
1116 | << "target ops must be linalg::CopyOp nested under a common " |
1117 | "gpu.LaunchOp to be rewritten because the tma descriptors need to " |
1118 | "be created on the host.\nBut got: " |
1119 | << *firstOp << "\nand " << *failingOp; |
1120 | return diag; |
1121 | } |
1122 | |
1123 | // TODO: more robust detection of copy, with transposes etc. |
1124 | CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps)); |
1125 | |
1126 | return DiagnosedSilenceableFailure::success(); |
1127 | } |
1128 | |
1129 | //===----------------------------------------------------------------------===// |
1130 | // Transform op registration |
1131 | //===----------------------------------------------------------------------===// |
1132 | |
1133 | namespace { |
1134 | class NVGPUTransformDialectExtension |
1135 | : public transform::TransformDialectExtension< |
1136 | NVGPUTransformDialectExtension> { |
1137 | public: |
1138 | NVGPUTransformDialectExtension() { |
1139 | declareGeneratedDialect<arith::ArithDialect>(); |
1140 | declareGeneratedDialect<affine::AffineDialect>(); |
1141 | declareGeneratedDialect<nvgpu::NVGPUDialect>(); |
1142 | declareGeneratedDialect<NVVM::NVVMDialect>(); |
1143 | declareGeneratedDialect<vector::VectorDialect>(); |
1144 | registerTransformOps< |
1145 | #define GET_OP_LIST |
1146 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
1147 | >(); |
1148 | } |
1149 | }; |
1150 | } // namespace |
1151 | |
1152 | #define GET_OP_CLASSES |
1153 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
1154 | |
1155 | void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
1156 | registry.addExtensions<NVGPUTransformDialectExtension>(); |
1157 | } |
1158 | |