1 | //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" |
10 | |
11 | #include "mlir/Analysis/SliceAnalysis.h" |
12 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
13 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
14 | #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" |
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
23 | #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
26 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
27 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
28 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
29 | #include "mlir/IR/AffineExpr.h" |
30 | #include "mlir/IR/BuiltinTypes.h" |
31 | #include "mlir/IR/Value.h" |
32 | #include "llvm/ADT/ArrayRef.h" |
33 | |
34 | using namespace mlir; |
35 | using namespace mlir::linalg; |
36 | using namespace mlir::nvgpu; |
37 | using namespace mlir::NVVM; |
38 | using namespace mlir::transform; |
39 | |
40 | #define DEBUG_TYPE "nvgpu-transforms" |
41 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
42 | #define DBGSNL() (llvm::dbgs() << "\n") |
43 | #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
44 | |
45 | //===----------------------------------------------------------------------===// |
46 | // Apply...ConversionPatternsOp |
47 | //===----------------------------------------------------------------------===// |
48 | |
49 | void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( |
50 | TypeConverter &typeConverter, RewritePatternSet &patterns) { |
51 | auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); |
52 | /// device-side async tokens cannot be materialized in nvvm. We just |
53 | /// convert them to a dummy i32 type in order to easily drop them during |
54 | /// conversion. |
55 | populateGpuMemorySpaceAttributeConversions( |
56 | llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned { |
57 | switch (space) { |
58 | case gpu::AddressSpace::Global: |
59 | return static_cast<unsigned>( |
60 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
61 | case gpu::AddressSpace::Workgroup: |
62 | return static_cast<unsigned>( |
63 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
64 | case gpu::AddressSpace::Private: |
65 | return 0; |
66 | } |
67 | llvm_unreachable("unknown address space enum value"); |
68 | return 0; |
69 | }); |
70 | llvmTypeConverter.addConversion( |
71 | [&](nvgpu::DeviceAsyncTokenType type) -> Type { |
72 | return llvmTypeConverter.convertType( |
73 | IntegerType::get(type.getContext(), 32)); |
74 | }); |
75 | llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { |
76 | return llvmTypeConverter.convertType( |
77 | IntegerType::get(type.getContext(), 64)); |
78 | }); |
79 | llvmTypeConverter.addConversion( |
80 | [&](nvgpu::WarpgroupAccumulatorType type) -> Type { |
81 | Type elemType = type.getFragmented().getElementType(); |
82 | int64_t sizeM = type.getFragmented().getDimSize(0); |
83 | int64_t sizeN = type.getFragmented().getDimSize(1); |
84 | |
85 | unsigned numMembers; |
86 | if (elemType.isF32() || elemType.isInteger(32)) |
87 | numMembers = sizeN / 2; |
88 | else if (elemType.isF16()) |
89 | numMembers = sizeN / 4; |
90 | else |
91 | llvm_unreachable("unsupported type for warpgroup accumulator"); |
92 | |
93 | SmallVector<Type> innerStructBody; |
94 | for (unsigned i = 0; i < numMembers; i++) |
95 | innerStructBody.push_back(elemType); |
96 | auto innerStructType = LLVM::LLVMStructType::getLiteral( |
97 | type.getContext(), innerStructBody); |
98 | |
99 | SmallVector<Type> structBody; |
100 | for (int i = 0; i < sizeM; i += kWgmmaSizeM) |
101 | structBody.push_back(innerStructType); |
102 | |
103 | auto convertedType = |
104 | LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); |
105 | return llvmTypeConverter.convertType(convertedType); |
106 | }); |
107 | llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { |
108 | return llvmTypeConverter.convertType( |
109 | getMBarrierMemrefType(type.getContext(), type)); |
110 | }); |
111 | llvmTypeConverter.addConversion( |
112 | [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { |
113 | return llvmTypeConverter.convertType( |
114 | IntegerType::get(type.getContext(), 64)); |
115 | }); |
116 | llvmTypeConverter.addConversion( |
117 | [&](nvgpu::TensorMapDescriptorType type) -> Type { |
118 | return LLVM::LLVMPointerType::get(type.getContext()); |
119 | }); |
120 | populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); |
121 | } |
122 | |
123 | LogicalResult |
124 | transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( |
125 | transform::TypeConverterBuilderOpInterface builder) { |
126 | if (builder.getTypeConverterType() != "LLVMTypeConverter") |
127 | return emitOpError("expected LLVMTypeConverter"); |
128 | return success(); |
129 | } |
130 | |
131 | //===---------------------------------------------------------------------===// |
132 | // CreateAsyncGroupsOp |
133 | //===---------------------------------------------------------------------===// |
134 | |
135 | void transform::CreateAsyncGroupsOp::getEffects( |
136 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
137 | transform::consumesHandle(getTargetMutable(), effects); |
138 | transform::producesHandle(getOperation()->getOpResults(), effects); |
139 | transform::modifiesPayload(effects); |
140 | } |
141 | |
142 | DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( |
143 | TransformRewriter &rewriter, Operation *target, |
144 | ApplyToEachResultList &results, TransformState &state) { |
145 | nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); |
146 | results.push_back(target); |
147 | return DiagnosedSilenceableFailure::success(); |
148 | } |
149 | |
150 | //===----------------------------------------------------------------------===// |
151 | // PipelineSharedMemoryCopiesOp |
152 | //===----------------------------------------------------------------------===// |
153 | |
154 | /// Returns true if the given type has the default memory space. |
155 | static bool hasDefaultMemorySpace(BaseMemRefType type) { |
156 | return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0; |
157 | } |
158 | |
159 | /// Returns true if the given type has the shared (workgroup) memory space. |
160 | static bool hasSharedMemorySpace(BaseMemRefType type) { |
161 | auto space = |
162 | dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace()); |
163 | return space && |
164 | space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); |
165 | } |
166 | |
167 | /// Returns the value produced by a load from the default memory space. Returns |
168 | /// null if the operation is not such a load. |
169 | static Value getValueLoadedFromGlobal(Operation *op) { |
170 | // TODO: consider an interface or leveraging the memory effects interface. |
171 | auto load = dyn_cast<vector::TransferReadOp>(op); |
172 | if (!load) |
173 | return nullptr; |
174 | |
175 | auto loadType = dyn_cast<MemRefType>(load.getBase().getType()); |
176 | if (!loadType || !hasDefaultMemorySpace(loadType)) |
177 | return nullptr; |
178 | return load; |
179 | } |
180 | |
181 | /// Returns true if the operation is storing the given value into shared memory. |
182 | static bool isStoreToShared(Operation *op, Value v) { |
183 | // TOD: consider an interface or leveraging the memory effects interface. |
184 | auto store = dyn_cast<vector::TransferWriteOp>(op); |
185 | if (!store || store.getVector() != v) |
186 | return false; |
187 | |
188 | auto storeType = dyn_cast<MemRefType>(store.getBase().getType()); |
189 | return storeType || hasSharedMemorySpace(storeType); |
190 | } |
191 | |
192 | /// Returns true if the operation is a load from the default memory space the |
193 | /// result of which is only stored into the shared memory space. |
194 | static bool isLoadFromGlobalStoredToShared(Operation *op) { |
195 | Value loaded = getValueLoadedFromGlobal(op); |
196 | if (!loaded || !loaded.hasOneUse()) |
197 | return false; |
198 | |
199 | return isStoreToShared(op: *loaded.getUsers().begin(), v: loaded); |
200 | } |
201 | |
202 | /// Populate `ops` with the set of operations that belong to the stage 0 of the |
203 | /// pipelined version of the given loop when pipelining copies to shared memory. |
204 | /// Specifically, this collects: |
205 | /// |
206 | /// 1. all loads from global memory, both sync and async; |
207 | /// 2. the barriers for async loads. |
208 | /// |
209 | /// In particular, barriers are omitted if they do not dominate at least one |
210 | /// async load for which there is not yet a barrier. |
211 | static LogicalResult |
212 | collectStage0PipeliningOps(scf::ForOp forOp, |
213 | llvm::SmallPtrSet<Operation *, 16> &ops) { |
214 | |
215 | llvm::SmallPtrSet<Operation *, 4> barriers; |
216 | for (Operation &op : *forOp.getBody()) { |
217 | // Bail on nested ops for now. |
218 | if (op.getNumRegions() > 0) |
219 | return failure(); |
220 | |
221 | if (isa<gpu::BarrierOp>(op)) { |
222 | barriers.insert(&op); |
223 | continue; |
224 | } |
225 | |
226 | if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) { |
227 | ops.insert(&op); |
228 | ops.insert(std::make_move_iterator(barriers.begin()), |
229 | std::make_move_iterator(barriers.end())); |
230 | assert(barriers.empty() && |
231 | "expected to have moved the barriers into another set"); |
232 | continue; |
233 | } |
234 | |
235 | if (isLoadFromGlobalStoredToShared(&op)) { |
236 | ops.insert(&op); |
237 | continue; |
238 | } |
239 | } |
240 | |
241 | return success(); |
242 | } |
243 | |
244 | /// Hook for the loop pipeliner that sets the "num groups in flight" attribute |
245 | /// of async wait operations corresponding to pipelined shared memory copies. |
246 | // TODO: this currently assumes that there are no groups that could be in flight |
247 | // in the existing code. |
248 | static void |
249 | setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, |
250 | scf::PipeliningOption::PipelinerPart part, |
251 | unsigned iteration, unsigned depth) { |
252 | // Based on the order of copies within the loop we need to set the number |
253 | // of copies in flight, unless it is already set. |
254 | auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op); |
255 | if (!waitOp || waitOp.getNumGroups()) |
256 | return; |
257 | |
258 | int numGroupInFlight = 0; |
259 | if (part == scf::PipeliningOption::PipelinerPart::Kernel || |
260 | part == scf::PipeliningOption::PipelinerPart::Prologue) { |
261 | numGroupInFlight = depth - 1; |
262 | } else { |
263 | // By construction there should be no wait op in the prologue as all the |
264 | // wait should be in the last stage. |
265 | assert(part == scf::PipeliningOption::PipelinerPart::Epilogue); |
266 | // Based on the schedule we pick we know how many groups are in flight for |
267 | // each iteration of the epilogue. |
268 | numGroupInFlight = depth - 1 - iteration; |
269 | } |
270 | waitOp.setNumGroups(numGroupInFlight); |
271 | } |
272 | |
273 | /// Hook for the loop pipeliner that populates `ops` with the stage information |
274 | /// as follows: |
275 | /// |
276 | /// - operations in `stage0Ops` (typically loads from global memory and |
277 | /// related barriers) are at stage 0; |
278 | /// - operations in the backward slice of any stage0Ops are all at stage 0; |
279 | /// - other operations are at stage `depth`; |
280 | /// - the internal order of the pipelined loop has ops at stage `depth` first, |
281 | /// then those at stage 0, with relative order within each group preserved. |
282 | /// |
283 | static void getPipelineStages( |
284 | scf::ForOp forOp, |
285 | std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages, |
286 | unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) { |
287 | SetVector<Operation *> dependencies; |
288 | BackwardSliceOptions options([&](Operation *visited) { |
289 | return visited->getBlock() == forOp.getBody(); |
290 | }); |
291 | options.inclusive = true; |
292 | for (Operation &op : forOp.getBody()->getOperations()) { |
293 | if (stage0Ops.contains(&op)) { |
294 | LogicalResult result = getBackwardSlice(&op, &dependencies, options); |
295 | assert(result.succeeded() && "expected a backward slice"); |
296 | (void)result; |
297 | } |
298 | } |
299 | |
300 | for (Operation &op : forOp.getBody()->getOperations()) { |
301 | if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op)) |
302 | opsWithPipelineStages.emplace_back(&op, depth); |
303 | } |
304 | for (Operation &op : forOp.getBody()->getOperations()) { |
305 | if (dependencies.contains(&op)) |
306 | opsWithPipelineStages.emplace_back(&op, 0); |
307 | } |
308 | } |
309 | |
310 | /// Hook for the loop pipeliner. Replaces op with a predicated version and |
311 | /// returns the resulting operation. Returns the original op if the predication |
312 | /// isn't necessary for the given op. Returns null if predication is needed but |
313 | /// not supported. |
314 | static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, |
315 | Operation *op, Value predicate) { |
316 | // Some operations may be fine to execute "speculatively" more times than the |
317 | // original number of iterations, in particular side-effect free operations |
318 | // and barriers, even if they cannot be predicated. |
319 | if (isMemoryEffectFree(op) || |
320 | isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, |
321 | nvgpu::DeviceAsyncWaitOp>(op)) { |
322 | return op; |
323 | } |
324 | |
325 | // Otherwise, only async copies can currently be predicated. |
326 | auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op); |
327 | if (!asyncCopyOp) |
328 | return nullptr; |
329 | |
330 | // Create srcElement Value based on `predicate`. The next lines generate |
331 | // the following code: |
332 | // |
333 | // srcElement = (pred) ? prevSrcElements : 0; |
334 | // |
335 | Location loc = asyncCopyOp->getLoc(); |
336 | Value dstElements = |
337 | rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr()); |
338 | Value originalSrcElement = |
339 | asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; |
340 | Value c0Index = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
341 | auto srcElements = rewriter.create<arith::SelectOp>( |
342 | loc, predicate, originalSrcElement, c0Index); |
343 | auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>( |
344 | loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), |
345 | asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), |
346 | asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, |
347 | UnitAttr()); |
348 | rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp); |
349 | return asyncCopyZeroFillOp; |
350 | } |
351 | |
352 | /// Applies loop pipelining with the given depth to the given loop so that |
353 | /// copies into the shared memory are pipelined. Doesn't affect other loops. |
354 | /// Returns a pair containing the error state and the pipelined op, the latter |
355 | /// being null in case of any failure. The error state contains a definite error |
356 | /// if the IR has been modified and a silenceable error otherwise. |
357 | static std::tuple<DiagnosedSilenceableFailure, scf::ForOp> |
358 | pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, |
359 | bool epiloguePeeling) { |
360 | llvm::SmallPtrSet<Operation *, 16> stage0Ops; |
361 | if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) { |
362 | return std::make_tuple( |
363 | emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"), |
364 | scf::ForOp()); |
365 | } |
366 | if (stage0Ops.empty()) { |
367 | return std::make_tuple( |
368 | emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp()); |
369 | } |
370 | |
371 | scf::PipeliningOption options; |
372 | unsigned maxDepth = depth; |
373 | auto setAnnotation = [&](Operation *op, |
374 | scf::PipeliningOption::PipelinerPart part, |
375 | unsigned iteration) { |
376 | return setAsyncWaitGroupsInFlight(builder&: rewriter, op, part, iteration, depth: maxDepth); |
377 | }; |
378 | options.getScheduleFn = |
379 | [&](scf::ForOp schedulingFor, |
380 | std::vector<std::pair<Operation *, unsigned>> &ops) { |
381 | if (schedulingFor != forOp) |
382 | return; |
383 | return getPipelineStages(forOp, ops, maxDepth, stage0Ops); |
384 | }; |
385 | options.annotateFn = setAnnotation; |
386 | if (!epiloguePeeling) { |
387 | options.peelEpilogue = false; |
388 | options.predicateFn = replaceOpWithPredicatedOp; |
389 | } |
390 | |
391 | OpBuilder::InsertionGuard guard(rewriter); |
392 | rewriter.setInsertionPoint(forOp); |
393 | bool modifiedIR; |
394 | FailureOr<scf::ForOp> maybePipelined = |
395 | pipelineForLoop(rewriter, forOp, options, &modifiedIR); |
396 | if (succeeded(maybePipelined)) { |
397 | return std::make_tuple(DiagnosedSilenceableFailure::success(), |
398 | *maybePipelined); |
399 | } |
400 | return std::make_tuple( |
401 | modifiedIR |
402 | ? DiagnosedSilenceableFailure::definiteFailure() |
403 | : emitSilenceableFailure(forOp, "pipelining preconditions failed"), |
404 | scf::ForOp()); |
405 | } |
406 | |
407 | DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne( |
408 | TransformRewriter &rewriter, scf::ForOp forOp, |
409 | ApplyToEachResultList &results, TransformState &state) { |
410 | auto [diag, pipelined] = pipelineForSharedCopies( |
411 | rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue()); |
412 | if (diag.succeeded()) { |
413 | results.push_back(pipelined); |
414 | return DiagnosedSilenceableFailure::success(); |
415 | } |
416 | if (diag.isDefiniteFailure()) { |
417 | auto diag = emitDefiniteFailure("irreversible pipelining failure"); |
418 | if (!getPeelEpilogue()) { |
419 | diag.attachNote(forOp->getLoc()) << "couldn't predicate?"; |
420 | diag.attachNote(getLoc()) << "try setting "<< getPeelEpilogueAttrName(); |
421 | } |
422 | return diag; |
423 | } |
424 | |
425 | return std::move(diag); |
426 | } |
427 | |
428 | //===----------------------------------------------------------------------===// |
429 | // RewriteMatmulAsMmaSyncOp |
430 | //===----------------------------------------------------------------------===// |
431 | |
432 | /// Helper struct to encode a pair of row/column indexings in the form of |
433 | /// affine expressions. |
434 | struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> { |
435 | RowColIndexing(AffineExpr row, AffineExpr col) |
436 | : std::pair<AffineExpr, AffineExpr>(row, col) {} |
437 | |
438 | AffineExpr row() const { return first; }; |
439 | AffineExpr col() const { return second; }; |
440 | |
441 | void print(llvm::raw_ostream &os) const { |
442 | os << "- indexing: "<< first << ", "<< second; |
443 | } |
444 | }; |
445 | |
446 | /// Helper struct to provide a simple mapping from matmul operations to the |
447 | /// corresponding mma.sync operation. This is constrained to the case where the |
448 | /// matmul matches the mma.sync operation 1-1. |
449 | struct MmaSyncBuilder { |
450 | MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId) |
451 | : b(b), loc(loc), laneId(laneId) {} |
452 | |
453 | using IndexCalculator = |
454 | std::function<SmallVector<RowColIndexing>(MLIRContext *)>; |
455 | |
456 | /// Create the mma.sync operation corresponding to `linalgOp` along with all |
457 | /// the supporting load/store and vector operations. |
458 | FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp); |
459 | |
460 | private: |
461 | struct MmaSyncInfo { |
462 | std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns; |
463 | std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>> |
464 | vectorShapes; |
465 | SmallVector<int64_t> mmaShape; |
466 | bool tf32Enabled; |
467 | }; |
468 | |
469 | /// Return the specific index calculator for the given `linalgOp` or failure |
470 | /// if the op is not supported. This is the toplevel switch that should just |
471 | /// be Tablegen'd in the future. |
472 | FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape, |
473 | TypeRange elementalTypes); |
474 | |
475 | //===--------------------------------------------------------------------===// |
476 | // Instruction-specific row, column indexing expression builders. |
477 | // These should all be declaratively specified via Tablegen in the future. |
478 | // The Tablegen specification should be as straightforward as possible to |
479 | // only model the existing size and type combinations. |
480 | //===--------------------------------------------------------------------===// |
481 | // |
482 | // TODO: Tablegen all this. |
483 | //===--------------------------------------------------------------------===// |
484 | // m16n8k4 tf32 case. |
485 | //===--------------------------------------------------------------------===// |
486 | /// From the NVIDIA doc: |
487 | /// groupID = %laneid >> 2 |
488 | /// threadIDInGroup = %laneid % 4 |
489 | /// row = groupID for a0 |
490 | /// groupID + 8 for a1 |
491 | /// col = threadIDInGroup |
492 | static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) { |
493 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
494 | AffineExpr groupID = dim.floorDiv(v: 4); |
495 | AffineExpr threadIDInGroup = dim % 4; |
496 | return {RowColIndexing{groupID, threadIDInGroup}, |
497 | RowColIndexing{groupID + 8, threadIDInGroup}}; |
498 | } |
499 | |
500 | /// From the NVIDIA doc: |
501 | /// groupID = %laneid >> 2 |
502 | /// threadIDInGroup = %laneid % 4 |
503 | /// row = threadIDInGroup |
504 | /// col = groupID |
505 | static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) { |
506 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
507 | AffineExpr groupID = dim.floorDiv(v: 4); |
508 | AffineExpr threadIDInGroup = dim % 4; |
509 | return {RowColIndexing{threadIDInGroup, groupID}}; |
510 | } |
511 | |
512 | /// From the NVIDIA doc: |
513 | /// groupID = %laneid >> 2 |
514 | /// threadIDInGroup = %laneid % 4 |
515 | /// row = groupID for c0 and c1 |
516 | /// groupID + 8 for c2 and c3 |
517 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
518 | static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) { |
519 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
520 | AffineExpr groupID = dim.floorDiv(v: 4); |
521 | AffineExpr threadIDInGroup = dim % 4; |
522 | return {RowColIndexing{groupID, threadIDInGroup * 2 + 0}, |
523 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, |
524 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, |
525 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}}; |
526 | } |
527 | |
528 | //===--------------------------------------------------------------------===// |
529 | // m16n8k16 f16 case. |
530 | //===--------------------------------------------------------------------===// |
531 | /// From the NVIDIA doc: |
532 | /// groupID = %laneid >> 2 |
533 | /// threadIDInGroup = %laneid % 4 |
534 | /// |
535 | /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6 |
536 | /// groupID + 8 Otherwise |
537 | /// |
538 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4 |
539 | /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4 |
540 | static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) { |
541 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
542 | AffineExpr groupID = dim.floorDiv(v: 4); |
543 | AffineExpr threadIDInGroup = dim % 4; |
544 | // clang-format off |
545 | return { |
546 | RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
547 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
548 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
549 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3 |
550 | RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4 |
551 | RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5 |
552 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6 |
553 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7 |
554 | }; |
555 | // clang-format on |
556 | } |
557 | |
558 | /// From the NVIDIA doc: |
559 | /// groupID = %laneid >> 2 |
560 | /// threadIDInGroup = %laneid % 4 |
561 | /// |
562 | /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2 |
563 | /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2 |
564 | /// |
565 | /// col = groupID |
566 | static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) { |
567 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
568 | AffineExpr groupID = dim.floorDiv(v: 4); |
569 | AffineExpr threadIDInGroup = dim % 4; |
570 | // clang-format off |
571 | return { |
572 | RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0 |
573 | RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1 |
574 | RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2 |
575 | RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3 |
576 | }; |
577 | // clang-format on |
578 | } |
579 | |
580 | /// From the NVIDIA doc: |
581 | /// groupID = %laneid >> 2 |
582 | /// threadIDInGroup = %laneid % 4 |
583 | /// |
584 | /// row = groupID for ci where i < 2 |
585 | /// groupID + 8 for ci where i >= 2 |
586 | /// |
587 | /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3} |
588 | static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) { |
589 | auto dim = getAffineDimExpr(position: 0, context: ctx); |
590 | AffineExpr groupID = dim.floorDiv(v: 4); |
591 | AffineExpr threadIDInGroup = dim % 4; |
592 | // clang-format off |
593 | return { |
594 | RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0 |
595 | RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1 |
596 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2 |
597 | RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3 |
598 | }; |
599 | // clang-format on |
600 | } |
601 | |
602 | //===--------------------------------------------------------------------===// |
603 | /// Helper functions to create customizable load and stores operations. The |
604 | /// specific shapes of each MMA instruction are passed via the |
605 | /// IndexCalculator callback. |
606 | //===--------------------------------------------------------------------===// |
607 | /// Build a list of memref.load operations indexed at `(row, col)` indices |
608 | /// that make sense for a particular MMA instruction and specified via the |
609 | /// IndexCalculator callback. |
610 | SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc, |
611 | OpFoldResult laneId, Value memref, |
612 | const IndexCalculator &indexFn); |
613 | |
614 | /// Perform a distributed load of a vector operand of `vectorShape` for a |
615 | /// particular MMA instruction whose `(row, col)` indices are specified via |
616 | /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
617 | /// data that makes sense for the particular MMA operation. |
618 | /// The `vectorShape` matches existing NVGPU dialect op specification but |
619 | /// could also be flattened in the future if needed for simplification. |
620 | Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc, |
621 | OpFoldResult laneId, Value memref, |
622 | IndexCalculator indexFn, |
623 | ArrayRef<int64_t> vectorShape); |
624 | |
625 | /// Build a list of memref.store operations indexed at `(row, col)` indices |
626 | /// that make sense for a particular MMA instruction and specified via the |
627 | /// IndexCalculator callback. |
628 | SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc, |
629 | ValueRange toStore, |
630 | OpFoldResult laneId, Value memref, |
631 | const IndexCalculator &indexFn); |
632 | |
633 | /// Perform a distributed store of a vector operand of `vectorShape` for a |
634 | /// particular MMA instruction whose `(row, col)` indices are specified via |
635 | /// the IndexCalculator callback. Each `laneId` loads the subportion of the |
636 | /// data that makes sense for the particular MMA operation. |
637 | /// The `vectorShape` matches existing NVGPU dialect op specification but |
638 | /// could also be flattened in the future if needed for simplification. |
639 | SmallVector<Operation *> buildMmaSyncMemRefStoreOperand( |
640 | OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
641 | Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape); |
642 | |
643 | OpBuilder &b; |
644 | Location loc; |
645 | OpFoldResult laneId; |
646 | }; |
647 | |
648 | //===--------------------------------------------------------------------===// |
649 | /// Helper functions to create customizable load and stores operations. The |
650 | /// specific shapes of each MMA instruction are passed via the |
651 | /// IndexCalculator callback. |
652 | //===--------------------------------------------------------------------===// |
653 | |
654 | template <typename ApplyFn, typename ReduceFn> |
655 | static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, |
656 | ReduceFn reduceFn) { |
657 | VectorType vectorType = cast<VectorType>(vector.getType()); |
658 | auto vectorShape = vectorType.getShape(); |
659 | auto strides = computeStrides(vectorShape); |
660 | for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) { |
661 | auto indices = delinearize(idx, strides); |
662 | reduceFn(applyFn(vector, idx, indices), idx, indices); |
663 | } |
664 | } |
665 | |
666 | SmallVector<Value> |
667 | MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc, |
668 | OpFoldResult laneId, Value memref, |
669 | const IndexCalculator &indexFn) { |
670 | auto aff = [&](AffineExpr e) { |
671 | return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId); |
672 | }; |
673 | SmallVector<Value> res; |
674 | SmallVector<RowColIndexing> indexings = indexFn(b.getContext()); |
675 | for (auto indexing : indexings) { |
676 | Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row())); |
677 | Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col())); |
678 | auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col}); |
679 | res.push_back(Elt: load); |
680 | } |
681 | return res; |
682 | } |
683 | |
684 | Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand( |
685 | OpBuilder &b, Location loc, OpFoldResult laneId, Value memref, |
686 | IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
687 | auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn: std::move(indexFn)); |
688 | |
689 | Type elementType = getElementTypeOrSelf(type: memref.getType()); |
690 | auto vt = VectorType::get(vectorShape, elementType); |
691 | Value res = b.create<vector::SplatOp>(loc, vt, loads[0]); |
692 | foreachIndividualVectorElement( |
693 | vector: res, |
694 | /*applyFn=*/ |
695 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
696 | return loads[linearIdx]; |
697 | }, |
698 | /*reduceFn=*/ |
699 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
700 | res = b.create<vector::InsertOp>(loc, v, res, indices); |
701 | }); |
702 | |
703 | return res; |
704 | } |
705 | |
706 | SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores( |
707 | OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId, |
708 | Value memref, const IndexCalculator &indexFn) { |
709 | auto aff = [&](AffineExpr e) { |
710 | return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId); |
711 | }; |
712 | SmallVector<Operation *> res; |
713 | for (auto [indexing, val] : |
714 | llvm::zip_equal(t: indexFn(b.getContext()), u&: toStore)) { |
715 | Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row())); |
716 | Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col())); |
717 | Operation *store = |
718 | b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col}); |
719 | res.push_back(Elt: store); |
720 | } |
721 | return res; |
722 | } |
723 | |
724 | SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand( |
725 | OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId, |
726 | Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { |
727 | SmallVector<Value> toStore; |
728 | toStore.reserve(N: 32); |
729 | foreachIndividualVectorElement( |
730 | vector: vectorToStore, |
731 | /*applyFn=*/ |
732 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
733 | return b.create<vector::ExtractOp>(loc, vectorToStore, indices); |
734 | }, |
735 | /*reduceFn=*/ |
736 | [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) { |
737 | toStore.push_back(Elt: v); |
738 | }); |
739 | return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn: std::move(indexFn)); |
740 | } |
741 | |
742 | static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, |
743 | SmallVector<int64_t>> |
744 | makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, |
745 | ArrayRef<int64_t> res) { |
746 | SmallVector<int64_t> vlhs(lhs); |
747 | SmallVector<int64_t> vrhs(rhs); |
748 | SmallVector<int64_t> vres(res); |
749 | return std::make_tuple(args&: vlhs, args&: vrhs, args&: vres); |
750 | } |
751 | |
752 | FailureOr<MmaSyncBuilder::MmaSyncInfo> |
753 | MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape, |
754 | TypeRange elementalTypes) { |
755 | // TODO: Tablegen all this. |
756 | Type f16 = b.getF16Type(); |
757 | Type f32 = b.getF32Type(); |
758 | if (opShape == ArrayRef<int64_t>{16, 8, 4} && |
759 | elementalTypes == TypeRange{f32, f32, f32}) { |
760 | return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k4tf32Lhs, |
761 | args: &MmaSyncBuilder::m16n8k4tf32Rhs, |
762 | args: &MmaSyncBuilder::m16n8k4tf32Res), |
763 | .vectorShapes: makeVectorShapes(lhs: {2, 1}, rhs: {1, 1}, res: {2, 2}), |
764 | .mmaShape: SmallVector<int64_t>{opShape}, |
765 | /*tf32Enabled=*/true}; |
766 | } |
767 | // This is the version with f16 accumulation. |
768 | // TODO: version with f32 accumulation. |
769 | if (opShape == ArrayRef<int64_t>{16, 8, 16} && |
770 | elementalTypes == TypeRange{f16, f16, f16}) { |
771 | return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k16f16Lhs, |
772 | args: &MmaSyncBuilder::m16n8k16f16Rhs, |
773 | args: &MmaSyncBuilder::m16n8k16f16Res), |
774 | .vectorShapes: makeVectorShapes(lhs: {4, 2}, rhs: {2, 2}, res: {2, 2}), |
775 | .mmaShape: SmallVector<int64_t>{opShape}, |
776 | /*tf32Enabled=*/false}; |
777 | } |
778 | return failure(); |
779 | } |
780 | |
781 | FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { |
782 | Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get(); |
783 | Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get(); |
784 | Value resMemRef = linalgOp.getDpsInitOperand(0)->get(); |
785 | assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 && |
786 | "expected lhs to be a 2D memref"); |
787 | assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 && |
788 | "expected rhs to be a 2D memref"); |
789 | assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 && |
790 | "expected res to be a 2D memref"); |
791 | |
792 | int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0]; |
793 | int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1]; |
794 | int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1]; |
795 | Type lhsType = getElementTypeOrSelf(type: lhsMemRef.getType()); |
796 | Type rhsType = getElementTypeOrSelf(type: rhsMemRef.getType()); |
797 | Type resType = getElementTypeOrSelf(type: resMemRef.getType()); |
798 | |
799 | FailureOr<MmaSyncInfo> maybeInfo = |
800 | getIndexCalculators(opShape: {m, n, k}, elementalTypes: {lhsType, rhsType, resType}); |
801 | if (failed(Result: maybeInfo)) |
802 | return failure(); |
803 | |
804 | MmaSyncInfo info = *maybeInfo; |
805 | auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns; |
806 | auto [lhsShape, rhsShape, resShape] = info.vectorShapes; |
807 | Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: lhsMemRef, |
808 | indexFn: lhsIndexFn, vectorShape: lhsShape); |
809 | Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: rhsMemRef, |
810 | indexFn: rhsIndexFn, vectorShape: rhsShape); |
811 | Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: resMemRef, |
812 | indexFn: resIndexFn, vectorShape: resShape); |
813 | res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape, |
814 | info.tf32Enabled); |
815 | buildMmaSyncMemRefStoreOperand(b, loc, vectorToStore: res, laneId, memref: resMemRef, indexFn: resIndexFn, |
816 | vectorShape: resShape); |
817 | return res.getDefiningOp(); |
818 | } |
819 | |
820 | DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( |
821 | transform::TransformRewriter &rewriter, LinalgOp linalgOp, |
822 | transform::ApplyToEachResultList &results, |
823 | transform::TransformState &state) { |
824 | bool fail = true; |
825 | // TODO: more robust detection of matmulOp, with transposes etc. |
826 | if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) { |
827 | // Check to not let go the matmul with extended semantic, through this |
828 | // transform. |
829 | if (linalgOp.hasUserDefinedMaps()) { |
830 | return emitSilenceableError() |
831 | << "only matmul ops with non-extended semantics are supported"; |
832 | } |
833 | Location loc = linalgOp.getLoc(); |
834 | // TODO: more robust computation of laneId, for now assume a single warp. |
835 | Value laneId = rewriter.create<gpu::ThreadIdOp>( |
836 | loc, rewriter.getIndexType(), gpu::Dimension::x); |
837 | if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp))) |
838 | fail = false; |
839 | } |
840 | |
841 | if (fail) { |
842 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
843 | << "unsupported target op: "<< linalgOp; |
844 | diag.attachNote(linalgOp->getLoc()) << "target op"; |
845 | return diag; |
846 | } |
847 | |
848 | rewriter.eraseOp(linalgOp); |
849 | return DiagnosedSilenceableFailure::success(); |
850 | } |
851 | |
852 | //===----------------------------------------------------------------------===// |
853 | // Hopper builders. |
854 | //===----------------------------------------------------------------------===// |
855 | |
856 | /// Helper to create the base Hopper-specific operations that are reused in |
857 | /// various other places. |
858 | struct HopperBuilder { |
859 | HopperBuilder(RewriterBase &rewriter, Location loc) |
860 | : rewriter(rewriter), loc(loc) {} |
861 | |
862 | TypedValue<nvgpu::MBarrierGroupType> |
863 | buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); |
864 | |
865 | /// Create tma descriptor op to initiate transfer from global to shared |
866 | /// memory. This must be done before the launch op, on the host. |
867 | TypedValue<nvgpu::TensorMapDescriptorType> |
868 | buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
869 | gpu::LaunchOp launchOp); |
870 | |
871 | /// Build a tma load from global memory to shared memory using `barrier` to |
872 | /// synchronize. Return the number of bytes that will be transferred. |
873 | OpFoldResult |
874 | buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
875 | TypedValue<MemRefType> sharedMemref, |
876 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
877 | SmallVectorImpl<Operation *> &loadOps); |
878 | void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier, |
879 | ArrayRef<OpFoldResult> sizes); |
880 | |
881 | /// If threadIdx.x == 0 does TMA request + wait, else just wait. |
882 | /// Return the operation that performs the transfer on thread0. |
883 | // TODO: In the future, don't hardcode to thread 0 but elect a leader. |
884 | SmallVector<Operation *> buildPredicateLoadsOnThread0( |
885 | ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
886 | ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
887 | TypedValue<nvgpu::MBarrierGroupType> barrier); |
888 | |
889 | void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier); |
890 | |
891 | RewriterBase &rewriter; |
892 | Location loc; |
893 | }; |
894 | |
895 | SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0( |
896 | ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, |
897 | ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, |
898 | TypedValue<nvgpu::MBarrierGroupType> barrier) { |
899 | SmallVector<Operation *> loadOps; |
900 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
901 | Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x); |
902 | Value cond = |
903 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero); |
904 | // clang-format off |
905 | rewriter.create<scf::IfOp>( |
906 | /*location=*/loc, |
907 | /*conditional=*/cond, |
908 | /*thenBuilder=*/ |
909 | [&](OpBuilder &lb, Location loc) { |
910 | SmallVector<OpFoldResult> sizes; |
911 | sizes.reserve(N: globalDescriptors.size()); |
912 | for (auto [desc, shmem] : llvm::zip_equal( |
913 | globalDescriptors, sharedMemBuffers)) { |
914 | OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps); |
915 | sizes.push_back(sz); |
916 | } |
917 | // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load. |
918 | // This may or may not have perf implications. |
919 | buildBarrierArriveTx(barrier, sizes); |
920 | rewriter.create<scf::YieldOp>(loc); |
921 | }, |
922 | /*elseBuilder=*/ |
923 | [&](OpBuilder &lb, Location loc) { |
924 | // TODO: is this for no-thread divergence? |
925 | // Should we just yield the size and hoist? |
926 | buildBarrierArriveTx(barrier, sizes: getAsIndexOpFoldResult(ctx: rewriter.getContext(), val: 0)); |
927 | rewriter.create<scf::YieldOp>(loc); |
928 | }); |
929 | // clang-format on |
930 | return loadOps; |
931 | } |
932 | |
933 | static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { |
934 | return gpu::AddressSpaceAttr::get( |
935 | b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); |
936 | // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace)); |
937 | } |
938 | |
939 | TypedValue<nvgpu::MBarrierGroupType> |
940 | HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { |
941 | auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter); |
942 | Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>( |
943 | loc, |
944 | nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); |
945 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
946 | rewriter.create<nvgpu::MBarrierInitOp>( |
947 | loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), |
948 | zero, Value()); |
949 | rewriter.create<gpu::BarrierOp>(loc); |
950 | return cast<TypedValue<nvgpu::MBarrierGroupType>>(Val&: barrier); |
951 | } |
952 | |
953 | TypedValue<nvgpu::TensorMapDescriptorType> |
954 | HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, |
955 | gpu::LaunchOp launchOp) { |
956 | OpBuilder::InsertionGuard guard(rewriter); |
957 | rewriter.setInsertionPoint(launchOp); |
958 | Value unrankedMemRef = rewriter.create<memref::CastOp>( |
959 | loc, |
960 | UnrankedMemRefType::get(memref.getType().getElementType(), |
961 | memref.getType().getMemorySpace()), |
962 | memref); |
963 | SmallVector<OpFoldResult> mixedSizes = |
964 | memref::getMixedSizes(builder&: rewriter, loc, value: memref); |
965 | SmallVector<Value> sizes = |
966 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: mixedSizes); |
967 | |
968 | auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter); |
969 | Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>( |
970 | loc, |
971 | nvgpu::TensorMapDescriptorType::get( |
972 | rewriter.getContext(), |
973 | MemRefType::Builder(memref.getType()) |
974 | .setMemorySpace(sharedMemorySpace), |
975 | TensorMapSwizzleKind::SWIZZLE_NONE, |
976 | TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, |
977 | TensorMapInterleaveKind::INTERLEAVE_NONE), |
978 | unrankedMemRef, sizes); |
979 | return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc); |
980 | } |
981 | |
982 | OpFoldResult HopperBuilder::buildTmaAsyncLoad( |
983 | TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, |
984 | TypedValue<MemRefType> sharedMemref, |
985 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
986 | SmallVectorImpl<Operation *> &loadOps) { |
987 | MLIRContext *ctx = rewriter.getContext(); |
988 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
989 | Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>( |
990 | loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero, |
991 | Value(), Value()); |
992 | loadOps.push_back(Elt: loadOp); |
993 | auto mixedSizes = memref::getMixedSizes(builder&: rewriter, loc, value: sharedMemref); |
994 | SmallVector<AffineExpr> symbols(mixedSizes.size()); |
995 | bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols}); |
996 | AffineExpr prodExprInBytes = |
997 | computeProduct(ctx, basis: symbols) * |
998 | (sharedMemref.getType().getElementTypeBitWidth() / 8); |
999 | auto res = affine::makeComposedFoldedAffineApply(b&: rewriter, loc, |
1000 | expr: prodExprInBytes, operands: mixedSizes); |
1001 | return res; |
1002 | } |
1003 | |
1004 | void HopperBuilder::buildBarrierArriveTx( |
1005 | TypedValue<nvgpu::MBarrierGroupType> barrier, |
1006 | ArrayRef<OpFoldResult> mixedSizes) { |
1007 | assert(!mixedSizes.empty() && "expecte non-empty sizes"); |
1008 | MLIRContext *ctx = rewriter.getContext(); |
1009 | SmallVector<AffineExpr> symbols(mixedSizes.size()); |
1010 | bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols}); |
1011 | AffineExpr sumExpr = computeSum(ctx, basis: symbols); |
1012 | OpFoldResult size = |
1013 | affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: sumExpr, operands: mixedSizes); |
1014 | Value sizeVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: size); |
1015 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1016 | rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero, |
1017 | Value()); |
1018 | } |
1019 | |
1020 | void HopperBuilder::buildTryWaitParity( |
1021 | TypedValue<nvgpu::MBarrierGroupType> barrier) { |
1022 | Type i1 = rewriter.getI1Type(); |
1023 | Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0); |
1024 | // 10M is an arbitrary, not too small or too big number to specify the number |
1025 | // of ticks before retry. |
1026 | // TODO: hoist this in a default dialect constant. |
1027 | Value ticksBeforeRetry = |
1028 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 10000000); |
1029 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
1030 | rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity, |
1031 | ticksBeforeRetry, zero); |
1032 | } |
1033 | |
1034 | //===----------------------------------------------------------------------===// |
1035 | // RewriteCopyAsTmaOp |
1036 | //===----------------------------------------------------------------------===// |
1037 | |
1038 | /// Helper to create the tma operations corresponding to `linalg::CopyOp`. |
1039 | struct CopyBuilder : public HopperBuilder { |
1040 | CopyBuilder(RewriterBase &rewriter, Location loc) |
1041 | : HopperBuilder(rewriter, loc) {} |
1042 | |
1043 | SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps); |
1044 | }; |
1045 | |
1046 | SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { |
1047 | MLIRContext *ctx = rewriter.getContext(); |
1048 | if (copyOps.empty()) |
1049 | return SmallVector<Operation *>(); |
1050 | |
1051 | auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>(); |
1052 | assert(launchOp && "expected launch op"); |
1053 | |
1054 | // 1. Init a barrier object in shared memory. |
1055 | OpBuilder::InsertionGuard g(rewriter); |
1056 | rewriter.setInsertionPoint(copyOps.front()); |
1057 | AffineExpr bx, by, bz; |
1058 | bindSymbols(ctx, exprs&: bx, exprs&: by, exprs&: bz); |
1059 | AffineExpr prod = computeProduct(ctx, basis: ArrayRef<AffineExpr>{bx, by, bz}); |
1060 | OpFoldResult numThreads = affine::makeComposedFoldedAffineApply( |
1061 | b&: rewriter, loc, expr: prod, |
1062 | operands: ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), |
1063 | launchOp.getBlockSizeZ()}); |
1064 | |
1065 | TypedValue<nvgpu::MBarrierGroupType> barrier = |
1066 | buildAndInitBarrierInSharedMemory(numThreads); |
1067 | |
1068 | SmallVector<TypedValue<MemRefType>> shmems; |
1069 | SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs; |
1070 | for (Operation *op : copyOps) { |
1071 | auto copyOp = cast<linalg::CopyOp>(op); |
1072 | auto inMemRef = |
1073 | cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get()); |
1074 | assert(inMemRef.getType().getRank() == 2 && |
1075 | "expected in to be a 2D memref"); |
1076 | |
1077 | // 2. Build global memory descriptor. |
1078 | TypedValue<nvgpu::TensorMapDescriptorType> globalDesc = |
1079 | buildGlobalMemRefDescriptor(inMemRef, launchOp); |
1080 | globalDescs.push_back(globalDesc); |
1081 | |
1082 | // 3. Shared memory and descriptor for the tmp array. |
1083 | auto shmem = |
1084 | cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get()); |
1085 | shmems.push_back(Elt: shmem); |
1086 | } |
1087 | |
1088 | // 4. Load in from global memory to shared memory using tma. |
1089 | OpBuilder::InsertionGuard g2(rewriter); |
1090 | rewriter.setInsertionPoint(copyOps.front()); |
1091 | SmallVector<Operation *> results = |
1092 | buildPredicateLoadsOnThread0(globalDescs, shmems, barrier); |
1093 | |
1094 | // 5. Spin-loop until data is ready. |
1095 | buildTryWaitParity(barrier); |
1096 | |
1097 | // 6. Erase the ops that have now been rewritten. |
1098 | for (Operation *op : copyOps) |
1099 | rewriter.eraseOp(op); |
1100 | |
1101 | return results; |
1102 | } |
1103 | |
1104 | DiagnosedSilenceableFailure |
1105 | transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, |
1106 | transform::TransformResults &results, |
1107 | transform::TransformState &state) { |
1108 | auto payloadOps = state.getPayloadOps(getTarget()); |
1109 | gpu::LaunchOp commonLaunchOp; |
1110 | Operation *firstOp, *failingOp; |
1111 | if (llvm::any_of(payloadOps, [&](Operation *op) { |
1112 | if (!commonLaunchOp) { |
1113 | commonLaunchOp = op->getParentOfType<gpu::LaunchOp>(); |
1114 | firstOp = op; |
1115 | } |
1116 | auto fail = !op->getParentOfType<gpu::LaunchOp>() || |
1117 | commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() || |
1118 | !isa<linalg::CopyOp>(op); |
1119 | if (fail) |
1120 | failingOp = op; |
1121 | return fail; |
1122 | })) { |
1123 | DiagnosedSilenceableFailure diag = |
1124 | emitSilenceableError() |
1125 | << "target ops must be linalg::CopyOp nested under a common " |
1126 | "gpu.LaunchOp to be rewritten because the tma descriptors need to " |
1127 | "be created on the host.\nBut got: " |
1128 | << *firstOp << "\nand "<< *failingOp; |
1129 | return diag; |
1130 | } |
1131 | |
1132 | // TODO: more robust detection of copy, with transposes etc. |
1133 | CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps)); |
1134 | |
1135 | return DiagnosedSilenceableFailure::success(); |
1136 | } |
1137 | |
1138 | //===----------------------------------------------------------------------===// |
1139 | // Transform op registration |
1140 | //===----------------------------------------------------------------------===// |
1141 | |
1142 | namespace { |
1143 | class NVGPUTransformDialectExtension |
1144 | : public transform::TransformDialectExtension< |
1145 | NVGPUTransformDialectExtension> { |
1146 | public: |
1147 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) |
1148 | |
1149 | NVGPUTransformDialectExtension() { |
1150 | declareGeneratedDialect<arith::ArithDialect>(); |
1151 | declareGeneratedDialect<affine::AffineDialect>(); |
1152 | declareGeneratedDialect<nvgpu::NVGPUDialect>(); |
1153 | declareGeneratedDialect<NVVM::NVVMDialect>(); |
1154 | declareGeneratedDialect<vector::VectorDialect>(); |
1155 | registerTransformOps< |
1156 | #define GET_OP_LIST |
1157 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
1158 | >(); |
1159 | } |
1160 | }; |
1161 | } // namespace |
1162 | |
1163 | #define GET_OP_CLASSES |
1164 | #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc" |
1165 | |
1166 | void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
1167 | registry.addExtensions<NVGPUTransformDialectExtension>(); |
1168 | } |
1169 |
Definitions
- hasDefaultMemorySpace
- hasSharedMemorySpace
- getValueLoadedFromGlobal
- isStoreToShared
- isLoadFromGlobalStoredToShared
- collectStage0PipeliningOps
- setAsyncWaitGroupsInFlight
- getPipelineStages
- replaceOpWithPredicatedOp
- pipelineForSharedCopies
- RowColIndexing
- RowColIndexing
- row
- col
- MmaSyncBuilder
- MmaSyncBuilder
- MmaSyncInfo
- m16n8k4tf32Lhs
- m16n8k4tf32Rhs
- m16n8k4tf32Res
- m16n8k16f16Lhs
- m16n8k16f16Rhs
- m16n8k16f16Res
- foreachIndividualVectorElement
- buildMemRefLoads
- buildMmaSyncMemRefLoadOperand
- buildMemRefStores
- buildMmaSyncMemRefStoreOperand
- makeVectorShapes
- getIndexCalculators
- buildMmaSync
- HopperBuilder
- HopperBuilder
- buildPredicateLoadsOnThread0
- getSharedAddressSpaceAttribute
- buildAndInitBarrierInSharedMemory
- buildGlobalMemRefDescriptor
- buildTmaAsyncLoad
- buildBarrierArriveTx
- buildTryWaitParity
- CopyBuilder
- CopyBuilder
- rewrite
- NVGPUTransformDialectExtension
- NVGPUTransformDialectExtension
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more