1//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SCF/Transforms/Transforms.h"
26#include "mlir/Dialect/Utils/IndexingUtils.h"
27#include "mlir/Dialect/Utils/StaticValueUtils.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/IR/AffineExpr.h"
30#include "mlir/IR/BuiltinTypes.h"
31#include "mlir/IR/Value.h"
32#include "llvm/ADT/ArrayRef.h"
33
34using namespace mlir;
35using namespace mlir::linalg;
36using namespace mlir::nvgpu;
37using namespace mlir::NVVM;
38using namespace mlir::transform;
39
40#define DEBUG_TYPE "nvgpu-transforms"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42#define DBGSNL() (llvm::dbgs() << "\n")
43#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
44
45//===----------------------------------------------------------------------===//
46// Apply...ConversionPatternsOp
47//===----------------------------------------------------------------------===//
48
49void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
50 TypeConverter &typeConverter, RewritePatternSet &patterns) {
51 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
52 /// device-side async tokens cannot be materialized in nvvm. We just
53 /// convert them to a dummy i32 type in order to easily drop them during
54 /// conversion.
55 populateGpuMemorySpaceAttributeConversions(
56 llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
57 switch (space) {
58 case gpu::AddressSpace::Global:
59 return static_cast<unsigned>(
60 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
61 case gpu::AddressSpace::Workgroup:
62 return static_cast<unsigned>(
63 NVVM::NVVMMemorySpace::kSharedMemorySpace);
64 case gpu::AddressSpace::Private:
65 return 0;
66 }
67 llvm_unreachable("unknown address space enum value");
68 return 0;
69 });
70 llvmTypeConverter.addConversion(
71 [&](nvgpu::DeviceAsyncTokenType type) -> Type {
72 return llvmTypeConverter.convertType(
73 IntegerType::get(type.getContext(), 32));
74 });
75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
76 return llvmTypeConverter.convertType(
77 IntegerType::get(type.getContext(), 64));
78 });
79 llvmTypeConverter.addConversion(
80 [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
81 Type elemType = type.getFragmented().getElementType();
82 int64_t sizeM = type.getFragmented().getDimSize(0);
83 int64_t sizeN = type.getFragmented().getDimSize(1);
84
85 unsigned numMembers;
86 if (elemType.isF32() || elemType.isInteger(32))
87 numMembers = sizeN / 2;
88 else if (elemType.isF16())
89 numMembers = sizeN / 4;
90 else
91 llvm_unreachable("unsupported type for warpgroup accumulator");
92
93 SmallVector<Type> innerStructBody;
94 for (unsigned i = 0; i < numMembers; i++)
95 innerStructBody.push_back(elemType);
96 auto innerStructType = LLVM::LLVMStructType::getLiteral(
97 type.getContext(), innerStructBody);
98
99 SmallVector<Type> structBody;
100 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
101 structBody.push_back(innerStructType);
102
103 auto convertedType =
104 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105 return llvmTypeConverter.convertType(convertedType);
106 });
107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
108 return llvmTypeConverter.convertType(
109 getMBarrierMemrefType(type.getContext(), type));
110 });
111 llvmTypeConverter.addConversion(
112 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
113 return llvmTypeConverter.convertType(
114 IntegerType::get(type.getContext(), 64));
115 });
116 llvmTypeConverter.addConversion(
117 [&](nvgpu::TensorMapDescriptorType type) -> Type {
118 return LLVM::LLVMPointerType::get(type.getContext());
119 });
120 populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
121}
122
123LogicalResult
124transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125 transform::TypeConverterBuilderOpInterface builder) {
126 if (builder.getTypeConverterType() != "LLVMTypeConverter")
127 return emitOpError("expected LLVMTypeConverter");
128 return success();
129}
130
131//===---------------------------------------------------------------------===//
132// CreateAsyncGroupsOp
133//===---------------------------------------------------------------------===//
134
135void transform::CreateAsyncGroupsOp::getEffects(
136 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
137 transform::consumesHandle(getTarget(), effects);
138 transform::producesHandle(getResult(), effects);
139 transform::modifiesPayload(effects);
140}
141
142DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
143 TransformRewriter &rewriter, Operation *target,
144 ApplyToEachResultList &results, TransformState &state) {
145 nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
146 results.push_back(target);
147 return DiagnosedSilenceableFailure::success();
148}
149
150//===----------------------------------------------------------------------===//
151// PipelineSharedMemoryCopiesOp
152//===----------------------------------------------------------------------===//
153
154/// Returns true if the given type has the default memory space.
155static bool hasDefaultMemorySpace(BaseMemRefType type) {
156 return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
157}
158
159/// Returns true if the given type has the shared (workgroup) memory space.
160static bool hasSharedMemorySpace(BaseMemRefType type) {
161 auto space =
162 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
163 return space &&
164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
165}
166
167/// Returns the value produced by a load from the default memory space. Returns
168/// null if the operation is not such a load.
169static Value getValueLoadedFromGlobal(Operation *op) {
170 // TODO: consider an interface or leveraging the memory effects interface.
171 auto load = dyn_cast<vector::TransferReadOp>(op);
172 if (!load)
173 return nullptr;
174
175 auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
176 if (!loadType || !hasDefaultMemorySpace(loadType))
177 return nullptr;
178 return load;
179}
180
181/// Returns true if the operation is storing the given value into shared memory.
182static bool isStoreToShared(Operation *op, Value v) {
183 // TOD: consider an interface or leveraging the memory effects interface.
184 auto store = dyn_cast<vector::TransferWriteOp>(op);
185 if (!store || store.getVector() != v)
186 return false;
187
188 auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
189 return storeType || hasSharedMemorySpace(storeType);
190}
191
192/// Returns true if the operation is a load from the default memory space the
193/// result of which is only stored into the shared memory space.
194static bool isLoadFromGlobalStoredToShared(Operation *op) {
195 Value loaded = getValueLoadedFromGlobal(op);
196 if (!loaded || !loaded.hasOneUse())
197 return false;
198
199 return isStoreToShared(op: *loaded.getUsers().begin(), v: loaded);
200}
201
202/// Populate `ops` with the set of operations that belong to the stage 0 of the
203/// pipelined version of the given loop when pipelining copies to shared memory.
204/// Specifically, this collects:
205///
206/// 1. all loads from global memory, both sync and async;
207/// 2. the barriers for async loads.
208///
209/// In particular, barriers are omitted if they do not dominate at least one
210/// async load for which there is not yet a barrier.
211static LogicalResult
212collectStage0PipeliningOps(scf::ForOp forOp,
213 llvm::SmallPtrSet<Operation *, 16> &ops) {
214
215 llvm::SmallPtrSet<Operation *, 4> barriers;
216 for (Operation &op : *forOp.getBody()) {
217 // Bail on nested ops for now.
218 if (op.getNumRegions() > 0)
219 return failure();
220
221 if (isa<gpu::BarrierOp>(op)) {
222 barriers.insert(&op);
223 continue;
224 }
225
226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
227 ops.insert(&op);
228 ops.insert(std::make_move_iterator(barriers.begin()),
229 std::make_move_iterator(barriers.end()));
230 assert(barriers.empty() &&
231 "expected to have moved the barriers into another set");
232 continue;
233 }
234
235 if (isLoadFromGlobalStoredToShared(&op)) {
236 ops.insert(&op);
237 continue;
238 }
239 }
240
241 return success();
242}
243
244/// Hook for the loop pipeliner that sets the "num groups in flight" attribute
245/// of async wait operations corresponding to pipelined shared memory copies.
246// TODO: this currently assumes that there are no groups that could be in flight
247// in the existing code.
248static void
249setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
250 scf::PipeliningOption::PipelinerPart part,
251 unsigned iteration, unsigned depth) {
252 // Based on the order of copies within the loop we need to set the number
253 // of copies in flight, unless it is already set.
254 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255 if (!waitOp || waitOp.getNumGroups())
256 return;
257
258 int numGroupInFlight = 0;
259 if (part == scf::PipeliningOption::PipelinerPart::Kernel ||
260 part == scf::PipeliningOption::PipelinerPart::Prologue) {
261 numGroupInFlight = depth - 1;
262 } else {
263 // By construction there should be no wait op in the prologue as all the
264 // wait should be in the last stage.
265 assert(part == scf::PipeliningOption::PipelinerPart::Epilogue);
266 // Based on the schedule we pick we know how many groups are in flight for
267 // each iteration of the epilogue.
268 numGroupInFlight = depth - 1 - iteration;
269 }
270 waitOp.setNumGroups(numGroupInFlight);
271}
272
273/// Hook for the loop pipeliner that populates `ops` with the stage information
274/// as follows:
275///
276/// - operations in `stage0Ops` (typically loads from global memory and
277/// related barriers) are at stage 0;
278/// - operations in the backward slice of any stage0Ops are all at stage 0;
279/// - other operations are at stage `depth`;
280/// - the internal order of the pipelined loop has ops at stage `depth` first,
281/// then those at stage 0, with relative order within each group preserved.
282///
283static void getPipelineStages(
284 scf::ForOp forOp,
285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
286 unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
287 SetVector<Operation *> dependencies;
288 BackwardSliceOptions options([&](Operation *visited) {
289 return visited->getBlock() == forOp.getBody();
290 });
291 options.inclusive = true;
292 for (Operation &op : forOp.getBody()->getOperations()) {
293 if (stage0Ops.contains(&op))
294 getBackwardSlice(&op, &dependencies, options);
295 }
296
297 for (Operation &op : forOp.getBody()->getOperations()) {
298 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299 opsWithPipelineStages.emplace_back(&op, depth);
300 }
301 for (Operation &op : forOp.getBody()->getOperations()) {
302 if (dependencies.contains(&op))
303 opsWithPipelineStages.emplace_back(&op, 0);
304 }
305}
306
307/// Hook for the loop pipeliner. Replaces op with a predicated version and
308/// returns the resulting operation. Returns the original op if the predication
309/// isn't necessary for the given op. Returns null if predication is needed but
310/// not supported.
311static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
312 Operation *op, Value predicate) {
313 // Some operations may be fine to execute "speculatively" more times than the
314 // original number of iterations, in particular side-effect free operations
315 // and barriers, even if they cannot be predicated.
316 if (isMemoryEffectFree(op) ||
317 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318 nvgpu::DeviceAsyncWaitOp>(op)) {
319 return op;
320 }
321
322 // Otherwise, only async copies can currently be predicated.
323 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
324 if (!asyncCopyOp)
325 return nullptr;
326
327 // Create srcElement Value based on `predicate`. The next lines generate
328 // the following code:
329 //
330 // srcElement = (pred) ? prevSrcElements : 0;
331 //
332 Location loc = asyncCopyOp->getLoc();
333 Value dstElements =
334 rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
335 Value originalSrcElement =
336 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
337 Value c0Index = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
338 auto srcElements = rewriter.create<arith::SelectOp>(
339 loc, predicate, originalSrcElement, c0Index);
340 auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
341 loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
342 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
344 UnitAttr());
345 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346 return asyncCopyZeroFillOp;
347}
348
349/// Applies loop pipelining with the given depth to the given loop so that
350/// copies into the shared memory are pipelined. Doesn't affect other loops.
351/// Returns a pair containing the error state and the pipelined op, the latter
352/// being null in case of any failure. The error state contains a definite error
353/// if the IR has been modified and a silenceable error otherwise.
354static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
355pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
356 bool epiloguePeeling) {
357 llvm::SmallPtrSet<Operation *, 16> stage0Ops;
358 if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
359 return std::make_tuple(
360 emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
361 scf::ForOp());
362 }
363 if (stage0Ops.empty()) {
364 return std::make_tuple(
365 emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
366 }
367
368 scf::PipeliningOption options;
369 unsigned maxDepth = depth;
370 auto setAnnotation = [&](Operation *op,
371 scf::PipeliningOption::PipelinerPart part,
372 unsigned iteration) {
373 return setAsyncWaitGroupsInFlight(builder&: rewriter, op, part, iteration, depth: maxDepth);
374 };
375 options.getScheduleFn =
376 [&](scf::ForOp schedulingFor,
377 std::vector<std::pair<Operation *, unsigned>> &ops) {
378 if (schedulingFor != forOp)
379 return;
380 return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
381 };
382 options.annotateFn = setAnnotation;
383 if (!epiloguePeeling) {
384 options.peelEpilogue = false;
385 options.predicateFn = replaceOpWithPredicatedOp;
386 }
387
388 OpBuilder::InsertionGuard guard(rewriter);
389 rewriter.setInsertionPoint(forOp);
390 bool modifiedIR;
391 FailureOr<scf::ForOp> maybePipelined =
392 pipelineForLoop(rewriter, forOp, options, &modifiedIR);
393 if (succeeded(maybePipelined)) {
394 return std::make_tuple(DiagnosedSilenceableFailure::success(),
395 *maybePipelined);
396 }
397 return std::make_tuple(
398 modifiedIR
399 ? DiagnosedSilenceableFailure::definiteFailure()
400 : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
401 scf::ForOp());
402}
403
404DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
405 TransformRewriter &rewriter, scf::ForOp forOp,
406 ApplyToEachResultList &results, TransformState &state) {
407 auto [diag, pipelined] = pipelineForSharedCopies(
408 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
409 if (diag.succeeded()) {
410 results.push_back(pipelined);
411 return DiagnosedSilenceableFailure::success();
412 }
413 if (diag.isDefiniteFailure()) {
414 auto diag = emitDefiniteFailure("irreversible pipelining failure");
415 if (!getPeelEpilogue()) {
416 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
417 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
418 }
419 return diag;
420 }
421
422 return std::move(diag);
423}
424
425//===----------------------------------------------------------------------===//
426// RewriteMatmulAsMmaSyncOp
427//===----------------------------------------------------------------------===//
428
429/// Helper struct to encode a pair of row/column indexings in the form of
430/// affine expressions.
431struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
432 RowColIndexing(AffineExpr row, AffineExpr col)
433 : std::pair<AffineExpr, AffineExpr>(row, col) {}
434
435 AffineExpr row() const { return first; };
436 AffineExpr col() const { return second; };
437
438 void print(llvm::raw_ostream &os) const {
439 os << "- indexing: " << first << ", " << second;
440 }
441};
442
443/// Helper struct to provide a simple mapping from matmul operations to the
444/// corresponding mma.sync operation. This is constrained to the case where the
445/// matmul matches the mma.sync operation 1-1.
446struct MmaSyncBuilder {
447 MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
448 : b(b), loc(loc), laneId(laneId) {}
449
450 using IndexCalculator =
451 std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
452
453 /// Create the mma.sync operation corresponding to `linalgOp` along with all
454 /// the supporting load/store and vector operations.
455 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
456
457private:
458 struct MmaSyncInfo {
459 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
460 std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
461 vectorShapes;
462 SmallVector<int64_t> mmaShape;
463 bool tf32Enabled;
464 };
465
466 /// Return the specific index calculator for the given `linalgOp` or failure
467 /// if the op is not supported. This is the toplevel switch that should just
468 /// be Tablegen'd in the future.
469 FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
470 TypeRange elementalTypes);
471
472 //===--------------------------------------------------------------------===//
473 // Instruction-specific row, column indexing expression builders.
474 // These should all be declaratively specified via Tablegen in the future.
475 // The Tablegen specification should be as straightforward as possible to
476 // only model the existing size and type combinations.
477 //===--------------------------------------------------------------------===//
478 //
479 // TODO: Tablegen all this.
480 //===--------------------------------------------------------------------===//
481 // m16n8k4 tf32 case.
482 //===--------------------------------------------------------------------===//
483 /// From the NVIDIA doc:
484 /// groupID = %laneid >> 2
485 /// threadIDInGroup = %laneid % 4
486 /// row = groupID for a0
487 /// groupID + 8 for a1
488 /// col = threadIDInGroup
489 static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
490 auto dim = getAffineDimExpr(position: 0, context: ctx);
491 AffineExpr groupID = dim.floorDiv(v: 4);
492 AffineExpr threadIDInGroup = dim % 4;
493 return {RowColIndexing{groupID, threadIDInGroup},
494 RowColIndexing{groupID + 8, threadIDInGroup}};
495 }
496
497 /// From the NVIDIA doc:
498 /// groupID = %laneid >> 2
499 /// threadIDInGroup = %laneid % 4
500 /// row = threadIDInGroup
501 /// col = groupID
502 static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
503 auto dim = getAffineDimExpr(position: 0, context: ctx);
504 AffineExpr groupID = dim.floorDiv(v: 4);
505 AffineExpr threadIDInGroup = dim % 4;
506 return {RowColIndexing{threadIDInGroup, groupID}};
507 }
508
509 /// From the NVIDIA doc:
510 /// groupID = %laneid >> 2
511 /// threadIDInGroup = %laneid % 4
512 /// row = groupID for c0 and c1
513 /// groupID + 8 for c2 and c3
514 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
515 static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
516 auto dim = getAffineDimExpr(position: 0, context: ctx);
517 AffineExpr groupID = dim.floorDiv(v: 4);
518 AffineExpr threadIDInGroup = dim % 4;
519 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
520 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
521 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
522 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
523 }
524
525 //===--------------------------------------------------------------------===//
526 // m16n8k16 f16 case.
527 //===--------------------------------------------------------------------===//
528 /// From the NVIDIA doc:
529 /// groupID = %laneid >> 2
530 /// threadIDInGroup = %laneid % 4
531 ///
532 /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
533 /// groupID + 8 Otherwise
534 ///
535 /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
536 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
537 static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
538 auto dim = getAffineDimExpr(position: 0, context: ctx);
539 AffineExpr groupID = dim.floorDiv(v: 4);
540 AffineExpr threadIDInGroup = dim % 4;
541 // clang-format off
542 return {
543 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
544 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
545 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
546 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
547 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
548 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
549 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
550 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
551 };
552 // clang-format on
553 }
554
555 /// From the NVIDIA doc:
556 /// groupID = %laneid >> 2
557 /// threadIDInGroup = %laneid % 4
558 ///
559 /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
560 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
561 ///
562 /// col = groupID
563 static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
564 auto dim = getAffineDimExpr(position: 0, context: ctx);
565 AffineExpr groupID = dim.floorDiv(v: 4);
566 AffineExpr threadIDInGroup = dim % 4;
567 // clang-format off
568 return {
569 RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
570 RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
571 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
572 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
573 };
574 // clang-format on
575 }
576
577 /// From the NVIDIA doc:
578 /// groupID = %laneid >> 2
579 /// threadIDInGroup = %laneid % 4
580 ///
581 /// row = groupID for ci where i < 2
582 /// groupID + 8 for ci where i >= 2
583 ///
584 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
585 static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
586 auto dim = getAffineDimExpr(position: 0, context: ctx);
587 AffineExpr groupID = dim.floorDiv(v: 4);
588 AffineExpr threadIDInGroup = dim % 4;
589 // clang-format off
590 return {
591 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
592 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
593 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
594 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
595 };
596 // clang-format on
597 }
598
599 //===--------------------------------------------------------------------===//
600 /// Helper functions to create customizable load and stores operations. The
601 /// specific shapes of each MMA instruction are passed via the
602 /// IndexCalculator callback.
603 //===--------------------------------------------------------------------===//
604 /// Build a list of memref.load operations indexed at `(row, col)` indices
605 /// that make sense for a particular MMA instruction and specified via the
606 /// IndexCalculator callback.
607 SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
608 OpFoldResult laneId, Value memref,
609 const IndexCalculator &indexFn);
610
611 /// Perform a distributed load of a vector operand of `vectorShape` for a
612 /// particular MMA instruction whose `(row, col)` indices are specified via
613 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
614 /// data that makes sense for the particular MMA operation.
615 /// The `vectorShape` matches existing NVGPU dialect op specification but
616 /// could also be flattened in the future if needed for simplification.
617 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
618 OpFoldResult laneId, Value memref,
619 IndexCalculator indexFn,
620 ArrayRef<int64_t> vectorShape);
621
622 /// Build a list of memref.store operations indexed at `(row, col)` indices
623 /// that make sense for a particular MMA instruction and specified via the
624 /// IndexCalculator callback.
625 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
626 ValueRange toStore,
627 OpFoldResult laneId, Value memref,
628 const IndexCalculator &indexFn);
629
630 /// Perform a distributed store of a vector operand of `vectorShape` for a
631 /// particular MMA instruction whose `(row, col)` indices are specified via
632 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
633 /// data that makes sense for the particular MMA operation.
634 /// The `vectorShape` matches existing NVGPU dialect op specification but
635 /// could also be flattened in the future if needed for simplification.
636 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
637 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
638 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
639
640 OpBuilder &b;
641 Location loc;
642 OpFoldResult laneId;
643};
644
645//===--------------------------------------------------------------------===//
646/// Helper functions to create customizable load and stores operations. The
647/// specific shapes of each MMA instruction are passed via the
648/// IndexCalculator callback.
649//===--------------------------------------------------------------------===//
650
651template <typename ApplyFn, typename ReduceFn>
652static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
653 ReduceFn reduceFn) {
654 VectorType vectorType = cast<VectorType>(vector.getType());
655 auto vectorShape = vectorType.getShape();
656 auto strides = computeStrides(vectorShape);
657 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
658 auto indices = delinearize(idx, strides);
659 reduceFn(applyFn(vector, idx, indices), idx, indices);
660 }
661}
662
663SmallVector<Value>
664MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
665 OpFoldResult laneId, Value memref,
666 const IndexCalculator &indexFn) {
667 auto aff = [&](AffineExpr e) {
668 return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId);
669 };
670 SmallVector<Value> res;
671 SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
672 for (auto indexing : indexings) {
673 Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row()));
674 Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col()));
675 auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
676 res.push_back(Elt: load);
677 }
678 return res;
679}
680
681Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
682 OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
683 IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
684 auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn: std::move(indexFn));
685
686 Type elementType = getElementTypeOrSelf(type: memref.getType());
687 auto vt = VectorType::get(vectorShape, elementType);
688 Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
689 foreachIndividualVectorElement(
690 vector: res,
691 /*applyFn=*/
692 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
693 return loads[linearIdx];
694 },
695 /*reduceFn=*/
696 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
697 res = b.create<vector::InsertOp>(loc, v, res, indices);
698 });
699
700 return res;
701}
702
703SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
704 OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
705 Value memref, const IndexCalculator &indexFn) {
706 auto aff = [&](AffineExpr e) {
707 return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId);
708 };
709 SmallVector<Operation *> res;
710 for (auto [indexing, val] :
711 llvm::zip_equal(t: indexFn(b.getContext()), u&: toStore)) {
712 Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row()));
713 Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col()));
714 Operation *store =
715 b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
716 res.push_back(Elt: store);
717 }
718 return res;
719}
720
721SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
722 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
723 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
724 SmallVector<Value> toStore;
725 toStore.reserve(N: 32);
726 foreachIndividualVectorElement(
727 vector: vectorToStore,
728 /*applyFn=*/
729 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
730 return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
731 },
732 /*reduceFn=*/
733 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
734 toStore.push_back(Elt: v);
735 });
736 return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn: std::move(indexFn));
737}
738
739static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
740 SmallVector<int64_t>>
741makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
742 ArrayRef<int64_t> res) {
743 SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
744 SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
745 SmallVector<int64_t> vres{res.begin(), res.end()};
746 return std::make_tuple(args&: vlhs, args&: vrhs, args&: vres);
747}
748
749FailureOr<MmaSyncBuilder::MmaSyncInfo>
750MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
751 TypeRange elementalTypes) {
752 // TODO: Tablegen all this.
753 Type f16 = b.getF16Type();
754 Type f32 = b.getF32Type();
755 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
756 elementalTypes == TypeRange{f32, f32, f32}) {
757 return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k4tf32Lhs,
758 args: &MmaSyncBuilder::m16n8k4tf32Rhs,
759 args: &MmaSyncBuilder::m16n8k4tf32Res),
760 .vectorShapes: makeVectorShapes(lhs: {2, 1}, rhs: {1, 1}, res: {2, 2}),
761 .mmaShape: SmallVector<int64_t>{opShape.begin(), opShape.end()},
762 /*tf32Enabled=*/true};
763 }
764 // This is the version with f16 accumulation.
765 // TODO: version with f32 accumulation.
766 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
767 elementalTypes == TypeRange{f16, f16, f16}) {
768 return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k16f16Lhs,
769 args: &MmaSyncBuilder::m16n8k16f16Rhs,
770 args: &MmaSyncBuilder::m16n8k16f16Res),
771 .vectorShapes: makeVectorShapes(lhs: {4, 2}, rhs: {2, 2}, res: {2, 2}),
772 .mmaShape: SmallVector<int64_t>{opShape.begin(), opShape.end()},
773 /*tf32Enabled=*/false};
774 }
775 return failure();
776}
777
778FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
779 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782 assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
783 "expected lhs to be a 2D memref");
784 assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
785 "expected rhs to be a 2D memref");
786 assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
787 "expected res to be a 2D memref");
788
789 int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
790 int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
791 int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
792 Type lhsType = getElementTypeOrSelf(type: lhsMemRef.getType());
793 Type rhsType = getElementTypeOrSelf(type: rhsMemRef.getType());
794 Type resType = getElementTypeOrSelf(type: resMemRef.getType());
795
796 FailureOr<MmaSyncInfo> maybeInfo =
797 getIndexCalculators(opShape: {m, n, k}, elementalTypes: {lhsType, rhsType, resType});
798 if (failed(result: maybeInfo))
799 return failure();
800
801 MmaSyncInfo info = *maybeInfo;
802 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: lhsMemRef,
805 indexFn: lhsIndexFn, vectorShape: lhsShape);
806 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: rhsMemRef,
807 indexFn: rhsIndexFn, vectorShape: rhsShape);
808 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: resMemRef,
809 indexFn: resIndexFn, vectorShape: resShape);
810 res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
811 info.tf32Enabled);
812 buildMmaSyncMemRefStoreOperand(b, loc, vectorToStore: res, laneId, memref: resMemRef, indexFn: resIndexFn,
813 vectorShape: resShape);
814 return res.getDefiningOp();
815}
816
817DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
818 transform::TransformRewriter &rewriter, LinalgOp linalgOp,
819 transform::ApplyToEachResultList &results,
820 transform::TransformState &state) {
821 bool fail = true;
822 // TODO: more robust detection of matmulOp, with transposes etc.
823 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
824 Location loc = linalgOp.getLoc();
825 // TODO: more robust computation of laneId, for now assume a single warp.
826 Value laneId = rewriter.create<gpu::ThreadIdOp>(
827 loc, rewriter.getIndexType(), gpu::Dimension::x);
828 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
829 fail = false;
830 }
831
832 if (fail) {
833 DiagnosedSilenceableFailure diag = emitSilenceableError()
834 << "unsupported target op: " << linalgOp;
835 diag.attachNote(linalgOp->getLoc()) << "target op";
836 return diag;
837 }
838
839 rewriter.eraseOp(linalgOp);
840 return DiagnosedSilenceableFailure::success();
841}
842
843//===----------------------------------------------------------------------===//
844// Hopper builders.
845//===----------------------------------------------------------------------===//
846
847/// Helper to create the base Hopper-specific operations that are reused in
848/// various other places.
849struct HopperBuilder {
850 HopperBuilder(RewriterBase &rewriter, Location loc)
851 : rewriter(rewriter), loc(loc) {}
852
853 TypedValue<nvgpu::MBarrierGroupType>
854 buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
855
856 /// Create tma descriptor op to initiate transfer from global to shared
857 /// memory. This must be done before the launch op, on the host.
858 TypedValue<nvgpu::TensorMapDescriptorType>
859 buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
860 gpu::LaunchOp launchOp);
861
862 /// Build a tma load from global memory to shared memory using `barrier` to
863 /// synchronize. Return the number of bytes that will be transferred.
864 OpFoldResult
865 buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
866 TypedValue<MemRefType> sharedMemref,
867 TypedValue<nvgpu::MBarrierGroupType> barrier,
868 SmallVectorImpl<Operation *> &loadOps);
869 void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
870 ArrayRef<OpFoldResult> sizes);
871
872 /// If threadIdx.x == 0 does TMA request + wait, else just wait.
873 /// Return the operation that performs the transfer on thread0.
874 // TODO: In the future, don't hardcode to thread 0 but elect a leader.
875 SmallVector<Operation *> buildPredicateLoadsOnThread0(
876 ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
877 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
878 TypedValue<nvgpu::MBarrierGroupType> barrier);
879
880 void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
881
882 RewriterBase &rewriter;
883 Location loc;
884};
885
886SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
887 ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
888 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
889 TypedValue<nvgpu::MBarrierGroupType> barrier) {
890 SmallVector<Operation *> loadOps;
891 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
892 Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
893 Value cond =
894 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
895 // clang-format off
896 rewriter.create<scf::IfOp>(
897 /*location=*/loc,
898 /*conditional=*/cond,
899 /*thenBuilder=*/
900 [&](OpBuilder &lb, Location loc) {
901 SmallVector<OpFoldResult> sizes;
902 sizes.reserve(N: globalDescriptors.size());
903 for (auto [desc, shmem] : llvm::zip_equal(
904 globalDescriptors, sharedMemBuffers)) {
905 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
906 sizes.push_back(sz);
907 }
908 // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
909 // This may or may not have perf implications.
910 buildBarrierArriveTx(barrier, sizes);
911 rewriter.create<scf::YieldOp>(loc);
912 },
913 /*elseBuilder=*/
914 [&](OpBuilder &lb, Location loc) {
915 // TODO: is this for no-thread divergence?
916 // Should we just yield the size and hoist?
917 buildBarrierArriveTx(barrier, sizes: getAsIndexOpFoldResult(ctx: rewriter.getContext(), val: 0));
918 rewriter.create<scf::YieldOp>(loc);
919 });
920 // clang-format on
921 return loadOps;
922}
923
924static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
925 return gpu::AddressSpaceAttr::get(
926 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
927 // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
928}
929
930TypedValue<nvgpu::MBarrierGroupType>
931HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
932 auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter);
933 Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
934 loc,
935 nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
936 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
937 rewriter.create<nvgpu::MBarrierInitOp>(
938 loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
939 zero, Value());
940 rewriter.create<gpu::BarrierOp>(loc);
941 return cast<TypedValue<nvgpu::MBarrierGroupType>>(Val&: barrier);
942}
943
944TypedValue<nvgpu::TensorMapDescriptorType>
945HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
946 gpu::LaunchOp launchOp) {
947 OpBuilder::InsertionGuard guard(rewriter);
948 rewriter.setInsertionPoint(launchOp);
949 Value unrankedMemRef = rewriter.create<memref::CastOp>(
950 loc,
951 UnrankedMemRefType::get(memref.getType().getElementType(),
952 memref.getType().getMemorySpace()),
953 memref);
954 SmallVector<OpFoldResult> mixedSizes =
955 memref::getMixedSizes(builder&: rewriter, loc, value: memref);
956 SmallVector<Value> sizes =
957 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: mixedSizes);
958
959 auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter);
960 Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
961 loc,
962 nvgpu::TensorMapDescriptorType::get(
963 rewriter.getContext(),
964 MemRefType::Builder(memref.getType())
965 .setMemorySpace(sharedMemorySpace),
966 TensorMapSwizzleKind::SWIZZLE_NONE,
967 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
968 TensorMapInterleaveKind::INTERLEAVE_NONE),
969 unrankedMemRef, sizes);
970 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
971}
972
973OpFoldResult HopperBuilder::buildTmaAsyncLoad(
974 TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
975 TypedValue<MemRefType> sharedMemref,
976 TypedValue<nvgpu::MBarrierGroupType> barrier,
977 SmallVectorImpl<Operation *> &loadOps) {
978 MLIRContext *ctx = rewriter.getContext();
979 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
980 Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
981 loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
982 Value(), Value());
983 loadOps.push_back(Elt: loadOp);
984 auto mixedSizes = memref::getMixedSizes(builder&: rewriter, loc, value: sharedMemref);
985 SmallVector<AffineExpr> symbols(mixedSizes.size());
986 bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols});
987 AffineExpr prodExprInBytes =
988 computeProduct(ctx, basis: symbols) *
989 (sharedMemref.getType().getElementTypeBitWidth() / 8);
990 auto res = affine::makeComposedFoldedAffineApply(b&: rewriter, loc,
991 expr: prodExprInBytes, operands: mixedSizes);
992 return res;
993}
994
995void HopperBuilder::buildBarrierArriveTx(
996 TypedValue<nvgpu::MBarrierGroupType> barrier,
997 ArrayRef<OpFoldResult> mixedSizes) {
998 assert(!mixedSizes.empty() && "expecte non-empty sizes");
999 MLIRContext *ctx = rewriter.getContext();
1000 SmallVector<AffineExpr> symbols(mixedSizes.size());
1001 bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols});
1002 AffineExpr sumExpr = computeSum(ctx, basis: symbols);
1003 OpFoldResult size =
1004 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: sumExpr, operands: mixedSizes);
1005 Value sizeVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: size);
1006 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1007 rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1008 Value());
1009}
1010
1011void HopperBuilder::buildTryWaitParity(
1012 TypedValue<nvgpu::MBarrierGroupType> barrier) {
1013 Type i1 = rewriter.getI1Type();
1014 Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
1015 // 10M is an arbitrary, not too small or too big number to specify the number
1016 // of ticks before retry.
1017 // TODO: hoist this in a default dialect constant.
1018 Value ticksBeforeRetry =
1019 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 10000000);
1020 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1021 rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1022 ticksBeforeRetry, zero);
1023}
1024
1025//===----------------------------------------------------------------------===//
1026// RewriteCopyAsTmaOp
1027//===----------------------------------------------------------------------===//
1028
1029/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1030struct CopyBuilder : public HopperBuilder {
1031 CopyBuilder(RewriterBase &rewriter, Location loc)
1032 : HopperBuilder(rewriter, loc) {}
1033
1034 SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
1035};
1036
1037SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1038 MLIRContext *ctx = rewriter.getContext();
1039 if (copyOps.empty())
1040 return SmallVector<Operation *>();
1041
1042 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1043 assert(launchOp && "expected launch op");
1044
1045 // 1. Init a barrier object in shared memory.
1046 OpBuilder::InsertionGuard g(rewriter);
1047 rewriter.setInsertionPoint(copyOps.front());
1048 AffineExpr bx, by, bz;
1049 bindSymbols(ctx, exprs&: bx, exprs&: by, exprs&: bz);
1050 AffineExpr prod = computeProduct(ctx, basis: ArrayRef<AffineExpr>{bx, by, bz});
1051 OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
1052 b&: rewriter, loc, expr: prod,
1053 operands: ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1054 launchOp.getBlockSizeZ()});
1055
1056 TypedValue<nvgpu::MBarrierGroupType> barrier =
1057 buildAndInitBarrierInSharedMemory(numThreads);
1058
1059 SmallVector<TypedValue<MemRefType>> shmems;
1060 SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
1061 for (Operation *op : copyOps) {
1062 auto copyOp = cast<linalg::CopyOp>(op);
1063 auto inMemRef =
1064 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1065 assert(inMemRef.getType().getRank() == 2 &&
1066 "expected in to be a 2D memref");
1067
1068 // 2. Build global memory descriptor.
1069 TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
1070 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1071 globalDescs.push_back(globalDesc);
1072
1073 // 3. Shared memory and descriptor for the tmp array.
1074 auto shmem =
1075 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1076 shmems.push_back(Elt: shmem);
1077 }
1078
1079 // 4. Load in from global memory to shared memory using tma.
1080 OpBuilder::InsertionGuard g2(rewriter);
1081 rewriter.setInsertionPoint(copyOps.front());
1082 SmallVector<Operation *> results =
1083 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1084
1085 // 5. Spin-loop until data is ready.
1086 buildTryWaitParity(barrier);
1087
1088 // 6. Erase the ops that have now been rewritten.
1089 for (Operation *op : copyOps)
1090 rewriter.eraseOp(op);
1091
1092 return results;
1093}
1094
1095DiagnosedSilenceableFailure
1096transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1097 transform::TransformResults &results,
1098 transform::TransformState &state) {
1099 auto payloadOps = state.getPayloadOps(getTarget());
1100 gpu::LaunchOp commonLaunchOp;
1101 Operation *firstOp, *failingOp;
1102 if (llvm::any_of(payloadOps, [&](Operation *op) {
1103 if (!commonLaunchOp) {
1104 commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1105 firstOp = op;
1106 }
1107 auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1108 commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1109 !isa<linalg::CopyOp>(op);
1110 if (fail)
1111 failingOp = op;
1112 return fail;
1113 })) {
1114 DiagnosedSilenceableFailure diag =
1115 emitSilenceableError()
1116 << "target ops must be linalg::CopyOp nested under a common "
1117 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1118 "be created on the host.\nBut got: "
1119 << *firstOp << "\nand " << *failingOp;
1120 return diag;
1121 }
1122
1123 // TODO: more robust detection of copy, with transposes etc.
1124 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1125
1126 return DiagnosedSilenceableFailure::success();
1127}
1128
1129//===----------------------------------------------------------------------===//
1130// Transform op registration
1131//===----------------------------------------------------------------------===//
1132
1133namespace {
1134class NVGPUTransformDialectExtension
1135 : public transform::TransformDialectExtension<
1136 NVGPUTransformDialectExtension> {
1137public:
1138 NVGPUTransformDialectExtension() {
1139 declareGeneratedDialect<arith::ArithDialect>();
1140 declareGeneratedDialect<affine::AffineDialect>();
1141 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1142 declareGeneratedDialect<NVVM::NVVMDialect>();
1143 declareGeneratedDialect<vector::VectorDialect>();
1144 registerTransformOps<
1145#define GET_OP_LIST
1146#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1147 >();
1148 }
1149};
1150} // namespace
1151
1152#define GET_OP_CLASSES
1153#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1154
1155void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry &registry) {
1156 registry.addExtensions<NVGPUTransformDialectExtension>();
1157}
1158

source code of mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp