1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/GPU/IR/GPUDialect.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
17#include "mlir/Dialect/MemRef/IR/MemRef.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/IR/Diagnostics.h"
24#include "mlir/IR/DialectImplementation.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/OpImplementation.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/IR/SymbolTable.h"
29#include "mlir/IR/TypeUtilities.h"
30#include "mlir/Interfaces/FunctionImplementation.h"
31#include "mlir/Interfaces/SideEffectInterfaces.h"
32#include "mlir/Interfaces/ValueBoundsOpInterface.h"
33#include "mlir/Transforms/InliningUtils.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/Support/CommandLine.h"
37#include "llvm/Support/ErrorHandling.h"
38#include "llvm/Support/FormatVariadic.h"
39#include "llvm/Support/InterleavedRange.h"
40#include "llvm/Support/StringSaver.h"
41#include <cassert>
42#include <numeric>
43
44using namespace mlir;
45using namespace mlir::gpu;
46
47#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
48
49//===----------------------------------------------------------------------===//
50// GPU Device Mapping Attributes
51//===----------------------------------------------------------------------===//
52
53int64_t GPUBlockMappingAttr::getMappingId() const {
54 return static_cast<int64_t>(getBlock());
55}
56
57bool GPUBlockMappingAttr::isLinearMapping() const {
58 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
59}
60
61int64_t GPUBlockMappingAttr::getRelativeIndex() const {
62 return isLinearMapping()
63 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
64 : getMappingId();
65}
66
67int64_t GPUWarpgroupMappingAttr::getMappingId() const {
68 return static_cast<int64_t>(getWarpgroup());
69}
70
71bool GPUWarpgroupMappingAttr::isLinearMapping() const {
72 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
73}
74
75int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
76 return isLinearMapping()
77 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
78 : getMappingId();
79}
80
81int64_t GPUWarpMappingAttr::getMappingId() const {
82 return static_cast<int64_t>(getWarp());
83}
84
85bool GPUWarpMappingAttr::isLinearMapping() const {
86 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
87}
88
89int64_t GPUWarpMappingAttr::getRelativeIndex() const {
90 return isLinearMapping()
91 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
92 : getMappingId();
93}
94
95int64_t GPUThreadMappingAttr::getMappingId() const {
96 return static_cast<int64_t>(getThread());
97}
98
99bool GPUThreadMappingAttr::isLinearMapping() const {
100 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
101}
102
103int64_t GPUThreadMappingAttr::getRelativeIndex() const {
104 return isLinearMapping()
105 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
106 : getMappingId();
107}
108
109int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
110 return static_cast<int64_t>(getAddressSpace());
111}
112
113bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
114 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
115}
116
117int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
118 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
119}
120
121//===----------------------------------------------------------------------===//
122// MMAMatrixType
123//===----------------------------------------------------------------------===//
124
125MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
126 StringRef operand) {
127 return Base::get(elementType.getContext(), shape, elementType, operand);
128}
129
130MMAMatrixType
131MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
132 ArrayRef<int64_t> shape, Type elementType,
133 StringRef operand) {
134 return Base::getChecked(emitErrorFn: emitError, ctx: elementType.getContext(), args: shape,
135 args: elementType, args: operand);
136}
137
138unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
139
140ArrayRef<int64_t> MMAMatrixType::getShape() const {
141 return getImpl()->getShape();
142}
143
144Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
145
146StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
147
148bool MMAMatrixType::isValidElementType(Type elementType) {
149 return elementType.isF16() || elementType.isF32() ||
150 elementType.isUnsignedInteger(width: 8) || elementType.isSignedInteger(width: 8) ||
151 elementType.isInteger(width: 32);
152}
153
154LogicalResult
155MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
156 ArrayRef<int64_t> shape, Type elementType,
157 StringRef operand) {
158 if (operand != "AOp" && operand != "BOp" && operand != "COp")
159 return emitError() << "operand expected to be one of AOp, BOp or COp";
160
161 if (shape.size() != 2)
162 return emitError() << "MMAMatrixType must have exactly two dimensions";
163
164 if (!MMAMatrixType::isValidElementType(elementType))
165 return emitError()
166 << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
167
168 return success();
169}
170
171//===----------------------------------------------------------------------===//
172// GPUDialect
173//===----------------------------------------------------------------------===//
174
175bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
176 if (!memorySpace)
177 return false;
178 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
179 return gpuAttr.getValue() == getWorkgroupAddressSpace();
180 return false;
181}
182
183bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
184 Attribute memorySpace = type.getMemorySpace();
185 return isWorkgroupMemoryAddressSpace(memorySpace);
186}
187
188bool GPUDialect::isKernel(Operation *op) {
189 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
190 return static_cast<bool>(isKernelAttr);
191}
192
193namespace {
194/// This class defines the interface for handling inlining with gpu
195/// operations.
196struct GPUInlinerInterface : public DialectInlinerInterface {
197 using DialectInlinerInterface::DialectInlinerInterface;
198
199 /// All gpu dialect ops can be inlined.
200 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
201 return true;
202 }
203};
204} // namespace
205
206void GPUDialect::initialize() {
207 addTypes<AsyncTokenType>();
208 addTypes<MMAMatrixType>();
209 addTypes<SparseDnTensorHandleType>();
210 addTypes<SparseSpMatHandleType>();
211 addTypes<SparseSpGEMMOpHandleType>();
212 addOperations<
213#define GET_OP_LIST
214#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
215 >();
216 addAttributes<
217#define GET_ATTRDEF_LIST
218#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
219 >();
220 addInterfaces<GPUInlinerInterface>();
221 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
222 TerminatorOp>();
223 declarePromisedInterfaces<
224 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
225 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
226 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
227}
228
229static std::string getSparseHandleKeyword(SparseHandleKind kind) {
230 switch (kind) {
231 case SparseHandleKind::DnTensor:
232 return "sparse.dntensor_handle";
233 case SparseHandleKind::SpMat:
234 return "sparse.spmat_handle";
235 case SparseHandleKind::SpGEMMOp:
236 return "sparse.spgemmop_handle";
237 }
238 llvm_unreachable("unknown sparse handle kind");
239 return "";
240}
241
242Type GPUDialect::parseType(DialectAsmParser &parser) const {
243 // Parse the main keyword for the type.
244 StringRef keyword;
245 if (parser.parseKeyword(&keyword))
246 return Type();
247 MLIRContext *context = getContext();
248
249 // Handle 'async token' types.
250 if (keyword == "async.token")
251 return AsyncTokenType::get(context);
252
253 if (keyword == "mma_matrix") {
254 SMLoc beginLoc = parser.getNameLoc();
255
256 // Parse '<'.
257 if (parser.parseLess())
258 return nullptr;
259
260 // Parse the size and elementType.
261 SmallVector<int64_t> shape;
262 Type elementType;
263 if (parser.parseDimensionList(shape, /*allowDynamic=*/false) ||
264 parser.parseType(elementType))
265 return nullptr;
266
267 // Parse ','
268 if (parser.parseComma())
269 return nullptr;
270
271 // Parse operand.
272 std::string operand;
273 if (failed(parser.parseOptionalString(&operand)))
274 return nullptr;
275
276 // Parse '>'.
277 if (parser.parseGreater())
278 return nullptr;
279
280 return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn(
281 parser.getEncodedSourceLoc(beginLoc)),
282 shape, elementType, operand);
283 }
284
285 if (keyword == getSparseHandleKeyword(SparseHandleKind::DnTensor))
286 return SparseDnTensorHandleType::get(context);
287 if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat))
288 return SparseSpMatHandleType::get(context);
289 if (keyword == getSparseHandleKeyword(SparseHandleKind::SpGEMMOp))
290 return SparseSpGEMMOpHandleType::get(context);
291
292 parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
293 return Type();
294}
295// TODO: print refined type here. Notice that should be corresponding to the
296// parser
297void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
298 TypeSwitch<Type>(type)
299 .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
300 .Case<SparseDnTensorHandleType>([&](Type) {
301 os << getSparseHandleKeyword(SparseHandleKind::DnTensor);
302 })
303 .Case<SparseSpMatHandleType>(
304 [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); })
305 .Case<SparseSpGEMMOpHandleType>([&](Type) {
306 os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp);
307 })
308 .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
309 os << "mma_matrix<";
310 auto shape = fragTy.getShape();
311 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
312 os << *dim << 'x';
313 os << shape.back() << 'x' << fragTy.getElementType();
314 os << ", \"" << fragTy.getOperand() << "\"" << '>';
315 })
316 .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
317}
318
319static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
320 NamedAttribute attr) {
321 auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue());
322 if (!array)
323 return op->emitOpError(message: Twine(attr.getName()) +
324 " must be a dense i32 array");
325 if (array.size() != 3)
326 return op->emitOpError(message: Twine(attr.getName()) +
327 " must contain exactly 3 elements");
328 return success();
329}
330
331LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
332 NamedAttribute attr) {
333 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
334 return verifyKnownLaunchSizeAttr(op, attr);
335 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
336 return verifyKnownLaunchSizeAttr(op, attr);
337 if (!llvm::isa<UnitAttr>(attr.getValue()) ||
338 attr.getName() != getContainerModuleAttrName())
339 return success();
340
341 auto module = dyn_cast<ModuleOp>(op);
342 if (!module)
343 return op->emitError("expected '")
344 << getContainerModuleAttrName() << "' attribute to be attached to '"
345 << ModuleOp::getOperationName() << '\'';
346
347 auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult {
348 // Ignore launches that are nested more or less deep than functions in the
349 // module we are currently checking.
350 if (!launchOp->getParentOp() ||
351 launchOp->getParentOp()->getParentOp() != module)
352 return success();
353
354 // Ignore launch ops with missing attributes here. The errors will be
355 // reported by the verifiers of those ops.
356 if (!launchOp->getAttrOfType<SymbolRefAttr>(
357 LaunchFuncOp::getKernelAttrName(launchOp->getName())))
358 return success();
359
360 // Check that `launch_func` refers to a well-formed GPU kernel container.
361 StringAttr kernelContainerName = launchOp.getKernelModuleName();
362 Operation *kernelContainer = module.lookupSymbol(kernelContainerName);
363 if (!kernelContainer)
364 return launchOp.emitOpError()
365 << "kernel container '" << kernelContainerName.getValue()
366 << "' is undefined";
367
368 // If the container is a GPU binary op return success.
369 if (isa<BinaryOp>(kernelContainer))
370 return success();
371
372 auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer);
373 if (!kernelModule)
374 return launchOp.emitOpError()
375 << "kernel module '" << kernelContainerName.getValue()
376 << "' is undefined";
377
378 // Check that `launch_func` refers to a well-formed kernel function.
379 Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr());
380 if (!kernelFunc)
381 return launchOp.emitOpError("kernel function '")
382 << launchOp.getKernel() << "' is undefined";
383 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
384 if (!kernelConvertedFunction) {
385 InFlightDiagnostic diag = launchOp.emitOpError()
386 << "referenced kernel '" << launchOp.getKernel()
387 << "' is not a function";
388 diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
389 return diag;
390 }
391
392 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
393 GPUDialect::getKernelFuncAttrName()))
394 return launchOp.emitOpError("kernel function is missing the '")
395 << GPUDialect::getKernelFuncAttrName() << "' attribute";
396
397 // TODO: If the kernel isn't a GPU function (which happens during separate
398 // compilation), do not check type correspondence as it would require the
399 // verifier to be aware of the type conversion.
400 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
401 if (!kernelGPUFunction)
402 return success();
403
404 unsigned actualNumArguments = launchOp.getNumKernelOperands();
405 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
406 if (expectedNumArguments != actualNumArguments)
407 return launchOp.emitOpError("got ")
408 << actualNumArguments << " kernel operands but expected "
409 << expectedNumArguments;
410
411 auto functionType = kernelGPUFunction.getFunctionType();
412 for (unsigned i = 0; i < expectedNumArguments; ++i) {
413 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
414 return launchOp.emitOpError("type of function argument ")
415 << i << " does not match";
416 }
417 }
418
419 return success();
420 });
421
422 return walkResult.wasInterrupted() ? failure() : success();
423}
424
425/// Parses an optional list of async operands with an optional leading keyword.
426/// (`async`)? (`[` ssa-id-list `]`)?
427///
428/// This method is used by the tablegen assembly format for async ops as well.
429static ParseResult parseAsyncDependencies(
430 OpAsmParser &parser, Type &asyncTokenType,
431 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) {
432 auto loc = parser.getCurrentLocation();
433 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "async"))) {
434 if (parser.getNumResults() == 0)
435 return parser.emitError(loc, message: "needs to be named when marked 'async'");
436 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
437 }
438 return parser.parseOperandList(result&: asyncDependencies,
439 delimiter: OpAsmParser::Delimiter::OptionalSquare);
440}
441
442/// Prints optional async dependencies with its leading keyword.
443/// (`async`)? (`[` ssa-id-list `]`)?
444// Used by the tablegen assembly format for several async ops.
445static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
446 Type asyncTokenType,
447 OperandRange asyncDependencies) {
448 if (asyncTokenType)
449 printer << "async";
450 if (asyncDependencies.empty())
451 return;
452 if (asyncTokenType)
453 printer << ' ';
454 printer << llvm::interleaved_array(R: asyncDependencies);
455}
456
457// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
458/// Parses a GPU function memory attribution.
459///
460/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
461/// (`private` `(` ssa-id-and-type-list `)`)?
462///
463/// Note that this function parses only one of the two similar parts, with the
464/// keyword provided as argument.
465static ParseResult
466parseAttributions(OpAsmParser &parser, StringRef keyword,
467 SmallVectorImpl<OpAsmParser::Argument> &args) {
468 // If we could not parse the keyword, just assume empty list and succeed.
469 if (failed(Result: parser.parseOptionalKeyword(keyword)))
470 return success();
471
472 return parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren,
473 /*allowType=*/true);
474}
475
476/// Prints a GPU function memory attribution.
477static void printAttributions(OpAsmPrinter &p, StringRef keyword,
478 ArrayRef<BlockArgument> values) {
479 if (values.empty())
480 return;
481
482 auto printBlockArg = [](BlockArgument v) {
483 return llvm::formatv(Fmt: "{} : {}", Vals&: v, Vals: v.getType());
484 };
485 p << ' ' << keyword << '('
486 << llvm::interleaved(R: llvm::map_range(C&: values, F: printBlockArg)) << ')';
487}
488
489/// Verifies a GPU function memory attribution.
490static LogicalResult verifyAttributions(Operation *op,
491 ArrayRef<BlockArgument> attributions,
492 gpu::AddressSpace memorySpace) {
493 for (Value v : attributions) {
494 auto type = llvm::dyn_cast<MemRefType>(v.getType());
495 if (!type)
496 return op->emitOpError() << "expected memref type in attribution";
497
498 // We can only verify the address space if it hasn't already been lowered
499 // from the AddressSpaceAttr to a target-specific numeric value.
500 auto addressSpace =
501 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
502 if (!addressSpace)
503 continue;
504 if (addressSpace.getValue() != memorySpace)
505 return op->emitOpError()
506 << "expected memory space " << stringifyAddressSpace(memorySpace)
507 << " in attribution";
508 }
509 return success();
510}
511
512//===----------------------------------------------------------------------===//
513// AllReduceOp
514//===----------------------------------------------------------------------===//
515
516static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
517 Type resType) {
518 using Kind = gpu::AllReduceOperation;
519 if (llvm::is_contained(
520 {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
521 opName)) {
522 if (!isa<FloatType>(Val: resType))
523 return failure();
524 }
525
526 if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
527 Kind::AND, Kind::OR, Kind::XOR},
528 opName)) {
529 if (!isa<IntegerType>(Val: resType))
530 return failure();
531 }
532
533 return success();
534}
535
536LogicalResult gpu::AllReduceOp::verifyRegions() {
537 if (getBody().empty() != getOp().has_value())
538 return emitError("expected either an op attribute or a non-empty body");
539 if (!getBody().empty()) {
540 if (getBody().getNumArguments() != 2)
541 return emitError("expected two region arguments");
542 for (auto argument : getBody().getArguments()) {
543 if (argument.getType() != getType())
544 return emitError("incorrect region argument type");
545 }
546 unsigned yieldCount = 0;
547 for (Block &block : getBody()) {
548 if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
549 if (yield.getNumOperands() != 1)
550 return emitError("expected one gpu.yield operand");
551 if (yield.getOperand(0).getType() != getType())
552 return emitError("incorrect gpu.yield type");
553 ++yieldCount;
554 }
555 }
556 if (yieldCount == 0)
557 return emitError("expected gpu.yield op in region");
558 } else {
559 gpu::AllReduceOperation opName = *getOp();
560 if (failed(verifyReduceOpAndType(opName, getType()))) {
561 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
562 << "` reduction operation is not compatible with type "
563 << getType();
564 }
565 }
566
567 return success();
568}
569
570static bool canMakeGroupOpUniform(Operation *op) {
571 auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp());
572 if (!launchOp)
573 return false;
574
575 Region &body = launchOp.getBody();
576 assert(!body.empty() && "Invalid region");
577
578 // Only convert ops in gpu::launch entry block for now.
579 return op->getBlock() == &body.front();
580}
581
582OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
583 if (!getUniform() && canMakeGroupOpUniform(*this)) {
584 setUniform(true);
585 return getResult();
586 }
587
588 return nullptr;
589}
590
591// TODO: Support optional custom attributes (without dialect prefix).
592static ParseResult parseAllReduceOperation(AsmParser &parser,
593 AllReduceOperationAttr &attr) {
594 StringRef enumStr;
595 if (!parser.parseOptionalKeyword(keyword: &enumStr)) {
596 std::optional<AllReduceOperation> op =
597 gpu::symbolizeAllReduceOperation(enumStr);
598 if (!op)
599 return parser.emitError(loc: parser.getCurrentLocation(), message: "invalid op kind");
600 attr = AllReduceOperationAttr::get(parser.getContext(), *op);
601 }
602 return success();
603}
604
605static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
606 AllReduceOperationAttr attr) {
607 if (attr)
608 attr.print(printer);
609}
610
611//===----------------------------------------------------------------------===//
612// SubgroupReduceOp
613//===----------------------------------------------------------------------===//
614
615LogicalResult gpu::SubgroupReduceOp::verify() {
616 Type elemType = getType();
617 if (auto vecTy = dyn_cast<VectorType>(elemType)) {
618 if (vecTy.isScalable())
619 return emitOpError() << "is not compatible with scalable vector types";
620
621 elemType = vecTy.getElementType();
622 }
623
624 gpu::AllReduceOperation opName = getOp();
625 if (failed(verifyReduceOpAndType(opName, elemType))) {
626 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
627 << "` reduction operation is not compatible with type "
628 << getType();
629 }
630
631 auto clusterSize = getClusterSize();
632 if (clusterSize) {
633 uint32_t size = *clusterSize;
634 if (!llvm::isPowerOf2_32(size)) {
635 return emitOpError() << "cluster size " << size
636 << " is not a power of two";
637 }
638 }
639
640 uint32_t stride = getClusterStride();
641 if (stride != 1 && !clusterSize) {
642 return emitOpError() << "cluster stride can only be specified if cluster "
643 "size is specified";
644 }
645 if (!llvm::isPowerOf2_32(stride)) {
646 return emitOpError() << "cluster stride " << stride
647 << " is not a power of two";
648 }
649
650 return success();
651}
652
653OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
654 if (getClusterSize() == 1)
655 return getValue();
656
657 if (!getUniform() && canMakeGroupOpUniform(*this)) {
658 setUniform(true);
659 return getResult();
660 }
661
662 return nullptr;
663}
664
665//===----------------------------------------------------------------------===//
666// AsyncOpInterface
667//===----------------------------------------------------------------------===//
668
669void gpu::addAsyncDependency(Operation *op, Value token) {
670 op->insertOperands(index: 0, operands: {token});
671 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
672 return;
673 auto attrName =
674 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
675 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName);
676
677 // Async dependencies is the only variadic operand.
678 if (!sizeAttr)
679 return;
680
681 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
682 ++sizes.front();
683 op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes));
684}
685
686//===----------------------------------------------------------------------===//
687// LaunchOp
688//===----------------------------------------------------------------------===//
689
690void LaunchOp::build(OpBuilder &builder, OperationState &result,
691 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
692 Value getBlockSizeX, Value getBlockSizeY,
693 Value getBlockSizeZ, Value dynamicSharedMemorySize,
694 Type asyncTokenType, ValueRange asyncDependencies,
695 TypeRange workgroupAttributions,
696 TypeRange privateAttributions, Value clusterSizeX,
697 Value clusterSizeY, Value clusterSizeZ) {
698 OpBuilder::InsertionGuard g(builder);
699
700 // Add a WorkGroup attribution attribute. This attribute is required to
701 // identify private attributions in the list of block argguments.
702 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
703 builder.getI64IntegerAttr(workgroupAttributions.size()));
704
705 // Add Op operands.
706 result.addOperands(asyncDependencies);
707 if (asyncTokenType)
708 result.types.push_back(builder.getType<AsyncTokenType>());
709
710 // Add grid and block sizes as op operands, followed by the data operands.
711 result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
712 getBlockSizeY, getBlockSizeZ});
713 if (clusterSizeX)
714 result.addOperands(clusterSizeX);
715 if (clusterSizeY)
716 result.addOperands(clusterSizeY);
717 if (clusterSizeZ)
718 result.addOperands(clusterSizeZ);
719 if (dynamicSharedMemorySize)
720 result.addOperands(dynamicSharedMemorySize);
721
722 // Create a kernel body region with kNumConfigRegionAttributes + N memory
723 // attributions, where the first kNumConfigRegionAttributes arguments have
724 // `index` type and the rest have the same types as the data operands.
725 Region *kernelRegion = result.addRegion();
726 Block *body = builder.createBlock(kernelRegion);
727 // TODO: Allow passing in proper locations here.
728 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
729 body->addArgument(builder.getIndexType(), result.location);
730 // Add WorkGroup & Private attributions to the region arguments.
731 for (Type argTy : workgroupAttributions)
732 body->addArgument(argTy, result.location);
733 for (Type argTy : privateAttributions)
734 body->addArgument(argTy, result.location);
735 // Fill OperandSegmentSize Attribute.
736 SmallVector<int32_t, 11> segmentSizes(11, 1);
737 segmentSizes.front() = asyncDependencies.size();
738 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
739 segmentSizes[7] = clusterSizeX ? 1 : 0;
740 segmentSizes[8] = clusterSizeY ? 1 : 0;
741 segmentSizes[9] = clusterSizeZ ? 1 : 0;
742 result.addAttribute(getOperandSegmentSizeAttr(),
743 builder.getDenseI32ArrayAttr(segmentSizes));
744}
745
746KernelDim3 LaunchOp::getBlockIds() {
747 assert(!getBody().empty() && "LaunchOp body must not be empty.");
748 auto args = getBody().getArguments();
749 return KernelDim3{args[0], args[1], args[2]};
750}
751
752KernelDim3 LaunchOp::getThreadIds() {
753 assert(!getBody().empty() && "LaunchOp body must not be empty.");
754 auto args = getBody().getArguments();
755 return KernelDim3{args[3], args[4], args[5]};
756}
757
758KernelDim3 LaunchOp::getGridSize() {
759 assert(!getBody().empty() && "LaunchOp body must not be empty.");
760 auto args = getBody().getArguments();
761 return KernelDim3{args[6], args[7], args[8]};
762}
763
764KernelDim3 LaunchOp::getBlockSize() {
765 assert(!getBody().empty() && "LaunchOp body must not be empty.");
766 auto args = getBody().getArguments();
767 return KernelDim3{args[9], args[10], args[11]};
768}
769
770std::optional<KernelDim3> LaunchOp::getClusterIds() {
771 assert(!getBody().empty() && "LaunchOp body must not be empty.");
772 if (!hasClusterSize())
773 return std::nullopt;
774 auto args = getBody().getArguments();
775 return KernelDim3{args[12], args[13], args[14]};
776}
777
778std::optional<KernelDim3> LaunchOp::getClusterSize() {
779 assert(!getBody().empty() && "LaunchOp body must not be empty.");
780 if (!hasClusterSize())
781 return std::nullopt;
782 auto args = getBody().getArguments();
783 return KernelDim3{args[15], args[16], args[17]};
784}
785
786KernelDim3 LaunchOp::getGridSizeOperandValues() {
787 auto operands = getOperands().drop_front(getAsyncDependencies().size());
788 return KernelDim3{operands[0], operands[1], operands[2]};
789}
790
791KernelDim3 LaunchOp::getBlockSizeOperandValues() {
792 auto operands = getOperands().drop_front(getAsyncDependencies().size());
793 return KernelDim3{operands[3], operands[4], operands[5]};
794}
795
796std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
797 auto operands = getOperands().drop_front(getAsyncDependencies().size());
798 if (!hasClusterSize())
799 return std::nullopt;
800 return KernelDim3{operands[6], operands[7], operands[8]};
801}
802
803LogicalResult LaunchOp::verify() {
804 if (!(hasClusterSize()) &&
805 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
806 return emitOpError() << "cluster size must be all present";
807 return success();
808}
809
810LogicalResult LaunchOp::verifyRegions() {
811 // Kernel launch takes kNumConfigOperands leading operands for grid/block
812 // sizes and transforms them into kNumConfigRegionAttributes region arguments
813 // for block/thread identifiers and grid/block sizes.
814 if (!getBody().empty()) {
815 if (getBody().getNumArguments() <
816 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
817 return emitOpError("unexpected number of region arguments");
818 }
819
820 // Verify Attributions Address Spaces.
821 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
822 GPUDialect::getWorkgroupAddressSpace())) ||
823 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
824 GPUDialect::getPrivateAddressSpace())))
825 return failure();
826
827 // Block terminators without successors are expected to exit the kernel region
828 // and must be `gpu.terminator`.
829 for (Block &block : getBody()) {
830 if (block.empty())
831 continue;
832 if (block.back().getNumSuccessors() != 0)
833 continue;
834 if (!isa<gpu::TerminatorOp>(&block.back())) {
835 return block.back()
836 .emitError()
837 .append("expected '", gpu::TerminatorOp::getOperationName(),
838 "' or a terminator with successors")
839 .attachNote(getLoc())
840 .append("in '", LaunchOp::getOperationName(), "' body region");
841 }
842 }
843
844 if (getNumResults() == 0 && getAsyncToken())
845 return emitOpError("needs to be named when async keyword is specified");
846
847 return success();
848}
849
850// Pretty-print the kernel grid/block size assignment as
851// (%iter-x, %iter-y, %iter-z) in
852// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
853// where %size-* and %iter-* will correspond to the body region arguments.
854static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
855 KernelDim3 operands, KernelDim3 ids) {
856 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
857 p << size.x << " = " << operands.x << ", ";
858 p << size.y << " = " << operands.y << ", ";
859 p << size.z << " = " << operands.z << ')';
860}
861
862void LaunchOp::print(OpAsmPrinter &p) {
863 if (getAsyncToken()) {
864 p << " async";
865 if (!getAsyncDependencies().empty())
866 p << " [" << getAsyncDependencies() << ']';
867 }
868 // Print the launch configuration.
869 if (hasClusterSize()) {
870 p << ' ' << getClustersKeyword();
871 printSizeAssignment(p, getClusterSize().value(),
872 getClusterSizeOperandValues().value(),
873 getClusterIds().value());
874 }
875 p << ' ' << getBlocksKeyword();
876 printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
877 getBlockIds());
878 p << ' ' << getThreadsKeyword();
879 printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(),
880 getThreadIds());
881 if (getDynamicSharedMemorySize())
882 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
883 << getDynamicSharedMemorySize();
884
885 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
886 printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
887
888 p << ' ';
889
890 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
891 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
892 LaunchOp::getOperandSegmentSizeAttr(),
893 getNumWorkgroupAttributionsAttrName()});
894}
895
896// Parse the size assignment blocks for blocks and threads. These have the form
897// (%region_arg, %region_arg, %region_arg) in
898// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
899// where %region_arg are percent-identifiers for the region arguments to be
900// introduced further (SSA defs), and %operand are percent-identifiers for the
901// SSA value uses.
902static ParseResult
903parseSizeAssignment(OpAsmParser &parser,
904 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,
905 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,
906 MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
907 assert(indices.size() == 3 && "space for three indices expected");
908 SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
909 if (parser.parseOperandList(result&: args, delimiter: OpAsmParser::Delimiter::Paren,
910 /*allowResultNumber=*/false) ||
911 parser.parseKeyword(keyword: "in") || parser.parseLParen())
912 return failure();
913 std::move(first: args.begin(), last: args.end(), result: indices.begin());
914
915 for (int i = 0; i < 3; ++i) {
916 if (i != 0 && parser.parseComma())
917 return failure();
918 if (parser.parseOperand(result&: regionSizes[i], /*allowResultNumber=*/false) ||
919 parser.parseEqual() || parser.parseOperand(result&: sizes[i]))
920 return failure();
921 }
922
923 return parser.parseRParen();
924}
925
926/// Parses a Launch operation.
927/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
928/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
929/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
930/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
931/// memory-attribution
932/// region attr-dict?
933/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
934ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
935 // Sizes of the grid and block.
936 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
937 sizes(LaunchOp::kNumConfigOperands);
938
939 // Region arguments to be created.
940 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
941 LaunchOp::kNumConfigRegionAttributes);
942
943 // Parse optional async dependencies.
944 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
945 Type asyncTokenType;
946 if (failed(
947 parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
948 parser.resolveOperands(asyncDependencies, asyncTokenType,
949 result.operands))
950 return failure();
951 if (parser.getNumResults() > 0)
952 result.types.push_back(asyncTokenType);
953
954 bool hasCluster = false;
955 if (succeeded(
956 parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) {
957 hasCluster = true;
958 sizes.resize(9);
959 regionArgs.resize(18);
960 }
961 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
962 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
963
964 // Last three segment assigns the cluster size. In the region argument
965 // list, this is last 6 arguments.
966 if (hasCluster) {
967 if (parseSizeAssignment(parser, sizesRef.drop_front(6),
968 regionArgsRef.slice(15, 3),
969 regionArgsRef.slice(12, 3)))
970 return failure();
971 }
972 // Parse the size assignment segments: the first segment assigns grid sizes
973 // and defines values for block identifiers; the second segment assigns block
974 // sizes and defines values for thread identifiers. In the region argument
975 // list, identifiers precede sizes, and block-related values precede
976 // thread-related values.
977 if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) ||
978 parseSizeAssignment(parser, sizesRef.take_front(3),
979 regionArgsRef.slice(6, 3),
980 regionArgsRef.slice(0, 3)) ||
981 parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) ||
982 parseSizeAssignment(parser, sizesRef.drop_front(3),
983 regionArgsRef.slice(9, 3),
984 regionArgsRef.slice(3, 3)) ||
985 parser.resolveOperands(sizes, parser.getBuilder().getIndexType(),
986 result.operands))
987 return failure();
988
989 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
990 bool hasDynamicSharedMemorySize = false;
991 if (!parser.parseOptionalKeyword(
992 LaunchOp::getDynamicSharedMemorySizeKeyword())) {
993 hasDynamicSharedMemorySize = true;
994 if (parser.parseOperand(dynamicSharedMemorySize) ||
995 parser.resolveOperand(dynamicSharedMemorySize,
996 parser.getBuilder().getI32Type(),
997 result.operands))
998 return failure();
999 }
1000
1001 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1002 // that correspond to block/thread identifiers and grid/block sizes, all
1003 // having `index` type, a variadic number of WorkGroup Attributions and
1004 // a variadic number of Private Attributions. The number of WorkGroup
1005 // Attributions is stored in the attr with name:
1006 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1007 Type index = parser.getBuilder().getIndexType();
1008 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1009 LaunchOp::kNumConfigRegionAttributes + 6, index);
1010
1011 SmallVector<OpAsmParser::Argument> regionArguments;
1012 for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
1013 OpAsmParser::Argument arg;
1014 arg.ssaName = std::get<0>(ssaValueAndType);
1015 arg.type = std::get<1>(ssaValueAndType);
1016 regionArguments.push_back(arg);
1017 }
1018
1019 Builder &builder = parser.getBuilder();
1020 // Parse workgroup memory attributions.
1021 if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(),
1022 regionArguments)))
1023 return failure();
1024
1025 // Store the number of operands we just parsed as the number of workgroup
1026 // memory attributions.
1027 unsigned numWorkgroupAttrs = regionArguments.size() -
1028 LaunchOp::kNumConfigRegionAttributes -
1029 (hasCluster ? 6 : 0);
1030 result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(),
1031 builder.getI64IntegerAttr(numWorkgroupAttrs));
1032
1033 // Parse private memory attributions.
1034 if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(),
1035 regionArguments)))
1036 return failure();
1037
1038 // Introduce the body region and parse it. The region has
1039 // kNumConfigRegionAttributes arguments that correspond to
1040 // block/thread identifiers and grid/block sizes, all having `index` type.
1041 Region *body = result.addRegion();
1042 if (parser.parseRegion(*body, regionArguments) ||
1043 parser.parseOptionalAttrDict(result.attributes))
1044 return failure();
1045
1046 SmallVector<int32_t, 11> segmentSizes(11, 1);
1047 segmentSizes.front() = asyncDependencies.size();
1048
1049 if (!hasCluster) {
1050 segmentSizes[7] = 0;
1051 segmentSizes[8] = 0;
1052 segmentSizes[9] = 0;
1053 }
1054 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1055 result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
1056 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
1057 return success();
1058}
1059
1060/// Simplify the gpu.launch when the range of a thread or block ID is
1061/// trivially known to be one.
1062struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1063 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1064 LogicalResult matchAndRewrite(LaunchOp op,
1065 PatternRewriter &rewriter) const override {
1066 // If the range implies a single value for `id`, replace `id`'s uses by
1067 // zero.
1068 Value zero;
1069 bool simplified = false;
1070 auto constPropIdUses = [&](Value id, Value size) {
1071 // Check if size is trivially one.
1072 if (!matchPattern(value: size, pattern: m_One()))
1073 return;
1074 if (id.getUses().empty())
1075 return;
1076 if (!simplified) {
1077 // Create a zero value the first time.
1078 OpBuilder::InsertionGuard guard(rewriter);
1079 rewriter.setInsertionPointToStart(&op.getBody().front());
1080 zero =
1081 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0);
1082 }
1083 rewriter.replaceAllUsesWith(from: id, to: zero);
1084 simplified = true;
1085 };
1086 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1087 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1088 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1089 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1090 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1091 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1092
1093 return success(IsSuccess: simplified);
1094 }
1095};
1096
1097void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1098 MLIRContext *context) {
1099 rewrites.add<FoldLaunchArguments>(context);
1100}
1101
1102/// Adds a new block argument that corresponds to buffers located in
1103/// workgroup memory.
1104BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1105 auto attrName = getNumWorkgroupAttributionsAttrName();
1106 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1107 (*this)->setAttr(attrName,
1108 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1109 return getBody().insertArgument(
1110 LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1111}
1112
1113/// Adds a new block argument that corresponds to buffers located in
1114/// private memory.
1115BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1116 // Buffers on the private memory always come after buffers on the workgroup
1117 // memory.
1118 return getBody().addArgument(type, loc);
1119}
1120
1121//===----------------------------------------------------------------------===//
1122// LaunchFuncOp
1123//===----------------------------------------------------------------------===//
1124
1125void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1126 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1127 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1128 ValueRange kernelOperands, Type asyncTokenType,
1129 ValueRange asyncDependencies,
1130 std::optional<KernelDim3> clusterSize) {
1131 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1132 "expected a symbol reference with a single nested reference");
1133 result.addOperands(asyncDependencies);
1134 if (asyncTokenType)
1135 result.types.push_back(builder.getType<AsyncTokenType>());
1136
1137 // Add grid and block sizes as op operands, followed by the data operands.
1138 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1139 getBlockSize.y, getBlockSize.z});
1140 if (clusterSize.has_value())
1141 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1142 if (dynamicSharedMemorySize)
1143 result.addOperands(dynamicSharedMemorySize);
1144 result.addOperands(kernelOperands);
1145
1146 Properties &prop = result.getOrAddProperties<Properties>();
1147 prop.kernel = kernelSymbol;
1148 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1149 // Initialize the segment sizes to 1.
1150 for (auto &sz : prop.operandSegmentSizes)
1151 sz = 1;
1152 prop.operandSegmentSizes[0] = asyncDependencies.size();
1153 if (!clusterSize.has_value()) {
1154 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1155 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1156 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1157 }
1158 prop.operandSegmentSizes[segmentSizesLen - 3] =
1159 dynamicSharedMemorySize ? 1 : 0;
1160 prop.operandSegmentSizes[segmentSizesLen - 2] =
1161 static_cast<int32_t>(kernelOperands.size());
1162 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1163}
1164
1165void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1166 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1167 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1168 ValueRange kernelOperands, Type asyncTokenType,
1169 ValueRange asyncDependencies,
1170 std::optional<KernelDim3> clusterSize) {
1171 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1172 auto kernelSymbol =
1173 SymbolRefAttr::get(kernelModule.getNameAttr(),
1174 {SymbolRefAttr::get(kernelFunc.getNameAttr())});
1175 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1176 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1177 asyncDependencies, clusterSize);
1178}
1179
1180void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1181 SymbolRefAttr kernel, KernelDim3 gridSize,
1182 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1183 ValueRange kernelOperands, Value asyncObject,
1184 std::optional<KernelDim3> clusterSize) {
1185 // Add grid and block sizes as op operands, followed by the data operands.
1186 result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1187 getBlockSize.y, getBlockSize.z});
1188 if (clusterSize.has_value())
1189 result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z});
1190 if (dynamicSharedMemorySize)
1191 result.addOperands(dynamicSharedMemorySize);
1192 result.addOperands(kernelOperands);
1193 if (asyncObject)
1194 result.addOperands(asyncObject);
1195 Properties &prop = result.getOrAddProperties<Properties>();
1196 prop.kernel = kernel;
1197 size_t segmentSizesLen = std::size(prop.operandSegmentSizes);
1198 // Initialize the segment sizes to 1.
1199 for (auto &sz : prop.operandSegmentSizes)
1200 sz = 1;
1201 prop.operandSegmentSizes[0] = 0;
1202 if (!clusterSize.has_value()) {
1203 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1204 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1205 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1206 }
1207 prop.operandSegmentSizes[segmentSizesLen - 3] =
1208 dynamicSharedMemorySize ? 1 : 0;
1209 prop.operandSegmentSizes[segmentSizesLen - 2] =
1210 static_cast<int32_t>(kernelOperands.size());
1211 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1212}
1213
1214StringAttr LaunchFuncOp::getKernelModuleName() {
1215 return getKernel().getRootReference();
1216}
1217
1218StringAttr LaunchFuncOp::getKernelName() {
1219 return getKernel().getLeafReference();
1220}
1221
1222unsigned LaunchFuncOp::getNumKernelOperands() {
1223 return getKernelOperands().size();
1224}
1225
1226Value LaunchFuncOp::getKernelOperand(unsigned i) {
1227 return getKernelOperands()[i];
1228}
1229
1230KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1231 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1232 return KernelDim3{operands[0], operands[1], operands[2]};
1233}
1234
1235KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1236 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1237 return KernelDim3{operands[3], operands[4], operands[5]};
1238}
1239
1240KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1241 assert(hasClusterSize() &&
1242 "cluster size is not set, check hasClusterSize() first");
1243 auto operands = getOperands().drop_front(getAsyncDependencies().size());
1244 return KernelDim3{operands[6], operands[7], operands[8]};
1245}
1246
1247LogicalResult LaunchFuncOp::verify() {
1248 auto module = (*this)->getParentOfType<ModuleOp>();
1249 if (!module)
1250 return emitOpError("expected to belong to a module");
1251
1252 if (!module->getAttrOfType<UnitAttr>(
1253 GPUDialect::getContainerModuleAttrName()))
1254 return emitOpError("expected the closest surrounding module to have the '" +
1255 GPUDialect::getContainerModuleAttrName() +
1256 "' attribute");
1257
1258 if (hasClusterSize()) {
1259 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1260 getClusterSizeZ().getType() != getClusterSizeX().getType())
1261 return emitOpError()
1262 << "expects types of the cluster dimensions must be the same";
1263 }
1264
1265 return success();
1266}
1267
1268static ParseResult
1269parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
1270 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1271 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1272 if (succeeded(Result: parser.parseOptionalColon())) {
1273 if (parser.parseType(result&: dimTy))
1274 return failure();
1275 } else {
1276 dimTy = IndexType::get(parser.getContext());
1277 }
1278 if (clusterValue.has_value()) {
1279 clusterXTy = clusterYTy = clusterZTy = dimTy;
1280 }
1281 return success();
1282}
1283
1284static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1285 Value clusterValue, Type clusterXTy,
1286 Type clusterYTy, Type clusterZTy) {
1287 if (!dimTy.isIndex())
1288 printer << ": " << dimTy;
1289}
1290
1291static ParseResult parseLaunchFuncOperands(
1292 OpAsmParser &parser,
1293 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
1294 SmallVectorImpl<Type> &argTypes) {
1295 if (parser.parseOptionalKeyword(keyword: "args"))
1296 return success();
1297
1298 auto parseElement = [&]() -> ParseResult {
1299 return failure(IsFailure: parser.parseOperand(result&: argNames.emplace_back()) ||
1300 parser.parseColonType(result&: argTypes.emplace_back()));
1301 };
1302
1303 return parser.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren,
1304 parseElementFn: parseElement, contextMessage: " in argument list");
1305}
1306
1307static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
1308 OperandRange operands, TypeRange types) {
1309 if (operands.empty())
1310 return;
1311 printer << "args(";
1312 llvm::interleaveComma(c: llvm::zip_equal(t&: operands, u&: types), os&: printer,
1313 each_fn: [&](const auto &pair) {
1314 auto [operand, type] = pair;
1315 printer << operand << " : " << type;
1316 });
1317 printer << ")";
1318}
1319
1320//===----------------------------------------------------------------------===//
1321// ShuffleOp
1322//===----------------------------------------------------------------------===//
1323
1324void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1325 int32_t offset, int32_t width, ShuffleMode mode) {
1326 build(builder, result, value,
1327 builder.create<arith::ConstantOp>(result.location,
1328 builder.getI32IntegerAttr(offset)),
1329 builder.create<arith::ConstantOp>(result.location,
1330 builder.getI32IntegerAttr(width)),
1331 mode);
1332}
1333
1334//===----------------------------------------------------------------------===//
1335// BarrierOp
1336//===----------------------------------------------------------------------===//
1337
1338namespace {
1339
1340/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1341LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1342 PatternRewriter &rewriter) {
1343 if (isa_and_nonnull<BarrierOp>(op->getNextNode())) {
1344 rewriter.eraseOp(op: op);
1345 return success();
1346 }
1347 return failure();
1348}
1349
1350} // end anonymous namespace
1351
1352void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1353 MLIRContext *context) {
1354 results.add(eraseRedundantGpuBarrierOps);
1355}
1356
1357//===----------------------------------------------------------------------===//
1358// GPUFuncOp
1359//===----------------------------------------------------------------------===//
1360
1361/// Adds a new block argument that corresponds to buffers located in
1362/// workgroup memory.
1363BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1364 auto attrName = getNumWorkgroupAttributionsAttrName();
1365 auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName);
1366 (*this)->setAttr(attrName,
1367 IntegerAttr::get(attr.getType(), attr.getValue() + 1));
1368 return getBody().insertArgument(
1369 getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1370}
1371
1372/// Adds a new block argument that corresponds to buffers located in
1373/// private memory.
1374BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1375 // Buffers on the private memory always come after buffers on the workgroup
1376 // memory.
1377 return getBody().addArgument(type, loc);
1378}
1379
1380void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1381 StringRef name, FunctionType type,
1382 TypeRange workgroupAttributions,
1383 TypeRange privateAttributions,
1384 ArrayRef<NamedAttribute> attrs) {
1385 OpBuilder::InsertionGuard g(builder);
1386
1387 result.addAttribute(SymbolTable::getSymbolAttrName(),
1388 builder.getStringAttr(name));
1389 result.addAttribute(getFunctionTypeAttrName(result.name),
1390 TypeAttr::get(type));
1391 result.addAttribute(getNumWorkgroupAttributionsAttrName(),
1392 builder.getI64IntegerAttr(workgroupAttributions.size()));
1393 result.addAttributes(attrs);
1394 Region *body = result.addRegion();
1395 Block *entryBlock = builder.createBlock(body);
1396
1397 // TODO: Allow passing in proper locations here.
1398 for (Type argTy : type.getInputs())
1399 entryBlock->addArgument(argTy, result.location);
1400 for (Type argTy : workgroupAttributions)
1401 entryBlock->addArgument(argTy, result.location);
1402 for (Type argTy : privateAttributions)
1403 entryBlock->addArgument(argTy, result.location);
1404}
1405
1406/// Parses a GPU function memory attribution.
1407///
1408/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1409/// (`private` `(` ssa-id-and-type-list `)`)?
1410///
1411/// Note that this function parses only one of the two similar parts, with the
1412/// keyword provided as argument.
1413static ParseResult
1414parseAttributions(OpAsmParser &parser, StringRef keyword,
1415 SmallVectorImpl<OpAsmParser::Argument> &args,
1416 Attribute &attributionAttrs) {
1417 // If we could not parse the keyword, just assume empty list and succeed.
1418 if (failed(Result: parser.parseOptionalKeyword(keyword)))
1419 return success();
1420
1421 size_t existingArgs = args.size();
1422 ParseResult result =
1423 parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren,
1424 /*allowType=*/true, /*allowAttrs=*/true);
1425 if (failed(Result: result))
1426 return result;
1427
1428 bool hadAttrs = llvm::any_of(Range: ArrayRef(args).drop_front(N: existingArgs),
1429 P: [](const OpAsmParser::Argument &arg) -> bool {
1430 return arg.attrs && !arg.attrs.empty();
1431 });
1432 if (!hadAttrs) {
1433 attributionAttrs = nullptr;
1434 return result;
1435 }
1436
1437 Builder &builder = parser.getBuilder();
1438 SmallVector<Attribute> attributionAttrsVec;
1439 for (const auto &argument : ArrayRef(args).drop_front(N: existingArgs)) {
1440 if (!argument.attrs)
1441 attributionAttrsVec.push_back(builder.getDictionaryAttr({}));
1442 else
1443 attributionAttrsVec.push_back(Elt: argument.attrs);
1444 }
1445 attributionAttrs = builder.getArrayAttr(attributionAttrsVec);
1446 return result;
1447}
1448
1449/// Parses a GPU function.
1450///
1451/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1452/// (`->` function-result-list)? memory-attribution `kernel`?
1453/// function-attributes? region
1454ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1455 SmallVector<OpAsmParser::Argument> entryArgs;
1456 SmallVector<DictionaryAttr> resultAttrs;
1457 SmallVector<Type> resultTypes;
1458 bool isVariadic;
1459
1460 // Parse the function name.
1461 StringAttr nameAttr;
1462 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1463 result.attributes))
1464 return failure();
1465
1466 auto signatureLocation = parser.getCurrentLocation();
1467 if (failed(function_interface_impl::parseFunctionSignatureWithArguments(
1468 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
1469 resultAttrs)))
1470 return failure();
1471
1472 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1473 return parser.emitError(signatureLocation)
1474 << "gpu.func requires named arguments";
1475
1476 // Construct the function type. More types will be added to the region, but
1477 // not to the function type.
1478 Builder &builder = parser.getBuilder();
1479
1480 SmallVector<Type> argTypes;
1481 for (auto &arg : entryArgs)
1482 argTypes.push_back(arg.type);
1483 auto type = builder.getFunctionType(argTypes, resultTypes);
1484 result.addAttribute(getFunctionTypeAttrName(result.name),
1485 TypeAttr::get(type));
1486
1487 call_interface_impl::addArgAndResultAttrs(
1488 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
1489 getResAttrsAttrName(result.name));
1490
1491 Attribute workgroupAttributionAttrs;
1492 // Parse workgroup memory attributions.
1493 if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
1494 entryArgs, workgroupAttributionAttrs)))
1495 return failure();
1496
1497 // Store the number of operands we just parsed as the number of workgroup
1498 // memory attributions.
1499 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1500 result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1501 builder.getI64IntegerAttr(numWorkgroupAttrs));
1502 if (workgroupAttributionAttrs)
1503 result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name),
1504 workgroupAttributionAttrs);
1505
1506 Attribute privateAttributionAttrs;
1507 // Parse private memory attributions.
1508 if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
1509 entryArgs, privateAttributionAttrs)))
1510 return failure();
1511 if (privateAttributionAttrs)
1512 result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name),
1513 privateAttributionAttrs);
1514
1515 // Parse the kernel attribute if present.
1516 if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword())))
1517 result.addAttribute(GPUDialect::getKernelFuncAttrName(),
1518 builder.getUnitAttr());
1519
1520 // Parse attributes.
1521 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1522 return failure();
1523
1524 // Parse the region. If no argument names were provided, take all names
1525 // (including those of attributions) from the entry block.
1526 auto *body = result.addRegion();
1527 return parser.parseRegion(*body, entryArgs);
1528}
1529
1530static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1531 ArrayRef<BlockArgument> values,
1532 ArrayAttr attributes) {
1533 if (values.empty())
1534 return;
1535
1536 p << ' ' << keyword << '(';
1537 llvm::interleaveComma(
1538 c: llvm::enumerate(First&: values), os&: p, each_fn: [&p, attributes](auto pair) {
1539 BlockArgument v = pair.value();
1540 p << v << " : " << v.getType();
1541
1542 size_t attributionIndex = pair.index();
1543 DictionaryAttr attrs;
1544 if (attributes && attributionIndex < attributes.size())
1545 attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]);
1546 if (attrs)
1547 p.printOptionalAttrDict(attrs: attrs.getValue());
1548 });
1549 p << ')';
1550}
1551
1552void GPUFuncOp::print(OpAsmPrinter &p) {
1553 p << ' ';
1554 p.printSymbolName(getName());
1555
1556 FunctionType type = getFunctionType();
1557 function_interface_impl::printFunctionSignature(p, *this, type.getInputs(),
1558 /*isVariadic=*/false,
1559 type.getResults());
1560
1561 printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(),
1562 getWorkgroupAttribAttrs().value_or(nullptr));
1563 printAttributions(p, getPrivateKeyword(), getPrivateAttributions(),
1564 getPrivateAttribAttrs().value_or(nullptr));
1565 if (isKernel())
1566 p << ' ' << getKernelKeyword();
1567
1568 function_interface_impl::printFunctionAttributes(
1569 p, *this,
1570 {getNumWorkgroupAttributionsAttrName(),
1571 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1572 getArgAttrsAttrName(), getResAttrsAttrName(),
1573 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1574 p << ' ';
1575 p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
1576}
1577
1578static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1579 StringAttr attrName) {
1580 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1581 if (!allAttrs || index >= allAttrs.size())
1582 return DictionaryAttr();
1583 return llvm::cast<DictionaryAttr>(allAttrs[index]);
1584}
1585
1586DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1587 return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName());
1588}
1589
1590DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1591 return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName());
1592}
1593
1594static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1595 DictionaryAttr value, StringAttr attrName) {
1596 MLIRContext *ctx = op.getContext();
1597 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName));
1598 SmallVector<Attribute> elements;
1599 if (allAttrs)
1600 elements.append(allAttrs.begin(), allAttrs.end());
1601 while (elements.size() <= index)
1602 elements.push_back(DictionaryAttr::get(ctx));
1603 if (!value)
1604 elements[index] = DictionaryAttr::get(ctx);
1605 else
1606 elements[index] = value;
1607 ArrayAttr newValue = ArrayAttr::get(ctx, elements);
1608 op->setAttr(attrName, newValue);
1609}
1610
1611void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1612 DictionaryAttr value) {
1613 setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName());
1614}
1615
1616void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1617 DictionaryAttr value) {
1618 setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName());
1619}
1620
1621static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1622 StringAttr name, StringAttr attrsName) {
1623 DictionaryAttr dict = getAttributionAttrs(op, index, attrsName);
1624 if (!dict)
1625 return Attribute();
1626 return dict.get(name);
1627}
1628
1629Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1630 StringAttr name) {
1631 assert(index < getNumWorkgroupAttributions() &&
1632 "index must map to a workgroup attribution");
1633 return getAttributionAttr(*this, index, name,
1634 getWorkgroupAttribAttrsAttrName());
1635}
1636
1637Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1638 StringAttr name) {
1639 assert(index < getNumPrivateAttributions() &&
1640 "index must map to a private attribution");
1641 return getAttributionAttr(*this, index, name,
1642 getPrivateAttribAttrsAttrName());
1643}
1644
1645static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1646 Attribute value, StringAttr attrsName) {
1647 MLIRContext *ctx = op.getContext();
1648 SmallVector<NamedAttribute> elems;
1649 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName);
1650 if (oldDict)
1651 elems.append(oldDict.getValue().begin(), oldDict.getValue().end());
1652
1653 bool found = false;
1654 bool mustSort = true;
1655 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1656 if (elems[i].getName() == name) {
1657 found = true;
1658 if (!value) {
1659 std::swap(a&: elems[i], b&: elems[elems.size() - 1]);
1660 elems.pop_back();
1661 } else {
1662 mustSort = false;
1663 elems[i] = NamedAttribute(elems[i].getName(), value);
1664 }
1665 break;
1666 }
1667 }
1668 if (!found) {
1669 if (!value)
1670 return;
1671 elems.emplace_back(name, value);
1672 }
1673 if (mustSort) {
1674 DictionaryAttr::sortInPlace(elems);
1675 }
1676 auto newDict = DictionaryAttr::getWithSorted(ctx, elems);
1677 setAttributionAttrs(op, index, newDict, attrsName);
1678}
1679
1680void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1681 Attribute value) {
1682 assert(index < getNumWorkgroupAttributions() &&
1683 "index must map to a workgroup attribution");
1684 setAttributionAttr(*this, index, name, value,
1685 getWorkgroupAttribAttrsAttrName());
1686}
1687
1688void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1689 Attribute value) {
1690 assert(index < getNumPrivateAttributions() &&
1691 "index must map to a private attribution");
1692 setAttributionAttr(*this, index, name, value,
1693 getPrivateAttribAttrsAttrName());
1694}
1695
1696LogicalResult GPUFuncOp::verifyType() {
1697 if (isKernel() && getFunctionType().getNumResults() != 0)
1698 return emitOpError() << "expected void return type for kernel function";
1699
1700 return success();
1701}
1702
1703/// Verifies the body of the function.
1704LogicalResult GPUFuncOp::verifyBody() {
1705 if (empty())
1706 return emitOpError() << "expected body with at least one block";
1707 unsigned numFuncArguments = getNumArguments();
1708 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1709 unsigned numBlockArguments = front().getNumArguments();
1710 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1711 return emitOpError() << "expected at least "
1712 << numFuncArguments + numWorkgroupAttributions
1713 << " arguments to body region";
1714
1715 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1716 for (unsigned i = 0; i < numFuncArguments; ++i) {
1717 Type blockArgType = front().getArgument(i).getType();
1718 if (funcArgTypes[i] != blockArgType)
1719 return emitOpError() << "expected body region argument #" << i
1720 << " to be of type " << funcArgTypes[i] << ", got "
1721 << blockArgType;
1722 }
1723
1724 if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(),
1725 GPUDialect::getWorkgroupAddressSpace())) ||
1726 failed(verifyAttributions(getOperation(), getPrivateAttributions(),
1727 GPUDialect::getPrivateAddressSpace())))
1728 return failure();
1729
1730 return success();
1731}
1732
1733//===----------------------------------------------------------------------===//
1734// ReturnOp
1735//===----------------------------------------------------------------------===//
1736
1737LogicalResult gpu::ReturnOp::verify() {
1738 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1739
1740 FunctionType funType = function.getFunctionType();
1741
1742 if (funType.getNumResults() != getOperands().size())
1743 return emitOpError()
1744 .append("expected ", funType.getNumResults(), " result operands")
1745 .attachNote(function.getLoc())
1746 .append("return type declared here");
1747
1748 for (const auto &pair : llvm::enumerate(
1749 llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
1750 auto [type, operand] = pair.value();
1751 if (type != operand.getType())
1752 return emitOpError() << "unexpected type `" << operand.getType()
1753 << "' for operand #" << pair.index();
1754 }
1755 return success();
1756}
1757
1758//===----------------------------------------------------------------------===//
1759// GPUModuleOp
1760//===----------------------------------------------------------------------===//
1761
1762void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1763 StringRef name, ArrayAttr targets,
1764 Attribute offloadingHandler) {
1765 result.addRegion()->emplaceBlock();
1766 Properties &props = result.getOrAddProperties<Properties>();
1767 if (targets)
1768 props.targets = targets;
1769 props.setSymName(builder.getStringAttr(name));
1770 props.offloadingHandler = offloadingHandler;
1771}
1772
1773void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1774 StringRef name, ArrayRef<Attribute> targets,
1775 Attribute offloadingHandler) {
1776 build(builder, result, name,
1777 targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets),
1778 offloadingHandler);
1779}
1780
1781bool GPUModuleOp::hasTarget(Attribute target) {
1782 if (ArrayAttr targets = getTargetsAttr())
1783 return llvm::count(targets.getValue(), target);
1784 return false;
1785}
1786
1787void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1788 ArrayAttr &targetsAttr = getProperties().targets;
1789 SmallVector<Attribute> targetsVector(targets);
1790 targetsAttr = ArrayAttr::get(getContext(), targetsVector);
1791}
1792
1793LogicalResult GPUModuleOp::verify() {
1794 auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1795
1796 if (!targets)
1797 return success();
1798
1799 for (auto target : targets) {
1800 if (auto verifyTargetAttr =
1801 llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1802 if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1803 return failure();
1804 }
1805 }
1806 return success();
1807}
1808
1809//===----------------------------------------------------------------------===//
1810// GPUBinaryOp
1811//===----------------------------------------------------------------------===//
1812void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1813 Attribute offloadingHandler, ArrayAttr objects) {
1814 auto &properties = result.getOrAddProperties<Properties>();
1815 result.attributes.push_back(builder.getNamedAttr(
1816 SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
1817 properties.objects = objects;
1818 if (offloadingHandler)
1819 properties.offloadingHandler = offloadingHandler;
1820 else
1821 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr);
1822}
1823
1824void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1825 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1826 build(builder, result, name, offloadingHandler,
1827 objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects));
1828}
1829
1830static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1831 Attribute &offloadingHandler) {
1832 if (succeeded(Result: parser.parseOptionalLess())) {
1833 if (parser.parseAttribute(result&: offloadingHandler))
1834 return failure();
1835 if (parser.parseGreater())
1836 return failure();
1837 }
1838 if (!offloadingHandler)
1839 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr);
1840 return success();
1841}
1842
1843static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op,
1844 Attribute offloadingHandler) {
1845 if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr))
1846 printer << '<' << offloadingHandler << '>';
1847}
1848
1849//===----------------------------------------------------------------------===//
1850// GPUMemcpyOp
1851//===----------------------------------------------------------------------===//
1852
1853LogicalResult MemcpyOp::verify() {
1854 auto srcType = getSrc().getType();
1855 auto dstType = getDst().getType();
1856
1857 if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType))
1858 return emitOpError("arguments have incompatible element type");
1859
1860 if (failed(verifyCompatibleShape(srcType, dstType)))
1861 return emitOpError("arguments have incompatible shape");
1862
1863 return success();
1864}
1865
1866namespace {
1867
1868/// Erases a common case of copy ops where a destination value is used only by
1869/// the copy op, alloc and dealloc ops.
1870struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1871 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1872
1873 LogicalResult matchAndRewrite(MemcpyOp op,
1874 PatternRewriter &rewriter) const override {
1875 Value dest = op.getDst();
1876 Operation *destDefOp = dest.getDefiningOp();
1877 // `dest` must be defined by an op having Allocate memory effect in order to
1878 // perform the folding.
1879 if (!destDefOp ||
1880 !hasSingleEffect<MemoryEffects::Allocate>(op: destDefOp, value: dest))
1881 return failure();
1882 // We can erase `op` iff `dest` has no other use apart from its
1883 // use by `op` and dealloc ops.
1884 if (llvm::any_of(Range: dest.getUsers(), P: [op, dest](Operation *user) {
1885 return user != op &&
1886 !hasSingleEffect<MemoryEffects::Free>(user, dest);
1887 }))
1888 return failure();
1889 // We can perform the folding if and only if op has a single async
1890 // dependency and produces an async token as result, or if it does not have
1891 // any async dependency and does not produce any async token result.
1892 if (op.getAsyncDependencies().size() > 1 ||
1893 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1894 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1895 return failure();
1896 rewriter.replaceOp(op, op.getAsyncDependencies());
1897 return success();
1898 }
1899};
1900
1901} // end anonymous namespace
1902
1903void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1904 MLIRContext *context) {
1905 results.add<EraseTrivialCopyOp>(context);
1906}
1907
1908//===----------------------------------------------------------------------===//
1909// GPU_SubgroupMmaLoadMatrixOp
1910//===----------------------------------------------------------------------===//
1911
1912LogicalResult SubgroupMmaLoadMatrixOp::verify() {
1913 auto srcType = getSrcMemref().getType();
1914 auto resType = getRes().getType();
1915 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType);
1916 auto operand = resMatrixType.getOperand();
1917 auto srcMemrefType = llvm::cast<MemRefType>(srcType);
1918
1919 if (!srcMemrefType.isLastDimUnitStride())
1920 return emitError(
1921 "expected source memref most minor dim must have unit stride");
1922
1923 if (operand != "AOp" && operand != "BOp" && operand != "COp")
1924 return emitError("only AOp, BOp and COp can be loaded");
1925
1926 return success();
1927}
1928
1929//===----------------------------------------------------------------------===//
1930// GPU_SubgroupMmaStoreMatrixOp
1931//===----------------------------------------------------------------------===//
1932
1933LogicalResult SubgroupMmaStoreMatrixOp::verify() {
1934 auto srcType = getSrc().getType();
1935 auto dstType = getDstMemref().getType();
1936 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
1937 auto dstMemrefType = llvm::cast<MemRefType>(dstType);
1938
1939 if (!dstMemrefType.isLastDimUnitStride())
1940 return emitError(
1941 "expected destination memref most minor dim must have unit stride");
1942
1943 if (srcMatrixType.getOperand() != "COp")
1944 return emitError(
1945 "expected the operand matrix being stored to have 'COp' operand type");
1946
1947 return success();
1948}
1949
1950//===----------------------------------------------------------------------===//
1951// GPU_SubgroupMmaComputeOp
1952//===----------------------------------------------------------------------===//
1953
1954LogicalResult SubgroupMmaComputeOp::verify() {
1955 enum OperandMap { A, B, C };
1956 SmallVector<MMAMatrixType, 3> opTypes;
1957 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType()));
1958 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
1959 opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
1960
1961 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1962 opTypes[C].getOperand() != "COp")
1963 return emitError("operands must be in the order AOp, BOp, COp");
1964
1965 ArrayRef<int64_t> aShape, bShape, cShape;
1966 aShape = opTypes[A].getShape();
1967 bShape = opTypes[B].getShape();
1968 cShape = opTypes[C].getShape();
1969
1970 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
1971 bShape[1] != cShape[1])
1972 return emitError("operand shapes do not satisfy matmul constraints");
1973
1974 return success();
1975}
1976
1977LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
1978 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1979 return memref::foldMemRefCast(*this);
1980}
1981
1982LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
1983 SmallVectorImpl<::mlir::OpFoldResult> &results) {
1984 return memref::foldMemRefCast(*this);
1985}
1986
1987//===----------------------------------------------------------------------===//
1988// GPU_WaitOp
1989//===----------------------------------------------------------------------===//
1990
1991namespace {
1992
1993/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
1994/// %t = gpu.wait async [] // No async dependencies.
1995/// ... gpu.wait ... [%t, ...] // %t can be removed.
1996struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
1997public:
1998 using OpRewritePattern::OpRewritePattern;
1999
2000 LogicalResult matchAndRewrite(WaitOp op,
2001 PatternRewriter &rewriter) const final {
2002 auto predicate = [](Value value) {
2003 auto waitOp = value.getDefiningOp<WaitOp>();
2004 return waitOp && waitOp->getNumOperands() == 0;
2005 };
2006 if (llvm::none_of(op.getAsyncDependencies(), predicate))
2007 return failure();
2008 SmallVector<Value> validOperands;
2009 for (Value operand : op->getOperands()) {
2010 if (predicate(operand))
2011 continue;
2012 validOperands.push_back(operand);
2013 }
2014 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
2015 return success();
2016 }
2017};
2018
2019/// Simplify trivial gpu.wait ops for the following patterns.
2020/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2021/// dependencies).
2022/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2023/// %t0.
2024/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2025/// dependencies nor return any token.
2026struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2027public:
2028 using OpRewritePattern::OpRewritePattern;
2029
2030 LogicalResult matchAndRewrite(WaitOp op,
2031 PatternRewriter &rewriter) const final {
2032 // Erase gpu.wait ops that neither have any async dependencies nor return
2033 // any async token.
2034 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2035 rewriter.eraseOp(op: op);
2036 return success();
2037 }
2038 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2039 if (llvm::hasSingleElement(op.getAsyncDependencies()) &&
2040 op.getAsyncToken()) {
2041 rewriter.replaceOp(op, op.getAsyncDependencies());
2042 return success();
2043 }
2044 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2045 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2046 rewriter.eraseOp(op: op);
2047 return success();
2048 }
2049 return failure();
2050 }
2051};
2052
2053} // end anonymous namespace
2054
2055void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2056 MLIRContext *context) {
2057 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context);
2058}
2059
2060//===----------------------------------------------------------------------===//
2061// GPU_AllocOp
2062//===----------------------------------------------------------------------===//
2063
2064LogicalResult AllocOp::verify() {
2065 auto memRefType = llvm::cast<MemRefType>(getMemref().getType());
2066
2067 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2068 return emitOpError("dimension operand count does not equal memref "
2069 "dynamic dimension count");
2070
2071 unsigned numSymbols = 0;
2072 if (!memRefType.getLayout().isIdentity())
2073 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2074 if (getSymbolOperands().size() != numSymbols) {
2075 return emitOpError(
2076 "symbol operand count does not equal memref symbol count");
2077 }
2078
2079 return success();
2080}
2081
2082namespace {
2083
2084/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2085/// `memref::AllocOp`.
2086struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2087 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2088
2089 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2090 PatternRewriter &rewriter) const override {
2091 std::optional<int64_t> index = dimOp.getConstantIndex();
2092 if (!index)
2093 return failure();
2094
2095 auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType());
2096 if (!memrefType || index.value() >= memrefType.getRank() ||
2097 !memrefType.isDynamicDim(index.value()))
2098 return failure();
2099
2100 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2101 if (!alloc)
2102 return failure();
2103
2104 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2105 memrefType.getDynamicDimIndex(index.value()));
2106 rewriter.replaceOp(dimOp, substituteOp);
2107 return success();
2108 }
2109};
2110
2111} // namespace
2112
2113void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2114 MLIRContext *context) {
2115 results.add<SimplifyDimOfAllocOp>(context);
2116}
2117
2118//===----------------------------------------------------------------------===//
2119// GPU object attribute
2120//===----------------------------------------------------------------------===//
2121
2122LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2123 Attribute target, CompilationTarget format,
2124 StringAttr object, DictionaryAttr properties,
2125 KernelTableAttr kernels) {
2126 if (!target)
2127 return emitError() << "the target attribute cannot be null";
2128 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2129 return success();
2130 return emitError() << "the target attribute must implement or promise the "
2131 "`gpu::TargetAttrInterface`";
2132}
2133
2134namespace {
2135ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2136 StringAttr &object) {
2137 std::optional<CompilationTarget> formatResult;
2138 StringRef enumKeyword;
2139 auto loc = odsParser.getCurrentLocation();
2140 if (failed(odsParser.parseOptionalKeyword(&enumKeyword)))
2141 formatResult = CompilationTarget::Fatbin;
2142 if (!formatResult &&
2143 (formatResult =
2144 gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) &&
2145 odsParser.parseEqual())
2146 return odsParser.emitError(loc, message: "expected an equal sign");
2147 if (!formatResult)
2148 return odsParser.emitError(loc, message: "expected keyword for GPU object format");
2149 FailureOr<StringAttr> objectResult =
2150 FieldParser<StringAttr>::parse(odsParser);
2151 if (failed(objectResult))
2152 return odsParser.emitError(loc: odsParser.getCurrentLocation(),
2153 message: "failed to parse GPU_ObjectAttr parameter "
2154 "'object' which is to be a `StringAttr`");
2155 format = *formatResult;
2156 object = *objectResult;
2157 return success();
2158}
2159
2160void printObject(AsmPrinter &odsParser, CompilationTarget format,
2161 StringAttr object) {
2162 if (format != CompilationTarget::Fatbin)
2163 odsParser << stringifyEnum(format) << " = ";
2164 odsParser << object;
2165}
2166} // namespace
2167
2168//===----------------------------------------------------------------------===//
2169// GPU select object attribute
2170//===----------------------------------------------------------------------===//
2171
2172LogicalResult
2173gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2174 Attribute target) {
2175 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2176 if (target) {
2177 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) {
2178 if (intAttr.getInt() < 0) {
2179 return emitError() << "the object index must be positive";
2180 }
2181 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2182 return emitError()
2183 << "the target attribute must be a GPU Target attribute";
2184 }
2185 }
2186 return success();
2187}
2188
2189//===----------------------------------------------------------------------===//
2190// DynamicSharedMemoryOp
2191//===----------------------------------------------------------------------===//
2192
2193LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2194 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2195 return emitOpError() << "must be inside an op with symbol table";
2196
2197 MemRefType memrefType = getResultMemref().getType();
2198 // Check address space
2199 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
2200 return emitOpError() << "address space must be "
2201 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2202 << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
2203 }
2204 if (memrefType.hasStaticShape()) {
2205 return emitOpError() << "result memref type must be memref<?xi8, "
2206 "#gpu.address_space<workgroup>>";
2207 }
2208 return success();
2209}
2210
2211//===----------------------------------------------------------------------===//
2212// GPU WarpExecuteOnLane0Op
2213//===----------------------------------------------------------------------===//
2214
2215void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2216 p << "(" << getLaneid() << ")";
2217
2218 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2219 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
2220 p << "[" << llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]";
2221
2222 if (!getArgs().empty())
2223 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2224 if (!getResults().empty())
2225 p << " -> (" << getResults().getTypes() << ')';
2226 p << " ";
2227 p.printRegion(getRegion(),
2228 /*printEntryBlockArgs=*/true,
2229 /*printBlockTerminators=*/!getResults().empty());
2230 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
2231}
2232
2233ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2234 OperationState &result) {
2235 // Create the region.
2236 result.regions.reserve(1);
2237 Region *warpRegion = result.addRegion();
2238
2239 auto &builder = parser.getBuilder();
2240 OpAsmParser::UnresolvedOperand laneId;
2241
2242 // Parse predicate operand.
2243 if (parser.parseLParen() ||
2244 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
2245 parser.parseRParen())
2246 return failure();
2247
2248 int64_t warpSize;
2249 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
2250 parser.parseRSquare())
2251 return failure();
2252 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
2253 builder.getContext())),
2254 builder.getI64IntegerAttr(warpSize));
2255
2256 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
2257 return failure();
2258
2259 llvm::SMLoc inputsOperandsLoc;
2260 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2261 SmallVector<Type> inputTypes;
2262 if (succeeded(parser.parseOptionalKeyword("args"))) {
2263 if (parser.parseLParen())
2264 return failure();
2265
2266 inputsOperandsLoc = parser.getCurrentLocation();
2267 if (parser.parseOperandList(inputsOperands) ||
2268 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
2269 return failure();
2270 }
2271 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
2272 result.operands))
2273 return failure();
2274
2275 // Parse optional results type list.
2276 if (parser.parseOptionalArrowTypeList(result.types))
2277 return failure();
2278 // Parse the region.
2279 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
2280 /*argTypes=*/{}))
2281 return failure();
2282 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
2283
2284 // Parse the optional attribute list.
2285 if (parser.parseOptionalAttrDict(result.attributes))
2286 return failure();
2287 return success();
2288}
2289
2290void WarpExecuteOnLane0Op::getSuccessorRegions(
2291 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2292 if (!point.isParent()) {
2293 regions.push_back(RegionSuccessor(getResults()));
2294 return;
2295 }
2296
2297 // The warp region is always executed
2298 regions.push_back(RegionSuccessor(&getWarpRegion()));
2299}
2300
2301void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2302 TypeRange resultTypes, Value laneId,
2303 int64_t warpSize) {
2304 build(builder, result, resultTypes, laneId, warpSize,
2305 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
2306}
2307
2308void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2309 TypeRange resultTypes, Value laneId,
2310 int64_t warpSize, ValueRange args,
2311 TypeRange blockArgTypes) {
2312 result.addOperands(laneId);
2313 result.addAttribute(getAttributeNames()[0],
2314 builder.getI64IntegerAttr(warpSize));
2315 result.addTypes(resultTypes);
2316 result.addOperands(args);
2317 assert(args.size() == blockArgTypes.size());
2318 OpBuilder::InsertionGuard guard(builder);
2319 Region *warpRegion = result.addRegion();
2320 Block *block = builder.createBlock(warpRegion);
2321 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
2322 block->addArgument(type, arg.getLoc());
2323}
2324
2325/// Helper check if the distributed vector type is consistent with the expanded
2326/// type and distributed size.
2327static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2328 int64_t warpSize, Operation *op) {
2329 // If the types matches there is no distribution.
2330 if (expanded == distributed)
2331 return success();
2332 auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
2333 auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
2334 if (!expandedVecType || !distributedVecType)
2335 return op->emitOpError(message: "expected vector type for distributed operands.");
2336 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2337 expandedVecType.getElementType() != distributedVecType.getElementType())
2338 return op->emitOpError(
2339 message: "expected distributed vectors to have same rank and element type.");
2340
2341 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2342 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2343 int64_t eDim = expandedVecType.getDimSize(i);
2344 int64_t dDim = distributedVecType.getDimSize(i);
2345 if (eDim == dDim)
2346 continue;
2347 if (eDim % dDim != 0)
2348 return op->emitOpError()
2349 << "expected expanded vector dimension #" << i << " (" << eDim
2350 << ") to be a multipler of the distributed vector dimension ("
2351 << dDim << ")";
2352 scales[i] = eDim / dDim;
2353 }
2354 if (std::accumulate(first: scales.begin(), last: scales.end(), init: 1,
2355 binary_op: std::multiplies<int64_t>()) != warpSize)
2356 return op->emitOpError()
2357 << "incompatible distribution dimensions from " << expandedVecType
2358 << " to " << distributedVecType << " with warp size = " << warpSize;
2359
2360 return success();
2361}
2362
2363LogicalResult WarpExecuteOnLane0Op::verify() {
2364 if (getArgs().size() != getWarpRegion().getNumArguments())
2365 return emitOpError(
2366 "expected same number op arguments and block arguments.");
2367 auto yield =
2368 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
2369 if (yield.getNumOperands() != getNumResults())
2370 return emitOpError(
2371 "expected same number of yield operands and return values.");
2372 int64_t warpSize = getWarpSize();
2373 for (auto [regionArg, arg] :
2374 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
2375 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
2376 warpSize, getOperation())))
2377 return failure();
2378 }
2379 for (auto [yieldOperand, result] :
2380 llvm::zip_equal(yield.getOperands(), getResults())) {
2381 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
2382 warpSize, getOperation())))
2383 return failure();
2384 }
2385 return success();
2386}
2387bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2388 return succeeded(
2389 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
2390}
2391
2392//===----------------------------------------------------------------------===//
2393// GPU KernelMetadataAttr
2394//===----------------------------------------------------------------------===//
2395
2396KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2397 DictionaryAttr metadata) {
2398 assert(kernel && "invalid kernel");
2399 return get(kernel.getNameAttr(), kernel.getFunctionType(),
2400 kernel.getAllArgAttrs(), metadata);
2401}
2402
2403KernelMetadataAttr
2404KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2405 FunctionOpInterface kernel,
2406 DictionaryAttr metadata) {
2407 assert(kernel && "invalid kernel");
2408 return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(),
2409 kernel.getAllArgAttrs(), metadata);
2410}
2411
2412KernelMetadataAttr
2413KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2414 if (attrs.empty())
2415 return *this;
2416 NamedAttrList attrList;
2417 if (DictionaryAttr dict = getMetadata())
2418 attrList.append(dict);
2419 attrList.append(attrs);
2420 return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(),
2421 attrList.getDictionary(getContext()));
2422}
2423
2424LogicalResult
2425KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2426 StringAttr name, Type functionType,
2427 ArrayAttr argAttrs, DictionaryAttr metadata) {
2428 if (name.empty())
2429 return emitError() << "the kernel name can't be empty";
2430 if (argAttrs) {
2431 if (llvm::any_of(argAttrs, [](Attribute attr) {
2432 return !llvm::isa<DictionaryAttr>(attr);
2433 }))
2434 return emitError()
2435 << "all attributes in the array must be a dictionary attribute";
2436 }
2437 return success();
2438}
2439
2440//===----------------------------------------------------------------------===//
2441// GPU KernelTableAttr
2442//===----------------------------------------------------------------------===//
2443
2444KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2445 ArrayRef<KernelMetadataAttr> kernels,
2446 bool isSorted) {
2447 // Note that `is_sorted` is always only invoked once even with assertions ON.
2448 assert((!isSorted || llvm::is_sorted(kernels)) &&
2449 "expected a sorted kernel array");
2450 // Immediately return the attribute if the array is sorted.
2451 if (isSorted || llvm::is_sorted(kernels))
2452 return Base::get(context, kernels);
2453 // Sort the array.
2454 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2455 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2456 return Base::get(context, kernelsTmp);
2457}
2458
2459KernelTableAttr KernelTableAttr::getChecked(
2460 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2461 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2462 // Note that `is_sorted` is always only invoked once even with assertions ON.
2463 assert((!isSorted || llvm::is_sorted(kernels)) &&
2464 "expected a sorted kernel array");
2465 // Immediately return the attribute if the array is sorted.
2466 if (isSorted || llvm::is_sorted(kernels))
2467 return Base::getChecked(emitError, context, kernels);
2468 // Sort the array.
2469 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2470 llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end());
2471 return Base::getChecked(emitError, context, kernelsTmp);
2472}
2473
2474LogicalResult
2475KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2476 ArrayRef<KernelMetadataAttr> kernels) {
2477 if (kernels.size() < 2)
2478 return success();
2479 // Check that the kernels are uniquely named.
2480 if (std::adjacent_find(kernels.begin(), kernels.end(),
2481 [](KernelMetadataAttr l, KernelMetadataAttr r) {
2482 return l.getName() == r.getName();
2483 }) != kernels.end()) {
2484 return emitError() << "expected all kernels to be uniquely named";
2485 }
2486 return success();
2487}
2488
2489KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2490 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2491 return found ? *iterator : KernelMetadataAttr();
2492}
2493
2494KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2495 auto [iterator, found] = impl::findAttrSorted(begin(), end(), key);
2496 return found ? *iterator : KernelMetadataAttr();
2497}
2498
2499//===----------------------------------------------------------------------===//
2500// GPU target options
2501//===----------------------------------------------------------------------===//
2502
2503TargetOptions::TargetOptions(
2504 StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2505 StringRef cmdOptions, StringRef elfSection,
2506 CompilationTarget compilationTarget,
2507 function_ref<SymbolTable *()> getSymbolTableCallback,
2508 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2509 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2510 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2511 function_ref<void(StringRef)> isaCallback)
2512 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2513 cmdOptions, elfSection, compilationTarget,
2514 getSymbolTableCallback, initialLlvmIRCallback,
2515 linkedLlvmIRCallback, optimizedLlvmIRCallback,
2516 isaCallback) {}
2517
2518TargetOptions::TargetOptions(
2519 TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2520 StringRef cmdOptions, StringRef elfSection,
2521 CompilationTarget compilationTarget,
2522 function_ref<SymbolTable *()> getSymbolTableCallback,
2523 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2524 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2525 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2526 function_ref<void(StringRef)> isaCallback)
2527 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2528 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2529 compilationTarget(compilationTarget),
2530 getSymbolTableCallback(getSymbolTableCallback),
2531 initialLlvmIRCallback(initialLlvmIRCallback),
2532 linkedLlvmIRCallback(linkedLlvmIRCallback),
2533 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2534 isaCallback(isaCallback), typeID(typeID) {}
2535
2536TypeID TargetOptions::getTypeID() const { return typeID; }
2537
2538StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2539
2540ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2541 return librariesToLink;
2542}
2543
2544StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2545
2546StringRef TargetOptions::getELFSection() const { return elfSection; }
2547
2548SymbolTable *TargetOptions::getSymbolTable() const {
2549 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2550}
2551
2552function_ref<void(llvm::Module &)>
2553TargetOptions::getInitialLlvmIRCallback() const {
2554 return initialLlvmIRCallback;
2555}
2556
2557function_ref<void(llvm::Module &)>
2558TargetOptions::getLinkedLlvmIRCallback() const {
2559 return linkedLlvmIRCallback;
2560}
2561
2562function_ref<void(llvm::Module &)>
2563TargetOptions::getOptimizedLlvmIRCallback() const {
2564 return optimizedLlvmIRCallback;
2565}
2566
2567function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2568 return isaCallback;
2569}
2570
2571CompilationTarget TargetOptions::getCompilationTarget() const {
2572 return compilationTarget;
2573}
2574
2575CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2576 return CompilationTarget::Fatbin;
2577}
2578
2579std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2580TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2581 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2582 llvm::StringSaver stringSaver(options.first);
2583 StringRef opts = cmdOptions;
2584 // For a correct tokenization of the command line options `opts` must be
2585 // unquoted, otherwise the tokenization function returns a single string: the
2586 // unquoted `cmdOptions` -which is not the desired behavior.
2587 // Remove any quotes if they are at the beginning and end of the string:
2588 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2589 opts.consume_front(Prefix: "\""), opts.consume_back(Suffix: "\"");
2590 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2591 opts.consume_front(Prefix: "'"), opts.consume_back(Suffix: "'");
2592#ifdef _WIN32
2593 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2594 /*MarkEOLs=*/false);
2595#else
2596 llvm::cl::TokenizeGNUCommandLine(Source: opts, Saver&: stringSaver, NewArgv&: options.second,
2597 /*MarkEOLs=*/false);
2598#endif // _WIN32
2599 return options;
2600}
2601
2602std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2603TargetOptions::tokenizeCmdOptions() const {
2604 return tokenizeCmdOptions(cmdOptions);
2605}
2606
2607std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2608TargetOptions::tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith) {
2609 size_t startPos = cmdOptions.find(svt: startsWith);
2610 if (startPos == std::string::npos)
2611 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2612
2613 auto tokenized =
2614 tokenizeCmdOptions(cmdOptions: cmdOptions.substr(pos: startPos + startsWith.size()));
2615 cmdOptions.resize(n: startPos);
2616 return tokenized;
2617}
2618
2619MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions)
2620
2621#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2622#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2623
2624#define GET_ATTRDEF_CLASSES
2625#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2626
2627#define GET_OP_CLASSES
2628#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2629
2630#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
2631

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/GPU/IR/GPUDialect.cpp