1//===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
10
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SCF/Transforms/Transforms.h"
26#include "mlir/Dialect/Utils/IndexingUtils.h"
27#include "mlir/Dialect/Utils/StaticValueUtils.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/IR/AffineExpr.h"
30#include "mlir/IR/BuiltinTypes.h"
31#include "mlir/IR/Value.h"
32#include "llvm/ADT/ArrayRef.h"
33
34using namespace mlir;
35using namespace mlir::linalg;
36using namespace mlir::nvgpu;
37using namespace mlir::NVVM;
38using namespace mlir::transform;
39
40#define DEBUG_TYPE "nvgpu-transforms"
41#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42#define DBGSNL() (llvm::dbgs() << "\n")
43#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
44
45//===----------------------------------------------------------------------===//
46// Apply...ConversionPatternsOp
47//===----------------------------------------------------------------------===//
48
49void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
50 TypeConverter &typeConverter, RewritePatternSet &patterns) {
51 auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
52 /// device-side async tokens cannot be materialized in nvvm. We just
53 /// convert them to a dummy i32 type in order to easily drop them during
54 /// conversion.
55 populateGpuMemorySpaceAttributeConversions(
56 llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
57 switch (space) {
58 case gpu::AddressSpace::Global:
59 return static_cast<unsigned>(
60 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
61 case gpu::AddressSpace::Workgroup:
62 return static_cast<unsigned>(
63 NVVM::NVVMMemorySpace::kSharedMemorySpace);
64 case gpu::AddressSpace::Private:
65 return 0;
66 }
67 llvm_unreachable("unknown address space enum value");
68 return 0;
69 });
70 llvmTypeConverter.addConversion(
71 [&](nvgpu::DeviceAsyncTokenType type) -> Type {
72 return llvmTypeConverter.convertType(
73 IntegerType::get(type.getContext(), 32));
74 });
75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
76 return llvmTypeConverter.convertType(
77 IntegerType::get(type.getContext(), 64));
78 });
79 llvmTypeConverter.addConversion(
80 [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
81 Type elemType = type.getFragmented().getElementType();
82 int64_t sizeM = type.getFragmented().getDimSize(0);
83 int64_t sizeN = type.getFragmented().getDimSize(1);
84
85 unsigned numMembers;
86 if (elemType.isF32() || elemType.isInteger(32))
87 numMembers = sizeN / 2;
88 else if (elemType.isF16())
89 numMembers = sizeN / 4;
90 else
91 llvm_unreachable("unsupported type for warpgroup accumulator");
92
93 SmallVector<Type> innerStructBody;
94 for (unsigned i = 0; i < numMembers; i++)
95 innerStructBody.push_back(elemType);
96 auto innerStructType = LLVM::LLVMStructType::getLiteral(
97 type.getContext(), innerStructBody);
98
99 SmallVector<Type> structBody;
100 for (int i = 0; i < sizeM; i += kWgmmaSizeM)
101 structBody.push_back(innerStructType);
102
103 auto convertedType =
104 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105 return llvmTypeConverter.convertType(convertedType);
106 });
107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
108 return llvmTypeConverter.convertType(
109 getMBarrierMemrefType(type.getContext(), type));
110 });
111 llvmTypeConverter.addConversion(
112 [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
113 return llvmTypeConverter.convertType(
114 IntegerType::get(type.getContext(), 64));
115 });
116 llvmTypeConverter.addConversion(
117 [&](nvgpu::TensorMapDescriptorType type) -> Type {
118 return LLVM::LLVMPointerType::get(type.getContext());
119 });
120 populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
121}
122
123LogicalResult
124transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125 transform::TypeConverterBuilderOpInterface builder) {
126 if (builder.getTypeConverterType() != "LLVMTypeConverter")
127 return emitOpError("expected LLVMTypeConverter");
128 return success();
129}
130
131//===---------------------------------------------------------------------===//
132// CreateAsyncGroupsOp
133//===---------------------------------------------------------------------===//
134
135void transform::CreateAsyncGroupsOp::getEffects(
136 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
137 transform::consumesHandle(getTargetMutable(), effects);
138 transform::producesHandle(getOperation()->getOpResults(), effects);
139 transform::modifiesPayload(effects);
140}
141
142DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
143 TransformRewriter &rewriter, Operation *target,
144 ApplyToEachResultList &results, TransformState &state) {
145 nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
146 results.push_back(target);
147 return DiagnosedSilenceableFailure::success();
148}
149
150//===----------------------------------------------------------------------===//
151// PipelineSharedMemoryCopiesOp
152//===----------------------------------------------------------------------===//
153
154/// Returns true if the given type has the default memory space.
155static bool hasDefaultMemorySpace(BaseMemRefType type) {
156 return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
157}
158
159/// Returns true if the given type has the shared (workgroup) memory space.
160static bool hasSharedMemorySpace(BaseMemRefType type) {
161 auto space =
162 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
163 return space &&
164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
165}
166
167/// Returns the value produced by a load from the default memory space. Returns
168/// null if the operation is not such a load.
169static Value getValueLoadedFromGlobal(Operation *op) {
170 // TODO: consider an interface or leveraging the memory effects interface.
171 auto load = dyn_cast<vector::TransferReadOp>(op);
172 if (!load)
173 return nullptr;
174
175 auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
176 if (!loadType || !hasDefaultMemorySpace(loadType))
177 return nullptr;
178 return load;
179}
180
181/// Returns true if the operation is storing the given value into shared memory.
182static bool isStoreToShared(Operation *op, Value v) {
183 // TOD: consider an interface or leveraging the memory effects interface.
184 auto store = dyn_cast<vector::TransferWriteOp>(op);
185 if (!store || store.getVector() != v)
186 return false;
187
188 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
189 return storeType || hasSharedMemorySpace(storeType);
190}
191
192/// Returns true if the operation is a load from the default memory space the
193/// result of which is only stored into the shared memory space.
194static bool isLoadFromGlobalStoredToShared(Operation *op) {
195 Value loaded = getValueLoadedFromGlobal(op);
196 if (!loaded || !loaded.hasOneUse())
197 return false;
198
199 return isStoreToShared(op: *loaded.getUsers().begin(), v: loaded);
200}
201
202/// Populate `ops` with the set of operations that belong to the stage 0 of the
203/// pipelined version of the given loop when pipelining copies to shared memory.
204/// Specifically, this collects:
205///
206/// 1. all loads from global memory, both sync and async;
207/// 2. the barriers for async loads.
208///
209/// In particular, barriers are omitted if they do not dominate at least one
210/// async load for which there is not yet a barrier.
211static LogicalResult
212collectStage0PipeliningOps(scf::ForOp forOp,
213 llvm::SmallPtrSet<Operation *, 16> &ops) {
214
215 llvm::SmallPtrSet<Operation *, 4> barriers;
216 for (Operation &op : *forOp.getBody()) {
217 // Bail on nested ops for now.
218 if (op.getNumRegions() > 0)
219 return failure();
220
221 if (isa<gpu::BarrierOp>(op)) {
222 barriers.insert(&op);
223 continue;
224 }
225
226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
227 ops.insert(&op);
228 ops.insert(std::make_move_iterator(barriers.begin()),
229 std::make_move_iterator(barriers.end()));
230 assert(barriers.empty() &&
231 "expected to have moved the barriers into another set");
232 continue;
233 }
234
235 if (isLoadFromGlobalStoredToShared(&op)) {
236 ops.insert(&op);
237 continue;
238 }
239 }
240
241 return success();
242}
243
244/// Hook for the loop pipeliner that sets the "num groups in flight" attribute
245/// of async wait operations corresponding to pipelined shared memory copies.
246// TODO: this currently assumes that there are no groups that could be in flight
247// in the existing code.
248static void
249setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
250 scf::PipeliningOption::PipelinerPart part,
251 unsigned iteration, unsigned depth) {
252 // Based on the order of copies within the loop we need to set the number
253 // of copies in flight, unless it is already set.
254 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255 if (!waitOp || waitOp.getNumGroups())
256 return;
257
258 int numGroupInFlight = 0;
259 if (part == scf::PipeliningOption::PipelinerPart::Kernel ||
260 part == scf::PipeliningOption::PipelinerPart::Prologue) {
261 numGroupInFlight = depth - 1;
262 } else {
263 // By construction there should be no wait op in the prologue as all the
264 // wait should be in the last stage.
265 assert(part == scf::PipeliningOption::PipelinerPart::Epilogue);
266 // Based on the schedule we pick we know how many groups are in flight for
267 // each iteration of the epilogue.
268 numGroupInFlight = depth - 1 - iteration;
269 }
270 waitOp.setNumGroups(numGroupInFlight);
271}
272
273/// Hook for the loop pipeliner that populates `ops` with the stage information
274/// as follows:
275///
276/// - operations in `stage0Ops` (typically loads from global memory and
277/// related barriers) are at stage 0;
278/// - operations in the backward slice of any stage0Ops are all at stage 0;
279/// - other operations are at stage `depth`;
280/// - the internal order of the pipelined loop has ops at stage `depth` first,
281/// then those at stage 0, with relative order within each group preserved.
282///
283static void getPipelineStages(
284 scf::ForOp forOp,
285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
286 unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
287 SetVector<Operation *> dependencies;
288 BackwardSliceOptions options([&](Operation *visited) {
289 return visited->getBlock() == forOp.getBody();
290 });
291 options.inclusive = true;
292 for (Operation &op : forOp.getBody()->getOperations()) {
293 if (stage0Ops.contains(&op)) {
294 LogicalResult result = getBackwardSlice(&op, &dependencies, options);
295 assert(result.succeeded() && "expected a backward slice");
296 (void)result;
297 }
298 }
299
300 for (Operation &op : forOp.getBody()->getOperations()) {
301 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
302 opsWithPipelineStages.emplace_back(&op, depth);
303 }
304 for (Operation &op : forOp.getBody()->getOperations()) {
305 if (dependencies.contains(&op))
306 opsWithPipelineStages.emplace_back(&op, 0);
307 }
308}
309
310/// Hook for the loop pipeliner. Replaces op with a predicated version and
311/// returns the resulting operation. Returns the original op if the predication
312/// isn't necessary for the given op. Returns null if predication is needed but
313/// not supported.
314static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
315 Operation *op, Value predicate) {
316 // Some operations may be fine to execute "speculatively" more times than the
317 // original number of iterations, in particular side-effect free operations
318 // and barriers, even if they cannot be predicated.
319 if (isMemoryEffectFree(op) ||
320 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
321 nvgpu::DeviceAsyncWaitOp>(op)) {
322 return op;
323 }
324
325 // Otherwise, only async copies can currently be predicated.
326 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
327 if (!asyncCopyOp)
328 return nullptr;
329
330 // Create srcElement Value based on `predicate`. The next lines generate
331 // the following code:
332 //
333 // srcElement = (pred) ? prevSrcElements : 0;
334 //
335 Location loc = asyncCopyOp->getLoc();
336 Value dstElements =
337 rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
338 Value originalSrcElement =
339 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
340 Value c0Index = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
341 auto srcElements = rewriter.create<arith::SelectOp>(
342 loc, predicate, originalSrcElement, c0Index);
343 auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
344 loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
345 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
346 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
347 UnitAttr());
348 rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
349 return asyncCopyZeroFillOp;
350}
351
352/// Applies loop pipelining with the given depth to the given loop so that
353/// copies into the shared memory are pipelined. Doesn't affect other loops.
354/// Returns a pair containing the error state and the pipelined op, the latter
355/// being null in case of any failure. The error state contains a definite error
356/// if the IR has been modified and a silenceable error otherwise.
357static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
358pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
359 bool epiloguePeeling) {
360 llvm::SmallPtrSet<Operation *, 16> stage0Ops;
361 if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
362 return std::make_tuple(
363 emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
364 scf::ForOp());
365 }
366 if (stage0Ops.empty()) {
367 return std::make_tuple(
368 emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
369 }
370
371 scf::PipeliningOption options;
372 unsigned maxDepth = depth;
373 auto setAnnotation = [&](Operation *op,
374 scf::PipeliningOption::PipelinerPart part,
375 unsigned iteration) {
376 return setAsyncWaitGroupsInFlight(builder&: rewriter, op, part, iteration, depth: maxDepth);
377 };
378 options.getScheduleFn =
379 [&](scf::ForOp schedulingFor,
380 std::vector<std::pair<Operation *, unsigned>> &ops) {
381 if (schedulingFor != forOp)
382 return;
383 return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
384 };
385 options.annotateFn = setAnnotation;
386 if (!epiloguePeeling) {
387 options.peelEpilogue = false;
388 options.predicateFn = replaceOpWithPredicatedOp;
389 }
390
391 OpBuilder::InsertionGuard guard(rewriter);
392 rewriter.setInsertionPoint(forOp);
393 bool modifiedIR;
394 FailureOr<scf::ForOp> maybePipelined =
395 pipelineForLoop(rewriter, forOp, options, &modifiedIR);
396 if (succeeded(maybePipelined)) {
397 return std::make_tuple(DiagnosedSilenceableFailure::success(),
398 *maybePipelined);
399 }
400 return std::make_tuple(
401 modifiedIR
402 ? DiagnosedSilenceableFailure::definiteFailure()
403 : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
404 scf::ForOp());
405}
406
407DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
408 TransformRewriter &rewriter, scf::ForOp forOp,
409 ApplyToEachResultList &results, TransformState &state) {
410 auto [diag, pipelined] = pipelineForSharedCopies(
411 rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
412 if (diag.succeeded()) {
413 results.push_back(pipelined);
414 return DiagnosedSilenceableFailure::success();
415 }
416 if (diag.isDefiniteFailure()) {
417 auto diag = emitDefiniteFailure("irreversible pipelining failure");
418 if (!getPeelEpilogue()) {
419 diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
420 diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
421 }
422 return diag;
423 }
424
425 return std::move(diag);
426}
427
428//===----------------------------------------------------------------------===//
429// RewriteMatmulAsMmaSyncOp
430//===----------------------------------------------------------------------===//
431
432/// Helper struct to encode a pair of row/column indexings in the form of
433/// affine expressions.
434struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
435 RowColIndexing(AffineExpr row, AffineExpr col)
436 : std::pair<AffineExpr, AffineExpr>(row, col) {}
437
438 AffineExpr row() const { return first; };
439 AffineExpr col() const { return second; };
440
441 void print(llvm::raw_ostream &os) const {
442 os << "- indexing: " << first << ", " << second;
443 }
444};
445
446/// Helper struct to provide a simple mapping from matmul operations to the
447/// corresponding mma.sync operation. This is constrained to the case where the
448/// matmul matches the mma.sync operation 1-1.
449struct MmaSyncBuilder {
450 MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
451 : b(b), loc(loc), laneId(laneId) {}
452
453 using IndexCalculator =
454 std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
455
456 /// Create the mma.sync operation corresponding to `linalgOp` along with all
457 /// the supporting load/store and vector operations.
458 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
459
460private:
461 struct MmaSyncInfo {
462 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
463 std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
464 vectorShapes;
465 SmallVector<int64_t> mmaShape;
466 bool tf32Enabled;
467 };
468
469 /// Return the specific index calculator for the given `linalgOp` or failure
470 /// if the op is not supported. This is the toplevel switch that should just
471 /// be Tablegen'd in the future.
472 FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
473 TypeRange elementalTypes);
474
475 //===--------------------------------------------------------------------===//
476 // Instruction-specific row, column indexing expression builders.
477 // These should all be declaratively specified via Tablegen in the future.
478 // The Tablegen specification should be as straightforward as possible to
479 // only model the existing size and type combinations.
480 //===--------------------------------------------------------------------===//
481 //
482 // TODO: Tablegen all this.
483 //===--------------------------------------------------------------------===//
484 // m16n8k4 tf32 case.
485 //===--------------------------------------------------------------------===//
486 /// From the NVIDIA doc:
487 /// groupID = %laneid >> 2
488 /// threadIDInGroup = %laneid % 4
489 /// row = groupID for a0
490 /// groupID + 8 for a1
491 /// col = threadIDInGroup
492 static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
493 auto dim = getAffineDimExpr(position: 0, context: ctx);
494 AffineExpr groupID = dim.floorDiv(v: 4);
495 AffineExpr threadIDInGroup = dim % 4;
496 return {RowColIndexing{groupID, threadIDInGroup},
497 RowColIndexing{groupID + 8, threadIDInGroup}};
498 }
499
500 /// From the NVIDIA doc:
501 /// groupID = %laneid >> 2
502 /// threadIDInGroup = %laneid % 4
503 /// row = threadIDInGroup
504 /// col = groupID
505 static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
506 auto dim = getAffineDimExpr(position: 0, context: ctx);
507 AffineExpr groupID = dim.floorDiv(v: 4);
508 AffineExpr threadIDInGroup = dim % 4;
509 return {RowColIndexing{threadIDInGroup, groupID}};
510 }
511
512 /// From the NVIDIA doc:
513 /// groupID = %laneid >> 2
514 /// threadIDInGroup = %laneid % 4
515 /// row = groupID for c0 and c1
516 /// groupID + 8 for c2 and c3
517 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
518 static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
519 auto dim = getAffineDimExpr(position: 0, context: ctx);
520 AffineExpr groupID = dim.floorDiv(v: 4);
521 AffineExpr threadIDInGroup = dim % 4;
522 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
523 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
524 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
525 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
526 }
527
528 //===--------------------------------------------------------------------===//
529 // m16n8k16 f16 case.
530 //===--------------------------------------------------------------------===//
531 /// From the NVIDIA doc:
532 /// groupID = %laneid >> 2
533 /// threadIDInGroup = %laneid % 4
534 ///
535 /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
536 /// groupID + 8 Otherwise
537 ///
538 /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
539 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
540 static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
541 auto dim = getAffineDimExpr(position: 0, context: ctx);
542 AffineExpr groupID = dim.floorDiv(v: 4);
543 AffineExpr threadIDInGroup = dim % 4;
544 // clang-format off
545 return {
546 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
547 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
548 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
549 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
550 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
551 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
552 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
553 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
554 };
555 // clang-format on
556 }
557
558 /// From the NVIDIA doc:
559 /// groupID = %laneid >> 2
560 /// threadIDInGroup = %laneid % 4
561 ///
562 /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
563 /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
564 ///
565 /// col = groupID
566 static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
567 auto dim = getAffineDimExpr(position: 0, context: ctx);
568 AffineExpr groupID = dim.floorDiv(v: 4);
569 AffineExpr threadIDInGroup = dim % 4;
570 // clang-format off
571 return {
572 RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
573 RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
574 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
575 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
576 };
577 // clang-format on
578 }
579
580 /// From the NVIDIA doc:
581 /// groupID = %laneid >> 2
582 /// threadIDInGroup = %laneid % 4
583 ///
584 /// row = groupID for ci where i < 2
585 /// groupID + 8 for ci where i >= 2
586 ///
587 /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
588 static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
589 auto dim = getAffineDimExpr(position: 0, context: ctx);
590 AffineExpr groupID = dim.floorDiv(v: 4);
591 AffineExpr threadIDInGroup = dim % 4;
592 // clang-format off
593 return {
594 RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
595 RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
596 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
597 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
598 };
599 // clang-format on
600 }
601
602 //===--------------------------------------------------------------------===//
603 /// Helper functions to create customizable load and stores operations. The
604 /// specific shapes of each MMA instruction are passed via the
605 /// IndexCalculator callback.
606 //===--------------------------------------------------------------------===//
607 /// Build a list of memref.load operations indexed at `(row, col)` indices
608 /// that make sense for a particular MMA instruction and specified via the
609 /// IndexCalculator callback.
610 SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
611 OpFoldResult laneId, Value memref,
612 const IndexCalculator &indexFn);
613
614 /// Perform a distributed load of a vector operand of `vectorShape` for a
615 /// particular MMA instruction whose `(row, col)` indices are specified via
616 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
617 /// data that makes sense for the particular MMA operation.
618 /// The `vectorShape` matches existing NVGPU dialect op specification but
619 /// could also be flattened in the future if needed for simplification.
620 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
621 OpFoldResult laneId, Value memref,
622 IndexCalculator indexFn,
623 ArrayRef<int64_t> vectorShape);
624
625 /// Build a list of memref.store operations indexed at `(row, col)` indices
626 /// that make sense for a particular MMA instruction and specified via the
627 /// IndexCalculator callback.
628 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
629 ValueRange toStore,
630 OpFoldResult laneId, Value memref,
631 const IndexCalculator &indexFn);
632
633 /// Perform a distributed store of a vector operand of `vectorShape` for a
634 /// particular MMA instruction whose `(row, col)` indices are specified via
635 /// the IndexCalculator callback. Each `laneId` loads the subportion of the
636 /// data that makes sense for the particular MMA operation.
637 /// The `vectorShape` matches existing NVGPU dialect op specification but
638 /// could also be flattened in the future if needed for simplification.
639 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
640 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
641 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
642
643 OpBuilder &b;
644 Location loc;
645 OpFoldResult laneId;
646};
647
648//===--------------------------------------------------------------------===//
649/// Helper functions to create customizable load and stores operations. The
650/// specific shapes of each MMA instruction are passed via the
651/// IndexCalculator callback.
652//===--------------------------------------------------------------------===//
653
654template <typename ApplyFn, typename ReduceFn>
655static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
656 ReduceFn reduceFn) {
657 VectorType vectorType = cast<VectorType>(vector.getType());
658 auto vectorShape = vectorType.getShape();
659 auto strides = computeStrides(vectorShape);
660 for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
661 auto indices = delinearize(idx, strides);
662 reduceFn(applyFn(vector, idx, indices), idx, indices);
663 }
664}
665
666SmallVector<Value>
667MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
668 OpFoldResult laneId, Value memref,
669 const IndexCalculator &indexFn) {
670 auto aff = [&](AffineExpr e) {
671 return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId);
672 };
673 SmallVector<Value> res;
674 SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
675 for (auto indexing : indexings) {
676 Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row()));
677 Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col()));
678 auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
679 res.push_back(Elt: load);
680 }
681 return res;
682}
683
684Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
685 OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
686 IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
687 auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn: std::move(indexFn));
688
689 Type elementType = getElementTypeOrSelf(type: memref.getType());
690 auto vt = VectorType::get(vectorShape, elementType);
691 Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
692 foreachIndividualVectorElement(
693 vector: res,
694 /*applyFn=*/
695 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
696 return loads[linearIdx];
697 },
698 /*reduceFn=*/
699 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
700 res = b.create<vector::InsertOp>(loc, v, res, indices);
701 });
702
703 return res;
704}
705
706SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
707 OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
708 Value memref, const IndexCalculator &indexFn) {
709 auto aff = [&](AffineExpr e) {
710 return affine::makeComposedFoldedAffineApply(b, loc, expr: e, operands: laneId);
711 };
712 SmallVector<Operation *> res;
713 for (auto [indexing, val] :
714 llvm::zip_equal(t: indexFn(b.getContext()), u&: toStore)) {
715 Value row = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.row()));
716 Value col = getValueOrCreateConstantIndexOp(b, loc, ofr: aff(indexing.col()));
717 Operation *store =
718 b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
719 res.push_back(Elt: store);
720 }
721 return res;
722}
723
724SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
725 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
726 Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
727 SmallVector<Value> toStore;
728 toStore.reserve(N: 32);
729 foreachIndividualVectorElement(
730 vector: vectorToStore,
731 /*applyFn=*/
732 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
733 return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
734 },
735 /*reduceFn=*/
736 [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
737 toStore.push_back(Elt: v);
738 });
739 return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn: std::move(indexFn));
740}
741
742static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
743 SmallVector<int64_t>>
744makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
745 ArrayRef<int64_t> res) {
746 SmallVector<int64_t> vlhs(lhs);
747 SmallVector<int64_t> vrhs(rhs);
748 SmallVector<int64_t> vres(res);
749 return std::make_tuple(args&: vlhs, args&: vrhs, args&: vres);
750}
751
752FailureOr<MmaSyncBuilder::MmaSyncInfo>
753MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
754 TypeRange elementalTypes) {
755 // TODO: Tablegen all this.
756 Type f16 = b.getF16Type();
757 Type f32 = b.getF32Type();
758 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
759 elementalTypes == TypeRange{f32, f32, f32}) {
760 return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k4tf32Lhs,
761 args: &MmaSyncBuilder::m16n8k4tf32Rhs,
762 args: &MmaSyncBuilder::m16n8k4tf32Res),
763 .vectorShapes: makeVectorShapes(lhs: {2, 1}, rhs: {1, 1}, res: {2, 2}),
764 .mmaShape: SmallVector<int64_t>{opShape},
765 /*tf32Enabled=*/true};
766 }
767 // This is the version with f16 accumulation.
768 // TODO: version with f32 accumulation.
769 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
770 elementalTypes == TypeRange{f16, f16, f16}) {
771 return MmaSyncInfo{.indexFns: std::make_tuple(args: &MmaSyncBuilder::m16n8k16f16Lhs,
772 args: &MmaSyncBuilder::m16n8k16f16Rhs,
773 args: &MmaSyncBuilder::m16n8k16f16Res),
774 .vectorShapes: makeVectorShapes(lhs: {4, 2}, rhs: {2, 2}, res: {2, 2}),
775 .mmaShape: SmallVector<int64_t>{opShape},
776 /*tf32Enabled=*/false};
777 }
778 return failure();
779}
780
781FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
782 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
783 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
784 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
785 assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
786 "expected lhs to be a 2D memref");
787 assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
788 "expected rhs to be a 2D memref");
789 assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
790 "expected res to be a 2D memref");
791
792 int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
793 int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
794 int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
795 Type lhsType = getElementTypeOrSelf(type: lhsMemRef.getType());
796 Type rhsType = getElementTypeOrSelf(type: rhsMemRef.getType());
797 Type resType = getElementTypeOrSelf(type: resMemRef.getType());
798
799 FailureOr<MmaSyncInfo> maybeInfo =
800 getIndexCalculators(opShape: {m, n, k}, elementalTypes: {lhsType, rhsType, resType});
801 if (failed(Result: maybeInfo))
802 return failure();
803
804 MmaSyncInfo info = *maybeInfo;
805 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
806 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
807 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: lhsMemRef,
808 indexFn: lhsIndexFn, vectorShape: lhsShape);
809 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: rhsMemRef,
810 indexFn: rhsIndexFn, vectorShape: rhsShape);
811 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, memref: resMemRef,
812 indexFn: resIndexFn, vectorShape: resShape);
813 res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
814 info.tf32Enabled);
815 buildMmaSyncMemRefStoreOperand(b, loc, vectorToStore: res, laneId, memref: resMemRef, indexFn: resIndexFn,
816 vectorShape: resShape);
817 return res.getDefiningOp();
818}
819
820DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
821 transform::TransformRewriter &rewriter, LinalgOp linalgOp,
822 transform::ApplyToEachResultList &results,
823 transform::TransformState &state) {
824 bool fail = true;
825 // TODO: more robust detection of matmulOp, with transposes etc.
826 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
827 // Check to not let go the matmul with extended semantic, through this
828 // transform.
829 if (linalgOp.hasUserDefinedMaps()) {
830 return emitSilenceableError()
831 << "only matmul ops with non-extended semantics are supported";
832 }
833 Location loc = linalgOp.getLoc();
834 // TODO: more robust computation of laneId, for now assume a single warp.
835 Value laneId = rewriter.create<gpu::ThreadIdOp>(
836 loc, rewriter.getIndexType(), gpu::Dimension::x);
837 if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
838 fail = false;
839 }
840
841 if (fail) {
842 DiagnosedSilenceableFailure diag = emitSilenceableError()
843 << "unsupported target op: " << linalgOp;
844 diag.attachNote(linalgOp->getLoc()) << "target op";
845 return diag;
846 }
847
848 rewriter.eraseOp(linalgOp);
849 return DiagnosedSilenceableFailure::success();
850}
851
852//===----------------------------------------------------------------------===//
853// Hopper builders.
854//===----------------------------------------------------------------------===//
855
856/// Helper to create the base Hopper-specific operations that are reused in
857/// various other places.
858struct HopperBuilder {
859 HopperBuilder(RewriterBase &rewriter, Location loc)
860 : rewriter(rewriter), loc(loc) {}
861
862 TypedValue<nvgpu::MBarrierGroupType>
863 buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
864
865 /// Create tma descriptor op to initiate transfer from global to shared
866 /// memory. This must be done before the launch op, on the host.
867 TypedValue<nvgpu::TensorMapDescriptorType>
868 buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
869 gpu::LaunchOp launchOp);
870
871 /// Build a tma load from global memory to shared memory using `barrier` to
872 /// synchronize. Return the number of bytes that will be transferred.
873 OpFoldResult
874 buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
875 TypedValue<MemRefType> sharedMemref,
876 TypedValue<nvgpu::MBarrierGroupType> barrier,
877 SmallVectorImpl<Operation *> &loadOps);
878 void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
879 ArrayRef<OpFoldResult> sizes);
880
881 /// If threadIdx.x == 0 does TMA request + wait, else just wait.
882 /// Return the operation that performs the transfer on thread0.
883 // TODO: In the future, don't hardcode to thread 0 but elect a leader.
884 SmallVector<Operation *> buildPredicateLoadsOnThread0(
885 ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
886 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
887 TypedValue<nvgpu::MBarrierGroupType> barrier);
888
889 void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
890
891 RewriterBase &rewriter;
892 Location loc;
893};
894
895SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
896 ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
897 ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
898 TypedValue<nvgpu::MBarrierGroupType> barrier) {
899 SmallVector<Operation *> loadOps;
900 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
901 Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
902 Value cond =
903 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
904 // clang-format off
905 rewriter.create<scf::IfOp>(
906 /*location=*/loc,
907 /*conditional=*/cond,
908 /*thenBuilder=*/
909 [&](OpBuilder &lb, Location loc) {
910 SmallVector<OpFoldResult> sizes;
911 sizes.reserve(N: globalDescriptors.size());
912 for (auto [desc, shmem] : llvm::zip_equal(
913 globalDescriptors, sharedMemBuffers)) {
914 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
915 sizes.push_back(sz);
916 }
917 // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
918 // This may or may not have perf implications.
919 buildBarrierArriveTx(barrier, sizes);
920 rewriter.create<scf::YieldOp>(loc);
921 },
922 /*elseBuilder=*/
923 [&](OpBuilder &lb, Location loc) {
924 // TODO: is this for no-thread divergence?
925 // Should we just yield the size and hoist?
926 buildBarrierArriveTx(barrier, sizes: getAsIndexOpFoldResult(ctx: rewriter.getContext(), val: 0));
927 rewriter.create<scf::YieldOp>(loc);
928 });
929 // clang-format on
930 return loadOps;
931}
932
933static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
934 return gpu::AddressSpaceAttr::get(
935 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
936 // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
937}
938
939TypedValue<nvgpu::MBarrierGroupType>
940HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
941 auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter);
942 Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
943 loc,
944 nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
945 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
946 rewriter.create<nvgpu::MBarrierInitOp>(
947 loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
948 zero, Value());
949 rewriter.create<gpu::BarrierOp>(loc);
950 return cast<TypedValue<nvgpu::MBarrierGroupType>>(Val&: barrier);
951}
952
953TypedValue<nvgpu::TensorMapDescriptorType>
954HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
955 gpu::LaunchOp launchOp) {
956 OpBuilder::InsertionGuard guard(rewriter);
957 rewriter.setInsertionPoint(launchOp);
958 Value unrankedMemRef = rewriter.create<memref::CastOp>(
959 loc,
960 UnrankedMemRefType::get(memref.getType().getElementType(),
961 memref.getType().getMemorySpace()),
962 memref);
963 SmallVector<OpFoldResult> mixedSizes =
964 memref::getMixedSizes(builder&: rewriter, loc, value: memref);
965 SmallVector<Value> sizes =
966 getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: mixedSizes);
967
968 auto sharedMemorySpace = getSharedAddressSpaceAttribute(b&: rewriter);
969 Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
970 loc,
971 nvgpu::TensorMapDescriptorType::get(
972 rewriter.getContext(),
973 MemRefType::Builder(memref.getType())
974 .setMemorySpace(sharedMemorySpace),
975 TensorMapSwizzleKind::SWIZZLE_NONE,
976 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
977 TensorMapInterleaveKind::INTERLEAVE_NONE),
978 unrankedMemRef, sizes);
979 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
980}
981
982OpFoldResult HopperBuilder::buildTmaAsyncLoad(
983 TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
984 TypedValue<MemRefType> sharedMemref,
985 TypedValue<nvgpu::MBarrierGroupType> barrier,
986 SmallVectorImpl<Operation *> &loadOps) {
987 MLIRContext *ctx = rewriter.getContext();
988 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
989 Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
990 loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
991 Value(), Value());
992 loadOps.push_back(Elt: loadOp);
993 auto mixedSizes = memref::getMixedSizes(builder&: rewriter, loc, value: sharedMemref);
994 SmallVector<AffineExpr> symbols(mixedSizes.size());
995 bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols});
996 AffineExpr prodExprInBytes =
997 computeProduct(ctx, basis: symbols) *
998 (sharedMemref.getType().getElementTypeBitWidth() / 8);
999 auto res = affine::makeComposedFoldedAffineApply(b&: rewriter, loc,
1000 expr: prodExprInBytes, operands: mixedSizes);
1001 return res;
1002}
1003
1004void HopperBuilder::buildBarrierArriveTx(
1005 TypedValue<nvgpu::MBarrierGroupType> barrier,
1006 ArrayRef<OpFoldResult> mixedSizes) {
1007 assert(!mixedSizes.empty() && "expecte non-empty sizes");
1008 MLIRContext *ctx = rewriter.getContext();
1009 SmallVector<AffineExpr> symbols(mixedSizes.size());
1010 bindSymbolsList(ctx, exprs: llvm::MutableArrayRef{symbols});
1011 AffineExpr sumExpr = computeSum(ctx, basis: symbols);
1012 OpFoldResult size =
1013 affine::makeComposedFoldedAffineApply(b&: rewriter, loc, expr: sumExpr, operands: mixedSizes);
1014 Value sizeVal = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: size);
1015 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1016 rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1017 Value());
1018}
1019
1020void HopperBuilder::buildTryWaitParity(
1021 TypedValue<nvgpu::MBarrierGroupType> barrier) {
1022 Type i1 = rewriter.getI1Type();
1023 Value parity = rewriter.create<LLVM::ConstantOp>(loc, i1, 0);
1024 // 10M is an arbitrary, not too small or too big number to specify the number
1025 // of ticks before retry.
1026 // TODO: hoist this in a default dialect constant.
1027 Value ticksBeforeRetry =
1028 rewriter.create<arith::ConstantIndexOp>(location: loc, args: 10000000);
1029 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
1030 rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1031 ticksBeforeRetry, zero);
1032}
1033
1034//===----------------------------------------------------------------------===//
1035// RewriteCopyAsTmaOp
1036//===----------------------------------------------------------------------===//
1037
1038/// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1039struct CopyBuilder : public HopperBuilder {
1040 CopyBuilder(RewriterBase &rewriter, Location loc)
1041 : HopperBuilder(rewriter, loc) {}
1042
1043 SmallVector<Operation *> rewrite(ArrayRef<Operation *> copyOps);
1044};
1045
1046SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1047 MLIRContext *ctx = rewriter.getContext();
1048 if (copyOps.empty())
1049 return SmallVector<Operation *>();
1050
1051 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1052 assert(launchOp && "expected launch op");
1053
1054 // 1. Init a barrier object in shared memory.
1055 OpBuilder::InsertionGuard g(rewriter);
1056 rewriter.setInsertionPoint(copyOps.front());
1057 AffineExpr bx, by, bz;
1058 bindSymbols(ctx, exprs&: bx, exprs&: by, exprs&: bz);
1059 AffineExpr prod = computeProduct(ctx, basis: ArrayRef<AffineExpr>{bx, by, bz});
1060 OpFoldResult numThreads = affine::makeComposedFoldedAffineApply(
1061 b&: rewriter, loc, expr: prod,
1062 operands: ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1063 launchOp.getBlockSizeZ()});
1064
1065 TypedValue<nvgpu::MBarrierGroupType> barrier =
1066 buildAndInitBarrierInSharedMemory(numThreads);
1067
1068 SmallVector<TypedValue<MemRefType>> shmems;
1069 SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
1070 for (Operation *op : copyOps) {
1071 auto copyOp = cast<linalg::CopyOp>(op);
1072 auto inMemRef =
1073 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1074 assert(inMemRef.getType().getRank() == 2 &&
1075 "expected in to be a 2D memref");
1076
1077 // 2. Build global memory descriptor.
1078 TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
1079 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1080 globalDescs.push_back(globalDesc);
1081
1082 // 3. Shared memory and descriptor for the tmp array.
1083 auto shmem =
1084 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1085 shmems.push_back(Elt: shmem);
1086 }
1087
1088 // 4. Load in from global memory to shared memory using tma.
1089 OpBuilder::InsertionGuard g2(rewriter);
1090 rewriter.setInsertionPoint(copyOps.front());
1091 SmallVector<Operation *> results =
1092 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1093
1094 // 5. Spin-loop until data is ready.
1095 buildTryWaitParity(barrier);
1096
1097 // 6. Erase the ops that have now been rewritten.
1098 for (Operation *op : copyOps)
1099 rewriter.eraseOp(op);
1100
1101 return results;
1102}
1103
1104DiagnosedSilenceableFailure
1105transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1106 transform::TransformResults &results,
1107 transform::TransformState &state) {
1108 auto payloadOps = state.getPayloadOps(getTarget());
1109 gpu::LaunchOp commonLaunchOp;
1110 Operation *firstOp, *failingOp;
1111 if (llvm::any_of(payloadOps, [&](Operation *op) {
1112 if (!commonLaunchOp) {
1113 commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1114 firstOp = op;
1115 }
1116 auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1117 commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1118 !isa<linalg::CopyOp>(op);
1119 if (fail)
1120 failingOp = op;
1121 return fail;
1122 })) {
1123 DiagnosedSilenceableFailure diag =
1124 emitSilenceableError()
1125 << "target ops must be linalg::CopyOp nested under a common "
1126 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1127 "be created on the host.\nBut got: "
1128 << *firstOp << "\nand " << *failingOp;
1129 return diag;
1130 }
1131
1132 // TODO: more robust detection of copy, with transposes etc.
1133 CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1134
1135 return DiagnosedSilenceableFailure::success();
1136}
1137
1138//===----------------------------------------------------------------------===//
1139// Transform op registration
1140//===----------------------------------------------------------------------===//
1141
1142namespace {
1143class NVGPUTransformDialectExtension
1144 : public transform::TransformDialectExtension<
1145 NVGPUTransformDialectExtension> {
1146public:
1147 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1148
1149 NVGPUTransformDialectExtension() {
1150 declareGeneratedDialect<arith::ArithDialect>();
1151 declareGeneratedDialect<affine::AffineDialect>();
1152 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1153 declareGeneratedDialect<NVVM::NVVMDialect>();
1154 declareGeneratedDialect<vector::VectorDialect>();
1155 registerTransformOps<
1156#define GET_OP_LIST
1157#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1158 >();
1159 }
1160};
1161} // namespace
1162
1163#define GET_OP_CLASSES
1164#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1165
1166void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry &registry) {
1167 registry.addExtensions<NVGPUTransformDialectExtension>();
1168}
1169

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp