1//===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the GPU kernel-related dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/GPU/IR/GPUDialect.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
17#include "mlir/Dialect/Math/IR/Math.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/IR/BuiltinAttributes.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/Diagnostics.h"
25#include "mlir/IR/DialectImplementation.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/OpImplementation.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/IR/SymbolTable.h"
30#include "mlir/IR/TypeUtilities.h"
31#include "mlir/Interfaces/FunctionImplementation.h"
32#include "mlir/Interfaces/SideEffectInterfaces.h"
33#include "mlir/Interfaces/ValueBoundsOpInterface.h"
34#include "mlir/Transforms/InliningUtils.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/CommandLine.h"
38#include "llvm/Support/ErrorHandling.h"
39#include "llvm/Support/FormatVariadic.h"
40#include "llvm/Support/InterleavedRange.h"
41#include "llvm/Support/StringSaver.h"
42#include <cassert>
43#include <numeric>
44
45using namespace mlir;
46using namespace mlir::gpu;
47
48#include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc"
49
50//===----------------------------------------------------------------------===//
51// GPU Device Mapping Attributes
52//===----------------------------------------------------------------------===//
53
54int64_t GPUBlockMappingAttr::getMappingId() const {
55 return static_cast<int64_t>(getBlock());
56}
57
58bool GPUBlockMappingAttr::isLinearMapping() const {
59 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
60}
61
62int64_t GPUBlockMappingAttr::getRelativeIndex() const {
63 return isLinearMapping()
64 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
65 : getMappingId();
66}
67
68int64_t GPUWarpgroupMappingAttr::getMappingId() const {
69 return static_cast<int64_t>(getWarpgroup());
70}
71
72bool GPUWarpgroupMappingAttr::isLinearMapping() const {
73 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
74}
75
76int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const {
77 return isLinearMapping()
78 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
79 : getMappingId();
80}
81
82int64_t GPUWarpMappingAttr::getMappingId() const {
83 return static_cast<int64_t>(getWarp());
84}
85
86bool GPUWarpMappingAttr::isLinearMapping() const {
87 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
88}
89
90int64_t GPUWarpMappingAttr::getRelativeIndex() const {
91 return isLinearMapping()
92 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
93 : getMappingId();
94}
95
96int64_t GPUThreadMappingAttr::getMappingId() const {
97 return static_cast<int64_t>(getThread());
98}
99
100bool GPUThreadMappingAttr::isLinearMapping() const {
101 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
102}
103
104int64_t GPUThreadMappingAttr::getRelativeIndex() const {
105 return isLinearMapping()
106 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
107 : getMappingId();
108}
109
110int64_t GPULaneMappingAttr::getMappingId() const {
111 return static_cast<int64_t>(getLane());
112}
113
114bool GPULaneMappingAttr::isLinearMapping() const {
115 return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0);
116}
117
118int64_t GPULaneMappingAttr::getRelativeIndex() const {
119 return isLinearMapping()
120 ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0)
121 : getMappingId();
122}
123
124int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
125
126/// 8 4 0
127/// Example mask : 0 0 0 1 1 0 1 0 0
128///
129/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
130/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
131///
132/// Example mask : 0 0 0 1 1 0 1 0 0
133/// Example filter: 0 0 0 0 1 1 1 1 1
134/// Intersection : 0 0 0 0 1 0 1 0 0
135/// PopCnt : 2
136Value GPUMappingMaskAttr::createLogicalLinearMappingId(
137 OpBuilder &b, Value physicalLinearMappingId) const {
138 Location loc = physicalLinearMappingId.getLoc();
139 Value mask = b.create<arith::ConstantOp>(location: loc, args: b.getI64IntegerAttr(value: getMask()));
140 Value one = b.create<arith::ConstantOp>(location: loc, args: b.getI64IntegerAttr(value: 1));
141 Value filter = b.create<arith::ShLIOp>(location: loc, args&: one, args&: physicalLinearMappingId);
142 filter = b.create<arith::SubIOp>(location: loc, args&: filter, args&: one);
143 Value filteredId = b.create<arith::AndIOp>(location: loc, args&: mask, args&: filter);
144 return b.create<math::CtPopOp>(location: loc, args&: filteredId);
145}
146
147/// 8 4 0
148/// Example mask : 0 0 0 1 1 0 1 0 0
149///
150/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
151/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
152///
153/// Example mask : 0 0 0 1 1 0 1 0 0
154/// Example filter: 0 0 0 1 0 0 0 0 0
155/// Intersection : 0 0 0 1 0 0 0 0 0
156/// Cmp : 1
157Value GPUMappingMaskAttr::createIsActiveIdPredicate(
158 OpBuilder &b, Value physicalLinearMappingId) const {
159 Location loc = physicalLinearMappingId.getLoc();
160 Value mask = b.create<arith::ConstantOp>(location: loc, args: b.getI64IntegerAttr(value: getMask()));
161 Value one = b.create<arith::ConstantOp>(location: loc, args: b.getI64IntegerAttr(value: 1));
162 Value filter = b.create<arith::ShLIOp>(location: loc, args&: one, args&: physicalLinearMappingId);
163 Value filtered = b.create<arith::AndIOp>(location: loc, args&: mask, args&: filter);
164 Value zero = b.create<arith::ConstantOp>(location: loc, args: b.getI64IntegerAttr(value: 0));
165 return b.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::ne, args&: filtered, args&: zero);
166}
167
168int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
169 return static_cast<int64_t>(getAddressSpace());
170}
171
172bool GPUMemorySpaceMappingAttr::isLinearMapping() const {
173 llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping");
174}
175
176int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const {
177 llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index");
178}
179
180//===----------------------------------------------------------------------===//
181// MMAMatrixType
182//===----------------------------------------------------------------------===//
183
184MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType,
185 StringRef operand) {
186 return Base::get(ctx: elementType.getContext(), args&: shape, args&: elementType, args&: operand);
187}
188
189MMAMatrixType
190MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
191 ArrayRef<int64_t> shape, Type elementType,
192 StringRef operand) {
193 return Base::getChecked(emitErrorFn: emitError, ctx: elementType.getContext(), args: shape,
194 args: elementType, args: operand);
195}
196
197unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; }
198
199ArrayRef<int64_t> MMAMatrixType::getShape() const {
200 return getImpl()->getShape();
201}
202
203Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
204
205StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
206
207bool MMAMatrixType::isValidElementType(Type elementType) {
208 return elementType.isF16() || elementType.isF32() ||
209 elementType.isUnsignedInteger(width: 8) || elementType.isSignedInteger(width: 8) ||
210 elementType.isInteger(width: 32);
211}
212
213LogicalResult
214MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
215 ArrayRef<int64_t> shape, Type elementType,
216 StringRef operand) {
217 if (operand != "AOp" && operand != "BOp" && operand != "COp")
218 return emitError() << "operand expected to be one of AOp, BOp or COp";
219
220 if (shape.size() != 2)
221 return emitError() << "MMAMatrixType must have exactly two dimensions";
222
223 if (!MMAMatrixType::isValidElementType(elementType))
224 return emitError()
225 << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
226
227 return success();
228}
229
230//===----------------------------------------------------------------------===//
231// GPUDialect
232//===----------------------------------------------------------------------===//
233
234bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
235 if (!memorySpace)
236 return false;
237 if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(Val&: memorySpace))
238 return gpuAttr.getValue() == getWorkgroupAddressSpace();
239 return false;
240}
241
242bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
243 Attribute memorySpace = type.getMemorySpace();
244 return isWorkgroupMemoryAddressSpace(memorySpace);
245}
246
247bool GPUDialect::isKernel(Operation *op) {
248 UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(name: getKernelFuncAttrName());
249 return static_cast<bool>(isKernelAttr);
250}
251
252namespace {
253/// This class defines the interface for handling inlining with gpu
254/// operations.
255struct GPUInlinerInterface : public DialectInlinerInterface {
256 using DialectInlinerInterface::DialectInlinerInterface;
257
258 /// All gpu dialect ops can be inlined.
259 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
260 return true;
261 }
262};
263} // namespace
264
265void GPUDialect::initialize() {
266 addTypes<AsyncTokenType>();
267 addTypes<MMAMatrixType>();
268 addTypes<SparseDnTensorHandleType>();
269 addTypes<SparseSpMatHandleType>();
270 addTypes<SparseSpGEMMOpHandleType>();
271 addOperations<
272#define GET_OP_LIST
273#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
274 >();
275 addAttributes<
276#define GET_ATTRDEF_LIST
277#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
278 >();
279 addInterfaces<GPUInlinerInterface>();
280 declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
281 TerminatorOp>();
282 declarePromisedInterfaces<
283 ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp,
284 ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp,
285 SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>();
286}
287
288static std::string getSparseHandleKeyword(SparseHandleKind kind) {
289 switch (kind) {
290 case SparseHandleKind::DnTensor:
291 return "sparse.dntensor_handle";
292 case SparseHandleKind::SpMat:
293 return "sparse.spmat_handle";
294 case SparseHandleKind::SpGEMMOp:
295 return "sparse.spgemmop_handle";
296 }
297 llvm_unreachable("unknown sparse handle kind");
298 return "";
299}
300
301Type GPUDialect::parseType(DialectAsmParser &parser) const {
302 // Parse the main keyword for the type.
303 StringRef keyword;
304 if (parser.parseKeyword(keyword: &keyword))
305 return Type();
306 MLIRContext *context = getContext();
307
308 // Handle 'async token' types.
309 if (keyword == "async.token")
310 return AsyncTokenType::get(ctx: context);
311
312 if (keyword == "mma_matrix") {
313 SMLoc beginLoc = parser.getNameLoc();
314
315 // Parse '<'.
316 if (parser.parseLess())
317 return nullptr;
318
319 // Parse the size and elementType.
320 SmallVector<int64_t> shape;
321 Type elementType;
322 if (parser.parseDimensionList(dimensions&: shape, /*allowDynamic=*/false) ||
323 parser.parseType(result&: elementType))
324 return nullptr;
325
326 // Parse ','
327 if (parser.parseComma())
328 return nullptr;
329
330 // Parse operand.
331 std::string operand;
332 if (failed(Result: parser.parseOptionalString(string: &operand)))
333 return nullptr;
334
335 // Parse '>'.
336 if (parser.parseGreater())
337 return nullptr;
338
339 return MMAMatrixType::getChecked(emitError: mlir::detail::getDefaultDiagnosticEmitFn(
340 loc: parser.getEncodedSourceLoc(loc: beginLoc)),
341 shape, elementType, operand);
342 }
343
344 if (keyword == getSparseHandleKeyword(kind: SparseHandleKind::DnTensor))
345 return SparseDnTensorHandleType::get(ctx: context);
346 if (keyword == getSparseHandleKeyword(kind: SparseHandleKind::SpMat))
347 return SparseSpMatHandleType::get(ctx: context);
348 if (keyword == getSparseHandleKeyword(kind: SparseHandleKind::SpGEMMOp))
349 return SparseSpGEMMOpHandleType::get(ctx: context);
350
351 parser.emitError(loc: parser.getNameLoc(), message: "unknown gpu type: " + keyword);
352 return Type();
353}
354// TODO: print refined type here. Notice that should be corresponding to the
355// parser
356void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
357 TypeSwitch<Type>(type)
358 .Case<AsyncTokenType>(caseFn: [&](Type) { os << "async.token"; })
359 .Case<SparseDnTensorHandleType>(caseFn: [&](Type) {
360 os << getSparseHandleKeyword(kind: SparseHandleKind::DnTensor);
361 })
362 .Case<SparseSpMatHandleType>(
363 caseFn: [&](Type) { os << getSparseHandleKeyword(kind: SparseHandleKind::SpMat); })
364 .Case<SparseSpGEMMOpHandleType>(caseFn: [&](Type) {
365 os << getSparseHandleKeyword(kind: SparseHandleKind::SpGEMMOp);
366 })
367 .Case<MMAMatrixType>(caseFn: [&](MMAMatrixType fragTy) {
368 os << "mma_matrix<";
369 auto shape = fragTy.getShape();
370 for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim)
371 os << *dim << 'x';
372 os << shape.back() << 'x' << fragTy.getElementType();
373 os << ", \"" << fragTy.getOperand() << "\"" << '>';
374 })
375 .Default(defaultFn: [](Type) { llvm_unreachable("unexpected 'gpu' type kind"); });
376}
377
378static LogicalResult verifyKnownLaunchSizeAttr(Operation *op,
379 NamedAttribute attr) {
380 auto array = dyn_cast<DenseI32ArrayAttr>(Val: attr.getValue());
381 if (!array)
382 return op->emitOpError(message: Twine(attr.getName()) +
383 " must be a dense i32 array");
384 if (array.size() != 3)
385 return op->emitOpError(message: Twine(attr.getName()) +
386 " must contain exactly 3 elements");
387 return success();
388}
389
390LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
391 NamedAttribute attr) {
392 if (attr.getName() == getKnownBlockSizeAttrHelper().getName())
393 return verifyKnownLaunchSizeAttr(op, attr);
394 if (attr.getName() == getKnownGridSizeAttrHelper().getName())
395 return verifyKnownLaunchSizeAttr(op, attr);
396 if (!llvm::isa<UnitAttr>(Val: attr.getValue()) ||
397 attr.getName() != getContainerModuleAttrName())
398 return success();
399
400 auto module = dyn_cast<ModuleOp>(Val: op);
401 if (!module)
402 return op->emitError(message: "expected '")
403 << getContainerModuleAttrName() << "' attribute to be attached to '"
404 << ModuleOp::getOperationName() << '\'';
405
406 auto walkResult = module.walk(callback: [&module](LaunchFuncOp launchOp) -> WalkResult {
407 // Ignore launches that are nested more or less deep than functions in the
408 // module we are currently checking.
409 if (!launchOp->getParentOp() ||
410 launchOp->getParentOp()->getParentOp() != module)
411 return success();
412
413 // Ignore launch ops with missing attributes here. The errors will be
414 // reported by the verifiers of those ops.
415 if (!launchOp->getAttrOfType<SymbolRefAttr>(
416 name: LaunchFuncOp::getKernelAttrName(name: launchOp->getName())))
417 return success();
418
419 // Check that `launch_func` refers to a well-formed GPU kernel container.
420 StringAttr kernelContainerName = launchOp.getKernelModuleName();
421 Operation *kernelContainer = module.lookupSymbol(name: kernelContainerName);
422 if (!kernelContainer)
423 return launchOp.emitOpError()
424 << "kernel container '" << kernelContainerName.getValue()
425 << "' is undefined";
426
427 // If the container is a GPU binary op return success.
428 if (isa<BinaryOp>(Val: kernelContainer))
429 return success();
430
431 auto kernelModule = dyn_cast<GPUModuleOp>(Val: kernelContainer);
432 if (!kernelModule)
433 return launchOp.emitOpError()
434 << "kernel module '" << kernelContainerName.getValue()
435 << "' is undefined";
436
437 // Check that `launch_func` refers to a well-formed kernel function.
438 Operation *kernelFunc = module.lookupSymbol(symbol: launchOp.getKernelAttr());
439 if (!kernelFunc)
440 return launchOp.emitOpError(message: "kernel function '")
441 << launchOp.getKernel() << "' is undefined";
442 auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(Val: kernelFunc);
443 if (!kernelConvertedFunction) {
444 InFlightDiagnostic diag = launchOp.emitOpError()
445 << "referenced kernel '" << launchOp.getKernel()
446 << "' is not a function";
447 diag.attachNote(noteLoc: kernelFunc->getLoc()) << "see the kernel definition here";
448 return diag;
449 }
450
451 if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
452 name: GPUDialect::getKernelFuncAttrName()))
453 return launchOp.emitOpError(message: "kernel function is missing the '")
454 << GPUDialect::getKernelFuncAttrName() << "' attribute";
455
456 // TODO: If the kernel isn't a GPU function (which happens during separate
457 // compilation), do not check type correspondence as it would require the
458 // verifier to be aware of the type conversion.
459 auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(Val: kernelFunc);
460 if (!kernelGPUFunction)
461 return success();
462
463 unsigned actualNumArguments = launchOp.getNumKernelOperands();
464 unsigned expectedNumArguments = kernelGPUFunction.getNumArguments();
465 if (expectedNumArguments != actualNumArguments)
466 return launchOp.emitOpError(message: "got ")
467 << actualNumArguments << " kernel operands but expected "
468 << expectedNumArguments;
469
470 auto functionType = kernelGPUFunction.getFunctionType();
471 for (unsigned i = 0; i < expectedNumArguments; ++i) {
472 if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) {
473 return launchOp.emitOpError(message: "type of function argument ")
474 << i << " does not match";
475 }
476 }
477
478 return success();
479 });
480
481 return walkResult.wasInterrupted() ? failure() : success();
482}
483
484/// Parses an optional list of async operands with an optional leading keyword.
485/// (`async`)? (`[` ssa-id-list `]`)?
486///
487/// This method is used by the tablegen assembly format for async ops as well.
488static ParseResult parseAsyncDependencies(
489 OpAsmParser &parser, Type &asyncTokenType,
490 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) {
491 auto loc = parser.getCurrentLocation();
492 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "async"))) {
493 if (parser.getNumResults() == 0)
494 return parser.emitError(loc, message: "needs to be named when marked 'async'");
495 asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
496 }
497 return parser.parseOperandList(result&: asyncDependencies,
498 delimiter: OpAsmParser::Delimiter::OptionalSquare);
499}
500
501/// Prints optional async dependencies with its leading keyword.
502/// (`async`)? (`[` ssa-id-list `]`)?
503// Used by the tablegen assembly format for several async ops.
504static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
505 Type asyncTokenType,
506 OperandRange asyncDependencies) {
507 if (asyncTokenType)
508 printer << "async";
509 if (asyncDependencies.empty())
510 return;
511 if (asyncTokenType)
512 printer << ' ';
513 printer << llvm::interleaved_array(R: asyncDependencies);
514}
515
516// GPU Memory attributions functions shared by LaunchOp and GPUFuncOp.
517/// Parses a GPU function memory attribution.
518///
519/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
520/// (`private` `(` ssa-id-and-type-list `)`)?
521///
522/// Note that this function parses only one of the two similar parts, with the
523/// keyword provided as argument.
524static ParseResult
525parseAttributions(OpAsmParser &parser, StringRef keyword,
526 SmallVectorImpl<OpAsmParser::Argument> &args) {
527 // If we could not parse the keyword, just assume empty list and succeed.
528 if (failed(Result: parser.parseOptionalKeyword(keyword)))
529 return success();
530
531 return parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren,
532 /*allowType=*/true);
533}
534
535/// Prints a GPU function memory attribution.
536static void printAttributions(OpAsmPrinter &p, StringRef keyword,
537 ArrayRef<BlockArgument> values) {
538 if (values.empty())
539 return;
540
541 auto printBlockArg = [](BlockArgument v) {
542 return llvm::formatv(Fmt: "{} : {}", Vals&: v, Vals: v.getType());
543 };
544 p << ' ' << keyword << '('
545 << llvm::interleaved(R: llvm::map_range(C&: values, F: printBlockArg)) << ')';
546}
547
548/// Verifies a GPU function memory attribution.
549static LogicalResult verifyAttributions(Operation *op,
550 ArrayRef<BlockArgument> attributions,
551 gpu::AddressSpace memorySpace) {
552 for (Value v : attributions) {
553 auto type = llvm::dyn_cast<MemRefType>(Val: v.getType());
554 if (!type)
555 return op->emitOpError() << "expected memref type in attribution";
556
557 // We can only verify the address space if it hasn't already been lowered
558 // from the AddressSpaceAttr to a target-specific numeric value.
559 auto addressSpace =
560 llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(Val: type.getMemorySpace());
561 if (!addressSpace)
562 continue;
563 if (addressSpace.getValue() != memorySpace)
564 return op->emitOpError()
565 << "expected memory space " << stringifyAddressSpace(memorySpace)
566 << " in attribution";
567 }
568 return success();
569}
570
571//===----------------------------------------------------------------------===//
572// AllReduceOp
573//===----------------------------------------------------------------------===//
574
575static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName,
576 Type resType) {
577 using Kind = gpu::AllReduceOperation;
578 if (llvm::is_contained(
579 Set: {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF},
580 Element: opName)) {
581 if (!isa<FloatType>(Val: resType))
582 return failure();
583 }
584
585 if (llvm::is_contained(Set: {Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI,
586 Kind::AND, Kind::OR, Kind::XOR},
587 Element: opName)) {
588 if (!isa<IntegerType>(Val: resType))
589 return failure();
590 }
591
592 return success();
593}
594
595LogicalResult gpu::AllReduceOp::verifyRegions() {
596 if (getBody().empty() != getOp().has_value())
597 return emitError(message: "expected either an op attribute or a non-empty body");
598 if (!getBody().empty()) {
599 if (getBody().getNumArguments() != 2)
600 return emitError(message: "expected two region arguments");
601 for (auto argument : getBody().getArguments()) {
602 if (argument.getType() != getType())
603 return emitError(message: "incorrect region argument type");
604 }
605 unsigned yieldCount = 0;
606 for (Block &block : getBody()) {
607 if (auto yield = dyn_cast<gpu::YieldOp>(Val: block.getTerminator())) {
608 if (yield.getNumOperands() != 1)
609 return emitError(message: "expected one gpu.yield operand");
610 if (yield.getOperand(i: 0).getType() != getType())
611 return emitError(message: "incorrect gpu.yield type");
612 ++yieldCount;
613 }
614 }
615 if (yieldCount == 0)
616 return emitError(message: "expected gpu.yield op in region");
617 } else {
618 gpu::AllReduceOperation opName = *getOp();
619 if (failed(Result: verifyReduceOpAndType(opName, resType: getType()))) {
620 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
621 << "` reduction operation is not compatible with type "
622 << getType();
623 }
624 }
625
626 return success();
627}
628
629static bool canMakeGroupOpUniform(Operation *op) {
630 auto launchOp = dyn_cast<gpu::LaunchOp>(Val: op->getParentOp());
631 if (!launchOp)
632 return false;
633
634 Region &body = launchOp.getBody();
635 assert(!body.empty() && "Invalid region");
636
637 // Only convert ops in gpu::launch entry block for now.
638 return op->getBlock() == &body.front();
639}
640
641OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) {
642 if (!getUniform() && canMakeGroupOpUniform(op: *this)) {
643 setUniform(true);
644 return getResult();
645 }
646
647 return nullptr;
648}
649
650// TODO: Support optional custom attributes (without dialect prefix).
651static ParseResult parseAllReduceOperation(AsmParser &parser,
652 AllReduceOperationAttr &attr) {
653 StringRef enumStr;
654 if (!parser.parseOptionalKeyword(keyword: &enumStr)) {
655 std::optional<AllReduceOperation> op =
656 gpu::symbolizeAllReduceOperation(enumStr);
657 if (!op)
658 return parser.emitError(loc: parser.getCurrentLocation(), message: "invalid op kind");
659 attr = AllReduceOperationAttr::get(context: parser.getContext(), value: *op);
660 }
661 return success();
662}
663
664static void printAllReduceOperation(AsmPrinter &printer, Operation *op,
665 AllReduceOperationAttr attr) {
666 if (attr)
667 attr.print(odsPrinter&: printer);
668}
669
670//===----------------------------------------------------------------------===//
671// SubgroupReduceOp
672//===----------------------------------------------------------------------===//
673
674LogicalResult gpu::SubgroupReduceOp::verify() {
675 Type elemType = getType();
676 if (auto vecTy = dyn_cast<VectorType>(Val&: elemType)) {
677 if (vecTy.isScalable())
678 return emitOpError() << "is not compatible with scalable vector types";
679
680 elemType = vecTy.getElementType();
681 }
682
683 gpu::AllReduceOperation opName = getOp();
684 if (failed(Result: verifyReduceOpAndType(opName, resType: elemType))) {
685 return emitError() << '`' << gpu::stringifyAllReduceOperation(opName)
686 << "` reduction operation is not compatible with type "
687 << getType();
688 }
689
690 auto clusterSize = getClusterSize();
691 if (clusterSize) {
692 uint32_t size = *clusterSize;
693 if (!llvm::isPowerOf2_32(Value: size)) {
694 return emitOpError() << "cluster size " << size
695 << " is not a power of two";
696 }
697 }
698
699 uint32_t stride = getClusterStride();
700 if (stride != 1 && !clusterSize) {
701 return emitOpError() << "cluster stride can only be specified if cluster "
702 "size is specified";
703 }
704 if (!llvm::isPowerOf2_32(Value: stride)) {
705 return emitOpError() << "cluster stride " << stride
706 << " is not a power of two";
707 }
708
709 return success();
710}
711
712OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
713 if (getClusterSize() == 1)
714 return getValue();
715
716 if (!getUniform() && canMakeGroupOpUniform(op: *this)) {
717 setUniform(true);
718 return getResult();
719 }
720
721 return nullptr;
722}
723
724//===----------------------------------------------------------------------===//
725// AsyncOpInterface
726//===----------------------------------------------------------------------===//
727
728void gpu::addAsyncDependency(Operation *op, Value token) {
729 op->insertOperands(index: 0, operands: {token});
730 if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>())
731 return;
732 auto attrName =
733 OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr();
734 auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(name: attrName);
735
736 // Async dependencies is the only variadic operand.
737 if (!sizeAttr)
738 return;
739
740 SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef());
741 ++sizes.front();
742 op->setAttr(name: attrName, value: Builder(op->getContext()).getDenseI32ArrayAttr(values: sizes));
743}
744
745//===----------------------------------------------------------------------===//
746// LaunchOp
747//===----------------------------------------------------------------------===//
748
749void LaunchOp::build(OpBuilder &builder, OperationState &result,
750 Value gridSizeX, Value gridSizeY, Value gridSizeZ,
751 Value getBlockSizeX, Value getBlockSizeY,
752 Value getBlockSizeZ, Value dynamicSharedMemorySize,
753 Type asyncTokenType, ValueRange asyncDependencies,
754 TypeRange workgroupAttributions,
755 TypeRange privateAttributions, Value clusterSizeX,
756 Value clusterSizeY, Value clusterSizeZ) {
757 OpBuilder::InsertionGuard g(builder);
758
759 // Add a WorkGroup attribution attribute. This attribute is required to
760 // identify private attributions in the list of block argguments.
761 result.addAttribute(name: getNumWorkgroupAttributionsAttrName(),
762 attr: builder.getI64IntegerAttr(value: workgroupAttributions.size()));
763
764 // Add Op operands.
765 result.addOperands(newOperands: asyncDependencies);
766 if (asyncTokenType)
767 result.types.push_back(Elt: builder.getType<AsyncTokenType>());
768
769 // Add grid and block sizes as op operands, followed by the data operands.
770 result.addOperands(newOperands: {gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX,
771 getBlockSizeY, getBlockSizeZ});
772 if (clusterSizeX)
773 result.addOperands(newOperands: clusterSizeX);
774 if (clusterSizeY)
775 result.addOperands(newOperands: clusterSizeY);
776 if (clusterSizeZ)
777 result.addOperands(newOperands: clusterSizeZ);
778 if (dynamicSharedMemorySize)
779 result.addOperands(newOperands: dynamicSharedMemorySize);
780
781 // Create a kernel body region with kNumConfigRegionAttributes + N memory
782 // attributions, where the first kNumConfigRegionAttributes arguments have
783 // `index` type and the rest have the same types as the data operands.
784 Region *kernelRegion = result.addRegion();
785 Block *body = builder.createBlock(parent: kernelRegion);
786 // TODO: Allow passing in proper locations here.
787 for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
788 body->addArgument(type: builder.getIndexType(), loc: result.location);
789 // Add WorkGroup & Private attributions to the region arguments.
790 for (Type argTy : workgroupAttributions)
791 body->addArgument(type: argTy, loc: result.location);
792 for (Type argTy : privateAttributions)
793 body->addArgument(type: argTy, loc: result.location);
794 // Fill OperandSegmentSize Attribute.
795 SmallVector<int32_t, 11> segmentSizes(11, 1);
796 segmentSizes.front() = asyncDependencies.size();
797 segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
798 segmentSizes[7] = clusterSizeX ? 1 : 0;
799 segmentSizes[8] = clusterSizeY ? 1 : 0;
800 segmentSizes[9] = clusterSizeZ ? 1 : 0;
801 result.addAttribute(name: getOperandSegmentSizeAttr(),
802 attr: builder.getDenseI32ArrayAttr(values: segmentSizes));
803}
804
805KernelDim3 LaunchOp::getBlockIds() {
806 assert(!getBody().empty() && "LaunchOp body must not be empty.");
807 auto args = getBody().getArguments();
808 return KernelDim3{.x: args[0], .y: args[1], .z: args[2]};
809}
810
811KernelDim3 LaunchOp::getThreadIds() {
812 assert(!getBody().empty() && "LaunchOp body must not be empty.");
813 auto args = getBody().getArguments();
814 return KernelDim3{.x: args[3], .y: args[4], .z: args[5]};
815}
816
817KernelDim3 LaunchOp::getGridSize() {
818 assert(!getBody().empty() && "LaunchOp body must not be empty.");
819 auto args = getBody().getArguments();
820 return KernelDim3{.x: args[6], .y: args[7], .z: args[8]};
821}
822
823KernelDim3 LaunchOp::getBlockSize() {
824 assert(!getBody().empty() && "LaunchOp body must not be empty.");
825 auto args = getBody().getArguments();
826 return KernelDim3{.x: args[9], .y: args[10], .z: args[11]};
827}
828
829std::optional<KernelDim3> LaunchOp::getClusterIds() {
830 assert(!getBody().empty() && "LaunchOp body must not be empty.");
831 if (!hasClusterSize())
832 return std::nullopt;
833 auto args = getBody().getArguments();
834 return KernelDim3{.x: args[12], .y: args[13], .z: args[14]};
835}
836
837std::optional<KernelDim3> LaunchOp::getClusterSize() {
838 assert(!getBody().empty() && "LaunchOp body must not be empty.");
839 if (!hasClusterSize())
840 return std::nullopt;
841 auto args = getBody().getArguments();
842 return KernelDim3{.x: args[15], .y: args[16], .z: args[17]};
843}
844
845KernelDim3 LaunchOp::getGridSizeOperandValues() {
846 auto operands = getOperands().drop_front(n: getAsyncDependencies().size());
847 return KernelDim3{.x: operands[0], .y: operands[1], .z: operands[2]};
848}
849
850KernelDim3 LaunchOp::getBlockSizeOperandValues() {
851 auto operands = getOperands().drop_front(n: getAsyncDependencies().size());
852 return KernelDim3{.x: operands[3], .y: operands[4], .z: operands[5]};
853}
854
855std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() {
856 auto operands = getOperands().drop_front(n: getAsyncDependencies().size());
857 if (!hasClusterSize())
858 return std::nullopt;
859 return KernelDim3{.x: operands[6], .y: operands[7], .z: operands[8]};
860}
861
862LogicalResult LaunchOp::verify() {
863 if (!(hasClusterSize()) &&
864 (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ()))
865 return emitOpError() << "cluster size must be all present";
866 return success();
867}
868
869LogicalResult LaunchOp::verifyRegions() {
870 // Kernel launch takes kNumConfigOperands leading operands for grid/block
871 // sizes and transforms them into kNumConfigRegionAttributes region arguments
872 // for block/thread identifiers and grid/block sizes.
873 if (!getBody().empty()) {
874 if (getBody().getNumArguments() <
875 kNumConfigRegionAttributes + getNumWorkgroupAttributions())
876 return emitOpError(message: "unexpected number of region arguments");
877 }
878
879 // Verify Attributions Address Spaces.
880 if (failed(Result: verifyAttributions(op: getOperation(), attributions: getWorkgroupAttributions(),
881 memorySpace: GPUDialect::getWorkgroupAddressSpace())) ||
882 failed(Result: verifyAttributions(op: getOperation(), attributions: getPrivateAttributions(),
883 memorySpace: GPUDialect::getPrivateAddressSpace())))
884 return failure();
885
886 // Block terminators without successors are expected to exit the kernel region
887 // and must be `gpu.terminator`.
888 for (Block &block : getBody()) {
889 if (block.empty())
890 continue;
891 if (block.back().getNumSuccessors() != 0)
892 continue;
893 if (!isa<gpu::TerminatorOp>(Val: &block.back())) {
894 return block.back()
895 .emitError()
896 .append(args: "expected '", args: gpu::TerminatorOp::getOperationName(),
897 args: "' or a terminator with successors")
898 .attachNote(noteLoc: getLoc())
899 .append(arg1: "in '", arg2: LaunchOp::getOperationName(), args: "' body region");
900 }
901 }
902
903 if (getNumResults() == 0 && getAsyncToken())
904 return emitOpError(message: "needs to be named when async keyword is specified");
905
906 return success();
907}
908
909// Pretty-print the kernel grid/block size assignment as
910// (%iter-x, %iter-y, %iter-z) in
911// (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use)
912// where %size-* and %iter-* will correspond to the body region arguments.
913static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
914 KernelDim3 operands, KernelDim3 ids) {
915 p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
916 p << size.x << " = " << operands.x << ", ";
917 p << size.y << " = " << operands.y << ", ";
918 p << size.z << " = " << operands.z << ')';
919}
920
921void LaunchOp::print(OpAsmPrinter &p) {
922 if (getAsyncToken()) {
923 p << " async";
924 if (!getAsyncDependencies().empty())
925 p << " [" << getAsyncDependencies() << ']';
926 }
927 // Print the launch configuration.
928 if (hasClusterSize()) {
929 p << ' ' << getClustersKeyword();
930 printSizeAssignment(p, size: getClusterSize().value(),
931 operands: getClusterSizeOperandValues().value(),
932 ids: getClusterIds().value());
933 }
934 p << ' ' << getBlocksKeyword();
935 printSizeAssignment(p, size: getGridSize(), operands: getGridSizeOperandValues(),
936 ids: getBlockIds());
937 p << ' ' << getThreadsKeyword();
938 printSizeAssignment(p, size: getBlockSize(), operands: getBlockSizeOperandValues(),
939 ids: getThreadIds());
940 if (getDynamicSharedMemorySize())
941 p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
942 << getDynamicSharedMemorySize();
943
944 printAttributions(p, keyword: getWorkgroupKeyword(), values: getWorkgroupAttributions());
945 printAttributions(p, keyword: getPrivateKeyword(), values: getPrivateAttributions());
946
947 p << ' ';
948
949 p.printRegion(blocks&: getBody(), /*printEntryBlockArgs=*/false);
950 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), /*elidedAttrs=*/{
951 LaunchOp::getOperandSegmentSizeAttr(),
952 getNumWorkgroupAttributionsAttrName()});
953}
954
955// Parse the size assignment blocks for blocks and threads. These have the form
956// (%region_arg, %region_arg, %region_arg) in
957// (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand)
958// where %region_arg are percent-identifiers for the region arguments to be
959// introduced further (SSA defs), and %operand are percent-identifiers for the
960// SSA value uses.
961static ParseResult
962parseSizeAssignment(OpAsmParser &parser,
963 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes,
964 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes,
965 MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
966 assert(indices.size() == 3 && "space for three indices expected");
967 SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
968 if (parser.parseOperandList(result&: args, delimiter: OpAsmParser::Delimiter::Paren,
969 /*allowResultNumber=*/false) ||
970 parser.parseKeyword(keyword: "in") || parser.parseLParen())
971 return failure();
972 std::move(first: args.begin(), last: args.end(), result: indices.begin());
973
974 for (int i = 0; i < 3; ++i) {
975 if (i != 0 && parser.parseComma())
976 return failure();
977 if (parser.parseOperand(result&: regionSizes[i], /*allowResultNumber=*/false) ||
978 parser.parseEqual() || parser.parseOperand(result&: sizes[i]))
979 return failure();
980 }
981
982 return parser.parseRParen();
983}
984
985/// Parses a Launch operation.
986/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
987/// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional)
988/// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
989/// `threads` `(` ssa-id-list `)` `in` ssa-reassignment
990/// memory-attribution
991/// region attr-dict?
992/// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
993ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
994 // Sizes of the grid and block.
995 SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands>
996 sizes(LaunchOp::kNumConfigOperands);
997
998 // Region arguments to be created.
999 SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs(
1000 LaunchOp::kNumConfigRegionAttributes);
1001
1002 // Parse optional async dependencies.
1003 SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
1004 Type asyncTokenType;
1005 if (failed(
1006 Result: parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
1007 parser.resolveOperands(operands&: asyncDependencies, type: asyncTokenType,
1008 result&: result.operands))
1009 return failure();
1010 if (parser.getNumResults() > 0)
1011 result.types.push_back(Elt: asyncTokenType);
1012
1013 bool hasCluster = false;
1014 if (succeeded(
1015 Result: parser.parseOptionalKeyword(keyword: LaunchOp::getClustersKeyword().data()))) {
1016 hasCluster = true;
1017 sizes.resize(N: 9);
1018 regionArgs.resize(N: 18);
1019 }
1020 MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes);
1021 MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
1022
1023 // Last three segment assigns the cluster size. In the region argument
1024 // list, this is last 6 arguments.
1025 if (hasCluster) {
1026 if (parseSizeAssignment(parser, sizes: sizesRef.drop_front(N: 6),
1027 regionSizes: regionArgsRef.slice(N: 15, M: 3),
1028 indices: regionArgsRef.slice(N: 12, M: 3)))
1029 return failure();
1030 }
1031 // Parse the size assignment segments: the first segment assigns grid sizes
1032 // and defines values for block identifiers; the second segment assigns block
1033 // sizes and defines values for thread identifiers. In the region argument
1034 // list, identifiers precede sizes, and block-related values precede
1035 // thread-related values.
1036 if (parser.parseKeyword(keyword: LaunchOp::getBlocksKeyword().data()) ||
1037 parseSizeAssignment(parser, sizes: sizesRef.take_front(N: 3),
1038 regionSizes: regionArgsRef.slice(N: 6, M: 3),
1039 indices: regionArgsRef.slice(N: 0, M: 3)) ||
1040 parser.parseKeyword(keyword: LaunchOp::getThreadsKeyword().data()) ||
1041 parseSizeAssignment(parser, sizes: sizesRef.drop_front(N: 3),
1042 regionSizes: regionArgsRef.slice(N: 9, M: 3),
1043 indices: regionArgsRef.slice(N: 3, M: 3)) ||
1044 parser.resolveOperands(operands&: sizes, type: parser.getBuilder().getIndexType(),
1045 result&: result.operands))
1046 return failure();
1047
1048 OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
1049 bool hasDynamicSharedMemorySize = false;
1050 if (!parser.parseOptionalKeyword(
1051 keyword: LaunchOp::getDynamicSharedMemorySizeKeyword())) {
1052 hasDynamicSharedMemorySize = true;
1053 if (parser.parseOperand(result&: dynamicSharedMemorySize) ||
1054 parser.resolveOperand(operand: dynamicSharedMemorySize,
1055 type: parser.getBuilder().getI32Type(),
1056 result&: result.operands))
1057 return failure();
1058 }
1059
1060 // Create the region arguments, it has kNumConfigRegionAttributes arguments
1061 // that correspond to block/thread identifiers and grid/block sizes, all
1062 // having `index` type, a variadic number of WorkGroup Attributions and
1063 // a variadic number of Private Attributions. The number of WorkGroup
1064 // Attributions is stored in the attr with name:
1065 // LaunchOp::getNumWorkgroupAttributionsAttrName().
1066 Type index = parser.getBuilder().getIndexType();
1067 SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
1068 LaunchOp::kNumConfigRegionAttributes + 6, index);
1069
1070 SmallVector<OpAsmParser::Argument> regionArguments;
1071 for (auto ssaValueAndType : llvm::zip(t&: regionArgs, u&: dataTypes)) {
1072 OpAsmParser::Argument arg;
1073 arg.ssaName = std::get<0>(t&: ssaValueAndType);
1074 arg.type = std::get<1>(t&: ssaValueAndType);
1075 regionArguments.push_back(Elt: arg);
1076 }
1077
1078 Builder &builder = parser.getBuilder();
1079 // Parse workgroup memory attributions.
1080 if (failed(Result: parseAttributions(parser, keyword: LaunchOp::getWorkgroupKeyword(),
1081 args&: regionArguments)))
1082 return failure();
1083
1084 // Store the number of operands we just parsed as the number of workgroup
1085 // memory attributions.
1086 unsigned numWorkgroupAttrs = regionArguments.size() -
1087 LaunchOp::kNumConfigRegionAttributes -
1088 (hasCluster ? 6 : 0);
1089 result.addAttribute(name: LaunchOp::getNumWorkgroupAttributionsAttrName(),
1090 attr: builder.getI64IntegerAttr(value: numWorkgroupAttrs));
1091
1092 // Parse private memory attributions.
1093 if (failed(Result: parseAttributions(parser, keyword: LaunchOp::getPrivateKeyword(),
1094 args&: regionArguments)))
1095 return failure();
1096
1097 // Introduce the body region and parse it. The region has
1098 // kNumConfigRegionAttributes arguments that correspond to
1099 // block/thread identifiers and grid/block sizes, all having `index` type.
1100 Region *body = result.addRegion();
1101 if (parser.parseRegion(region&: *body, arguments: regionArguments) ||
1102 parser.parseOptionalAttrDict(result&: result.attributes))
1103 return failure();
1104
1105 SmallVector<int32_t, 11> segmentSizes(11, 1);
1106 segmentSizes.front() = asyncDependencies.size();
1107
1108 if (!hasCluster) {
1109 segmentSizes[7] = 0;
1110 segmentSizes[8] = 0;
1111 segmentSizes[9] = 0;
1112 }
1113 segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
1114 result.addAttribute(name: LaunchOp::getOperandSegmentSizeAttr(),
1115 attr: parser.getBuilder().getDenseI32ArrayAttr(values: segmentSizes));
1116 return success();
1117}
1118
1119/// Simplify the gpu.launch when the range of a thread or block ID is
1120/// trivially known to be one.
1121struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> {
1122 using OpRewritePattern<LaunchOp>::OpRewritePattern;
1123 LogicalResult matchAndRewrite(LaunchOp op,
1124 PatternRewriter &rewriter) const override {
1125 // If the range implies a single value for `id`, replace `id`'s uses by
1126 // zero.
1127 Value zero;
1128 bool simplified = false;
1129 auto constPropIdUses = [&](Value id, Value size) {
1130 // Check if size is trivially one.
1131 if (!matchPattern(value: size, pattern: m_One()))
1132 return;
1133 if (id.getUses().empty())
1134 return;
1135 if (!simplified) {
1136 // Create a zero value the first time.
1137 OpBuilder::InsertionGuard guard(rewriter);
1138 rewriter.setInsertionPointToStart(&op.getBody().front());
1139 zero =
1140 rewriter.create<arith::ConstantIndexOp>(location: op.getLoc(), /*value=*/args: 0);
1141 }
1142 rewriter.replaceAllUsesWith(from: id, to: zero);
1143 simplified = true;
1144 };
1145 constPropIdUses(op.getBlockIds().x, op.getGridSizeX());
1146 constPropIdUses(op.getBlockIds().y, op.getGridSizeY());
1147 constPropIdUses(op.getBlockIds().z, op.getGridSizeZ());
1148 constPropIdUses(op.getThreadIds().x, op.getBlockSizeX());
1149 constPropIdUses(op.getThreadIds().y, op.getBlockSizeY());
1150 constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ());
1151
1152 return success(IsSuccess: simplified);
1153 }
1154};
1155
1156void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
1157 MLIRContext *context) {
1158 rewrites.add<FoldLaunchArguments>(arg&: context);
1159}
1160
1161/// Adds a new block argument that corresponds to buffers located in
1162/// workgroup memory.
1163BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) {
1164 auto attrName = getNumWorkgroupAttributionsAttrName();
1165 auto attr = (*this)->getAttrOfType<IntegerAttr>(name: attrName);
1166 (*this)->setAttr(name: attrName,
1167 value: IntegerAttr::get(type: attr.getType(), value: attr.getValue() + 1));
1168 return getBody().insertArgument(
1169 index: LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc);
1170}
1171
1172/// Adds a new block argument that corresponds to buffers located in
1173/// private memory.
1174BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) {
1175 // Buffers on the private memory always come after buffers on the workgroup
1176 // memory.
1177 return getBody().addArgument(type, loc);
1178}
1179
1180//===----------------------------------------------------------------------===//
1181// LaunchFuncOp
1182//===----------------------------------------------------------------------===//
1183
1184void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1185 SymbolRefAttr kernelSymbol, KernelDim3 gridSize,
1186 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1187 ValueRange kernelOperands, Type asyncTokenType,
1188 ValueRange asyncDependencies,
1189 std::optional<KernelDim3> clusterSize) {
1190 assert(kernelSymbol.getNestedReferences().size() == 1 &&
1191 "expected a symbol reference with a single nested reference");
1192 result.addOperands(newOperands: asyncDependencies);
1193 if (asyncTokenType)
1194 result.types.push_back(Elt: builder.getType<AsyncTokenType>());
1195
1196 // Add grid and block sizes as op operands, followed by the data operands.
1197 result.addOperands(newOperands: {gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1198 getBlockSize.y, getBlockSize.z});
1199 if (clusterSize.has_value())
1200 result.addOperands(newOperands: {clusterSize->x, clusterSize->y, clusterSize->z});
1201 if (dynamicSharedMemorySize)
1202 result.addOperands(newOperands: dynamicSharedMemorySize);
1203 result.addOperands(newOperands: kernelOperands);
1204
1205 Properties &prop = result.getOrAddProperties<Properties>();
1206 prop.kernel = kernelSymbol;
1207 size_t segmentSizesLen = std::size(cont: prop.operandSegmentSizes);
1208 // Initialize the segment sizes to 1.
1209 llvm::fill(Range&: prop.operandSegmentSizes, Value: 1);
1210 prop.operandSegmentSizes[0] = asyncDependencies.size();
1211 if (!clusterSize.has_value()) {
1212 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1213 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1214 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1215 }
1216 prop.operandSegmentSizes[segmentSizesLen - 3] =
1217 dynamicSharedMemorySize ? 1 : 0;
1218 prop.operandSegmentSizes[segmentSizesLen - 2] =
1219 static_cast<int32_t>(kernelOperands.size());
1220 prop.operandSegmentSizes[segmentSizesLen - 1] = 0;
1221}
1222
1223void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1224 GPUFuncOp kernelFunc, KernelDim3 gridSize,
1225 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1226 ValueRange kernelOperands, Type asyncTokenType,
1227 ValueRange asyncDependencies,
1228 std::optional<KernelDim3> clusterSize) {
1229 auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
1230 auto kernelSymbol =
1231 SymbolRefAttr::get(rootReference: kernelModule.getNameAttr(),
1232 nestedReferences: {SymbolRefAttr::get(value: kernelFunc.getNameAttr())});
1233 build(builder, result, kernelSymbol, gridSize, getBlockSize,
1234 dynamicSharedMemorySize, kernelOperands, asyncTokenType,
1235 asyncDependencies, clusterSize);
1236}
1237
1238void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
1239 SymbolRefAttr kernel, KernelDim3 gridSize,
1240 KernelDim3 getBlockSize, Value dynamicSharedMemorySize,
1241 ValueRange kernelOperands, Value asyncObject,
1242 std::optional<KernelDim3> clusterSize) {
1243 // Add grid and block sizes as op operands, followed by the data operands.
1244 result.addOperands(newOperands: {gridSize.x, gridSize.y, gridSize.z, getBlockSize.x,
1245 getBlockSize.y, getBlockSize.z});
1246 if (clusterSize.has_value())
1247 result.addOperands(newOperands: {clusterSize->x, clusterSize->y, clusterSize->z});
1248 if (dynamicSharedMemorySize)
1249 result.addOperands(newOperands: dynamicSharedMemorySize);
1250 result.addOperands(newOperands: kernelOperands);
1251 if (asyncObject)
1252 result.addOperands(newOperands: asyncObject);
1253 Properties &prop = result.getOrAddProperties<Properties>();
1254 prop.kernel = kernel;
1255 size_t segmentSizesLen = std::size(cont: prop.operandSegmentSizes);
1256 // Initialize the segment sizes to 1.
1257 llvm::fill(Range&: prop.operandSegmentSizes, Value: 1);
1258 prop.operandSegmentSizes[0] = 0;
1259 if (!clusterSize.has_value()) {
1260 prop.operandSegmentSizes[segmentSizesLen - 4] = 0;
1261 prop.operandSegmentSizes[segmentSizesLen - 5] = 0;
1262 prop.operandSegmentSizes[segmentSizesLen - 6] = 0;
1263 }
1264 prop.operandSegmentSizes[segmentSizesLen - 3] =
1265 dynamicSharedMemorySize ? 1 : 0;
1266 prop.operandSegmentSizes[segmentSizesLen - 2] =
1267 static_cast<int32_t>(kernelOperands.size());
1268 prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0;
1269}
1270
1271StringAttr LaunchFuncOp::getKernelModuleName() {
1272 return getKernel().getRootReference();
1273}
1274
1275StringAttr LaunchFuncOp::getKernelName() {
1276 return getKernel().getLeafReference();
1277}
1278
1279unsigned LaunchFuncOp::getNumKernelOperands() {
1280 return getKernelOperands().size();
1281}
1282
1283Value LaunchFuncOp::getKernelOperand(unsigned i) {
1284 return getKernelOperands()[i];
1285}
1286
1287KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
1288 auto operands = getOperands().drop_front(n: getAsyncDependencies().size());
1289 return KernelDim3{.x: operands[0], .y: operands[1], .z: operands[2]};
1290}
1291
1292KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() {
1293 auto operands = getOperands().drop_front(n: getAsyncDependencies().size());
1294 return KernelDim3{.x: operands[3], .y: operands[4], .z: operands[5]};
1295}
1296
1297KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() {
1298 assert(hasClusterSize() &&
1299 "cluster size is not set, check hasClusterSize() first");
1300 auto operands = getOperands().drop_front(n: getAsyncDependencies().size());
1301 return KernelDim3{.x: operands[6], .y: operands[7], .z: operands[8]};
1302}
1303
1304LogicalResult LaunchFuncOp::verify() {
1305 auto module = (*this)->getParentOfType<ModuleOp>();
1306 if (!module)
1307 return emitOpError(message: "expected to belong to a module");
1308
1309 if (!module->getAttrOfType<UnitAttr>(
1310 name: GPUDialect::getContainerModuleAttrName()))
1311 return emitOpError(message: "expected the closest surrounding module to have the '" +
1312 GPUDialect::getContainerModuleAttrName() +
1313 "' attribute");
1314
1315 if (hasClusterSize()) {
1316 if (getClusterSizeY().getType() != getClusterSizeX().getType() ||
1317 getClusterSizeZ().getType() != getClusterSizeX().getType())
1318 return emitOpError()
1319 << "expects types of the cluster dimensions must be the same";
1320 }
1321
1322 return success();
1323}
1324
1325static ParseResult
1326parseLaunchDimType(OpAsmParser &parser, Type &dimTy,
1327 std::optional<OpAsmParser::UnresolvedOperand> clusterValue,
1328 Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) {
1329 if (succeeded(Result: parser.parseOptionalColon())) {
1330 if (parser.parseType(result&: dimTy))
1331 return failure();
1332 } else {
1333 dimTy = IndexType::get(context: parser.getContext());
1334 }
1335 if (clusterValue.has_value()) {
1336 clusterXTy = clusterYTy = clusterZTy = dimTy;
1337 }
1338 return success();
1339}
1340
1341static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy,
1342 Value clusterValue, Type clusterXTy,
1343 Type clusterYTy, Type clusterZTy) {
1344 if (!dimTy.isIndex())
1345 printer << ": " << dimTy;
1346}
1347
1348static ParseResult parseLaunchFuncOperands(
1349 OpAsmParser &parser,
1350 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
1351 SmallVectorImpl<Type> &argTypes) {
1352 if (parser.parseOptionalKeyword(keyword: "args"))
1353 return success();
1354
1355 auto parseElement = [&]() -> ParseResult {
1356 return failure(IsFailure: parser.parseOperand(result&: argNames.emplace_back()) ||
1357 parser.parseColonType(result&: argTypes.emplace_back()));
1358 };
1359
1360 return parser.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren,
1361 parseElementFn: parseElement, contextMessage: " in argument list");
1362}
1363
1364static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
1365 OperandRange operands, TypeRange types) {
1366 if (operands.empty())
1367 return;
1368 printer << "args(";
1369 llvm::interleaveComma(c: llvm::zip_equal(t&: operands, u&: types), os&: printer,
1370 each_fn: [&](const auto &pair) {
1371 auto [operand, type] = pair;
1372 printer << operand << " : " << type;
1373 });
1374 printer << ")";
1375}
1376
1377//===----------------------------------------------------------------------===//
1378// ShuffleOp
1379//===----------------------------------------------------------------------===//
1380
1381void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
1382 int32_t offset, int32_t width, ShuffleMode mode) {
1383 build(odsBuilder&: builder, odsState&: result, value,
1384 offset: builder.create<arith::ConstantOp>(location: result.location,
1385 args: builder.getI32IntegerAttr(value: offset)),
1386 width: builder.create<arith::ConstantOp>(location: result.location,
1387 args: builder.getI32IntegerAttr(value: width)),
1388 mode);
1389}
1390
1391//===----------------------------------------------------------------------===//
1392// RotateOp
1393//===----------------------------------------------------------------------===//
1394
1395void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
1396 int32_t offset, int32_t width) {
1397 build(odsBuilder&: builder, odsState&: result, value,
1398 offset: builder.create<arith::ConstantOp>(location: result.location,
1399 args: builder.getI32IntegerAttr(value: offset)),
1400 width: builder.create<arith::ConstantOp>(location: result.location,
1401 args: builder.getI32IntegerAttr(value: width)));
1402}
1403
1404LogicalResult RotateOp::verify() {
1405 auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
1406 if (!offsetConstOp)
1407 return emitOpError() << "offset is not a constant value";
1408
1409 auto offsetIntAttr =
1410 llvm::dyn_cast<mlir::IntegerAttr>(Val: offsetConstOp.getValue());
1411
1412 auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
1413 if (!widthConstOp)
1414 return emitOpError() << "width is not a constant value";
1415
1416 auto widthIntAttr =
1417 llvm::dyn_cast<mlir::IntegerAttr>(Val: widthConstOp.getValue());
1418
1419 llvm::APInt offsetValue = offsetIntAttr.getValue();
1420 llvm::APInt widthValue = widthIntAttr.getValue();
1421
1422 if (!widthValue.isPowerOf2())
1423 return emitOpError() << "width must be a power of two";
1424
1425 if (offsetValue.sge(RHS: widthValue) || offsetValue.slt(RHS: 0)) {
1426 int64_t widthValueInt = widthValue.getSExtValue();
1427 return emitOpError() << "offset must be in the range [0, " << widthValueInt
1428 << ")";
1429 }
1430
1431 return success();
1432}
1433
1434//===----------------------------------------------------------------------===//
1435// BarrierOp
1436//===----------------------------------------------------------------------===//
1437
1438namespace {
1439
1440/// Remove gpu.barrier after gpu.barrier, the threads are already synchronized!
1441LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op,
1442 PatternRewriter &rewriter) {
1443 if (isa_and_nonnull<BarrierOp>(Val: op->getNextNode())) {
1444 rewriter.eraseOp(op);
1445 return success();
1446 }
1447 return failure();
1448}
1449
1450} // end anonymous namespace
1451
1452void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
1453 MLIRContext *context) {
1454 results.add(implFn: eraseRedundantGpuBarrierOps);
1455}
1456
1457//===----------------------------------------------------------------------===//
1458// GPUFuncOp
1459//===----------------------------------------------------------------------===//
1460
1461/// Adds a new block argument that corresponds to buffers located in
1462/// workgroup memory.
1463BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) {
1464 auto attrName = getNumWorkgroupAttributionsAttrName();
1465 auto attr = (*this)->getAttrOfType<IntegerAttr>(name: attrName);
1466 (*this)->setAttr(name: attrName,
1467 value: IntegerAttr::get(type: attr.getType(), value: attr.getValue() + 1));
1468 return getBody().insertArgument(
1469 index: getFunctionType().getNumInputs() + attr.getInt(), type, loc);
1470}
1471
1472/// Adds a new block argument that corresponds to buffers located in
1473/// private memory.
1474BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) {
1475 // Buffers on the private memory always come after buffers on the workgroup
1476 // memory.
1477 return getBody().addArgument(type, loc);
1478}
1479
1480void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
1481 StringRef name, FunctionType type,
1482 TypeRange workgroupAttributions,
1483 TypeRange privateAttributions,
1484 ArrayRef<NamedAttribute> attrs) {
1485 OpBuilder::InsertionGuard g(builder);
1486
1487 result.addAttribute(name: SymbolTable::getSymbolAttrName(),
1488 attr: builder.getStringAttr(bytes: name));
1489 result.addAttribute(name: getFunctionTypeAttrName(name: result.name),
1490 attr: TypeAttr::get(type));
1491 result.addAttribute(name: getNumWorkgroupAttributionsAttrName(),
1492 attr: builder.getI64IntegerAttr(value: workgroupAttributions.size()));
1493 result.addAttributes(newAttributes: attrs);
1494 Region *body = result.addRegion();
1495 Block *entryBlock = builder.createBlock(parent: body);
1496
1497 // TODO: Allow passing in proper locations here.
1498 for (Type argTy : type.getInputs())
1499 entryBlock->addArgument(type: argTy, loc: result.location);
1500 for (Type argTy : workgroupAttributions)
1501 entryBlock->addArgument(type: argTy, loc: result.location);
1502 for (Type argTy : privateAttributions)
1503 entryBlock->addArgument(type: argTy, loc: result.location);
1504}
1505
1506/// Parses a GPU function memory attribution.
1507///
1508/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
1509/// (`private` `(` ssa-id-and-type-list `)`)?
1510///
1511/// Note that this function parses only one of the two similar parts, with the
1512/// keyword provided as argument.
1513static ParseResult
1514parseAttributions(OpAsmParser &parser, StringRef keyword,
1515 SmallVectorImpl<OpAsmParser::Argument> &args,
1516 Attribute &attributionAttrs) {
1517 // If we could not parse the keyword, just assume empty list and succeed.
1518 if (failed(Result: parser.parseOptionalKeyword(keyword)))
1519 return success();
1520
1521 size_t existingArgs = args.size();
1522 ParseResult result =
1523 parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren,
1524 /*allowType=*/true, /*allowAttrs=*/true);
1525 if (failed(Result: result))
1526 return result;
1527
1528 bool hadAttrs = llvm::any_of(Range: ArrayRef(args).drop_front(N: existingArgs),
1529 P: [](const OpAsmParser::Argument &arg) -> bool {
1530 return arg.attrs && !arg.attrs.empty();
1531 });
1532 if (!hadAttrs) {
1533 attributionAttrs = nullptr;
1534 return result;
1535 }
1536
1537 Builder &builder = parser.getBuilder();
1538 SmallVector<Attribute> attributionAttrsVec;
1539 for (const auto &argument : ArrayRef(args).drop_front(N: existingArgs)) {
1540 if (!argument.attrs)
1541 attributionAttrsVec.push_back(Elt: builder.getDictionaryAttr(value: {}));
1542 else
1543 attributionAttrsVec.push_back(Elt: argument.attrs);
1544 }
1545 attributionAttrs = builder.getArrayAttr(value: attributionAttrsVec);
1546 return result;
1547}
1548
1549/// Parses a GPU function.
1550///
1551/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
1552/// (`->` function-result-list)? memory-attribution `kernel`?
1553/// function-attributes? region
1554ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
1555 SmallVector<OpAsmParser::Argument> entryArgs;
1556 SmallVector<DictionaryAttr> resultAttrs;
1557 SmallVector<Type> resultTypes;
1558 bool isVariadic;
1559
1560 // Parse the function name.
1561 StringAttr nameAttr;
1562 if (parser.parseSymbolName(result&: nameAttr, attrName: ::mlir::SymbolTable::getSymbolAttrName(),
1563 attrs&: result.attributes))
1564 return failure();
1565
1566 auto signatureLocation = parser.getCurrentLocation();
1567 if (failed(Result: function_interface_impl::parseFunctionSignatureWithArguments(
1568 parser, /*allowVariadic=*/false, arguments&: entryArgs, isVariadic, resultTypes,
1569 resultAttrs)))
1570 return failure();
1571
1572 if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
1573 return parser.emitError(loc: signatureLocation)
1574 << "gpu.func requires named arguments";
1575
1576 // Construct the function type. More types will be added to the region, but
1577 // not to the function type.
1578 Builder &builder = parser.getBuilder();
1579
1580 SmallVector<Type> argTypes;
1581 for (auto &arg : entryArgs)
1582 argTypes.push_back(Elt: arg.type);
1583 auto type = builder.getFunctionType(inputs: argTypes, results: resultTypes);
1584 result.addAttribute(name: getFunctionTypeAttrName(name: result.name),
1585 attr: TypeAttr::get(type));
1586
1587 call_interface_impl::addArgAndResultAttrs(
1588 builder, result, args: entryArgs, resultAttrs, argAttrsName: getArgAttrsAttrName(name: result.name),
1589 resAttrsName: getResAttrsAttrName(name: result.name));
1590
1591 Attribute workgroupAttributionAttrs;
1592 // Parse workgroup memory attributions.
1593 if (failed(Result: parseAttributions(parser, keyword: GPUFuncOp::getWorkgroupKeyword(),
1594 args&: entryArgs, attributionAttrs&: workgroupAttributionAttrs)))
1595 return failure();
1596
1597 // Store the number of operands we just parsed as the number of workgroup
1598 // memory attributions.
1599 unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
1600 result.addAttribute(name: GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
1601 attr: builder.getI64IntegerAttr(value: numWorkgroupAttrs));
1602 if (workgroupAttributionAttrs)
1603 result.addAttribute(name: GPUFuncOp::getWorkgroupAttribAttrsAttrName(name: result.name),
1604 attr: workgroupAttributionAttrs);
1605
1606 Attribute privateAttributionAttrs;
1607 // Parse private memory attributions.
1608 if (failed(Result: parseAttributions(parser, keyword: GPUFuncOp::getPrivateKeyword(),
1609 args&: entryArgs, attributionAttrs&: privateAttributionAttrs)))
1610 return failure();
1611 if (privateAttributionAttrs)
1612 result.addAttribute(name: GPUFuncOp::getPrivateAttribAttrsAttrName(name: result.name),
1613 attr: privateAttributionAttrs);
1614
1615 // Parse the kernel attribute if present.
1616 if (succeeded(Result: parser.parseOptionalKeyword(keyword: GPUFuncOp::getKernelKeyword())))
1617 result.addAttribute(name: GPUDialect::getKernelFuncAttrName(),
1618 attr: builder.getUnitAttr());
1619
1620 // Parse attributes.
1621 if (failed(Result: parser.parseOptionalAttrDictWithKeyword(result&: result.attributes)))
1622 return failure();
1623
1624 // Parse the region. If no argument names were provided, take all names
1625 // (including those of attributions) from the entry block.
1626 auto *body = result.addRegion();
1627 return parser.parseRegion(region&: *body, arguments: entryArgs);
1628}
1629
1630static void printAttributions(OpAsmPrinter &p, StringRef keyword,
1631 ArrayRef<BlockArgument> values,
1632 ArrayAttr attributes) {
1633 if (values.empty())
1634 return;
1635
1636 p << ' ' << keyword << '(';
1637 llvm::interleaveComma(
1638 c: llvm::enumerate(First&: values), os&: p, each_fn: [&p, attributes](auto pair) {
1639 BlockArgument v = pair.value();
1640 p << v << " : " << v.getType();
1641
1642 size_t attributionIndex = pair.index();
1643 DictionaryAttr attrs;
1644 if (attributes && attributionIndex < attributes.size())
1645 attrs = llvm::cast<DictionaryAttr>(Val: attributes[attributionIndex]);
1646 if (attrs)
1647 p.printOptionalAttrDict(attrs: attrs.getValue());
1648 });
1649 p << ')';
1650}
1651
1652void GPUFuncOp::print(OpAsmPrinter &p) {
1653 p << ' ';
1654 p.printSymbolName(symbolRef: getName());
1655
1656 FunctionType type = getFunctionType();
1657 function_interface_impl::printFunctionSignature(p, op: *this, argTypes: type.getInputs(),
1658 /*isVariadic=*/false,
1659 resultTypes: type.getResults());
1660
1661 printAttributions(p, keyword: getWorkgroupKeyword(), values: getWorkgroupAttributions(),
1662 attributes: getWorkgroupAttribAttrs().value_or(u: nullptr));
1663 printAttributions(p, keyword: getPrivateKeyword(), values: getPrivateAttributions(),
1664 attributes: getPrivateAttribAttrs().value_or(u: nullptr));
1665 if (isKernel())
1666 p << ' ' << getKernelKeyword();
1667
1668 function_interface_impl::printFunctionAttributes(
1669 p, op: *this,
1670 elided: {getNumWorkgroupAttributionsAttrName(),
1671 GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(),
1672 getArgAttrsAttrName(), getResAttrsAttrName(),
1673 getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()});
1674 p << ' ';
1675 p.printRegion(blocks&: getBody(), /*printEntryBlockArgs=*/false);
1676}
1677
1678static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index,
1679 StringAttr attrName) {
1680 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(Val: op->getAttr(name: attrName));
1681 if (!allAttrs || index >= allAttrs.size())
1682 return DictionaryAttr();
1683 return llvm::cast<DictionaryAttr>(Val: allAttrs[index]);
1684}
1685
1686DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) {
1687 return getAttributionAttrs(op: *this, index, attrName: getWorkgroupAttribAttrsAttrName());
1688}
1689
1690DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) {
1691 return getAttributionAttrs(op: *this, index, attrName: getPrivateAttribAttrsAttrName());
1692}
1693
1694static void setAttributionAttrs(GPUFuncOp op, unsigned index,
1695 DictionaryAttr value, StringAttr attrName) {
1696 MLIRContext *ctx = op.getContext();
1697 auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(Val: op->getAttr(name: attrName));
1698 SmallVector<Attribute> elements;
1699 if (allAttrs)
1700 elements.append(in_start: allAttrs.begin(), in_end: allAttrs.end());
1701 while (elements.size() <= index)
1702 elements.push_back(Elt: DictionaryAttr::get(context: ctx));
1703 if (!value)
1704 elements[index] = DictionaryAttr::get(context: ctx);
1705 else
1706 elements[index] = value;
1707 ArrayAttr newValue = ArrayAttr::get(context: ctx, value: elements);
1708 op->setAttr(name: attrName, value: newValue);
1709}
1710
1711void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index,
1712 DictionaryAttr value) {
1713 setAttributionAttrs(op: *this, index, value, attrName: getWorkgroupAttribAttrsAttrName());
1714}
1715
1716void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index,
1717 DictionaryAttr value) {
1718 setAttributionAttrs(op: *this, index, value, attrName: getPrivateAttribAttrsAttrName());
1719}
1720
1721static Attribute getAttributionAttr(GPUFuncOp op, unsigned index,
1722 StringAttr name, StringAttr attrsName) {
1723 DictionaryAttr dict = getAttributionAttrs(op, index, attrName: attrsName);
1724 if (!dict)
1725 return Attribute();
1726 return dict.get(name);
1727}
1728
1729Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index,
1730 StringAttr name) {
1731 assert(index < getNumWorkgroupAttributions() &&
1732 "index must map to a workgroup attribution");
1733 return getAttributionAttr(op: *this, index, name,
1734 attrsName: getWorkgroupAttribAttrsAttrName());
1735}
1736
1737Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index,
1738 StringAttr name) {
1739 assert(index < getNumPrivateAttributions() &&
1740 "index must map to a private attribution");
1741 return getAttributionAttr(op: *this, index, name,
1742 attrsName: getPrivateAttribAttrsAttrName());
1743}
1744
1745static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name,
1746 Attribute value, StringAttr attrsName) {
1747 MLIRContext *ctx = op.getContext();
1748 SmallVector<NamedAttribute> elems;
1749 DictionaryAttr oldDict = getAttributionAttrs(op, index, attrName: attrsName);
1750 if (oldDict)
1751 elems.append(in_start: oldDict.getValue().begin(), in_end: oldDict.getValue().end());
1752
1753 bool found = false;
1754 bool mustSort = true;
1755 for (unsigned i = 0, e = elems.size(); i < e; ++i) {
1756 if (elems[i].getName() == name) {
1757 found = true;
1758 if (!value) {
1759 std::swap(a&: elems[i], b&: elems[elems.size() - 1]);
1760 elems.pop_back();
1761 } else {
1762 mustSort = false;
1763 elems[i] = NamedAttribute(elems[i].getName(), value);
1764 }
1765 break;
1766 }
1767 }
1768 if (!found) {
1769 if (!value)
1770 return;
1771 elems.emplace_back(Args&: name, Args&: value);
1772 }
1773 if (mustSort) {
1774 DictionaryAttr::sortInPlace(array&: elems);
1775 }
1776 auto newDict = DictionaryAttr::getWithSorted(context: ctx, value: elems);
1777 setAttributionAttrs(op, index, value: newDict, attrName: attrsName);
1778}
1779
1780void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name,
1781 Attribute value) {
1782 assert(index < getNumWorkgroupAttributions() &&
1783 "index must map to a workgroup attribution");
1784 setAttributionAttr(op: *this, index, name, value,
1785 attrsName: getWorkgroupAttribAttrsAttrName());
1786}
1787
1788void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name,
1789 Attribute value) {
1790 assert(index < getNumPrivateAttributions() &&
1791 "index must map to a private attribution");
1792 setAttributionAttr(op: *this, index, name, value,
1793 attrsName: getPrivateAttribAttrsAttrName());
1794}
1795
1796LogicalResult GPUFuncOp::verifyType() {
1797 if (isKernel() && getFunctionType().getNumResults() != 0)
1798 return emitOpError() << "expected void return type for kernel function";
1799
1800 return success();
1801}
1802
1803/// Verifies the body of the function.
1804LogicalResult GPUFuncOp::verifyBody() {
1805 if (empty())
1806 return emitOpError() << "expected body with at least one block";
1807 unsigned numFuncArguments = getNumArguments();
1808 unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
1809 unsigned numBlockArguments = front().getNumArguments();
1810 if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
1811 return emitOpError() << "expected at least "
1812 << numFuncArguments + numWorkgroupAttributions
1813 << " arguments to body region";
1814
1815 ArrayRef<Type> funcArgTypes = getFunctionType().getInputs();
1816 for (unsigned i = 0; i < numFuncArguments; ++i) {
1817 Type blockArgType = front().getArgument(i).getType();
1818 if (funcArgTypes[i] != blockArgType)
1819 return emitOpError() << "expected body region argument #" << i
1820 << " to be of type " << funcArgTypes[i] << ", got "
1821 << blockArgType;
1822 }
1823
1824 if (failed(Result: verifyAttributions(op: getOperation(), attributions: getWorkgroupAttributions(),
1825 memorySpace: GPUDialect::getWorkgroupAddressSpace())) ||
1826 failed(Result: verifyAttributions(op: getOperation(), attributions: getPrivateAttributions(),
1827 memorySpace: GPUDialect::getPrivateAddressSpace())))
1828 return failure();
1829
1830 return success();
1831}
1832
1833//===----------------------------------------------------------------------===//
1834// ReturnOp
1835//===----------------------------------------------------------------------===//
1836
1837LogicalResult gpu::ReturnOp::verify() {
1838 GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>();
1839
1840 FunctionType funType = function.getFunctionType();
1841
1842 if (funType.getNumResults() != getOperands().size())
1843 return emitOpError()
1844 .append(args: "expected ", args: funType.getNumResults(), args: " result operands")
1845 .attachNote(noteLoc: function.getLoc())
1846 .append(arg: "return type declared here");
1847
1848 for (const auto &pair : llvm::enumerate(
1849 First: llvm::zip(t: function.getFunctionType().getResults(), u: getOperands()))) {
1850 auto [type, operand] = pair.value();
1851 if (type != operand.getType())
1852 return emitOpError() << "unexpected type `" << operand.getType()
1853 << "' for operand #" << pair.index();
1854 }
1855 return success();
1856}
1857
1858//===----------------------------------------------------------------------===//
1859// GPUModuleOp
1860//===----------------------------------------------------------------------===//
1861
1862void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1863 StringRef name, ArrayAttr targets,
1864 Attribute offloadingHandler) {
1865 result.addRegion()->emplaceBlock();
1866 Properties &props = result.getOrAddProperties<Properties>();
1867 if (targets)
1868 props.targets = targets;
1869 props.setSymName(builder.getStringAttr(bytes: name));
1870 props.offloadingHandler = offloadingHandler;
1871}
1872
1873void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
1874 StringRef name, ArrayRef<Attribute> targets,
1875 Attribute offloadingHandler) {
1876 build(builder, result, name,
1877 targets: targets.empty() ? ArrayAttr() : builder.getArrayAttr(value: targets),
1878 offloadingHandler);
1879}
1880
1881bool GPUModuleOp::hasTarget(Attribute target) {
1882 if (ArrayAttr targets = getTargetsAttr())
1883 return llvm::count(Range: targets.getValue(), Element: target);
1884 return false;
1885}
1886
1887void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
1888 ArrayAttr &targetsAttr = getProperties().targets;
1889 SmallVector<Attribute> targetsVector(targets);
1890 targetsAttr = ArrayAttr::get(context: getContext(), value: targetsVector);
1891}
1892
1893LogicalResult GPUModuleOp::verify() {
1894 auto targets = getOperation()->getAttrOfType<ArrayAttr>(name: "targets");
1895
1896 if (!targets)
1897 return success();
1898
1899 for (auto target : targets) {
1900 if (auto verifyTargetAttr =
1901 llvm::dyn_cast<TargetAttrVerifyInterface>(Val&: target)) {
1902 if (verifyTargetAttr.verifyTarget(module: getOperation()).failed())
1903 return failure();
1904 }
1905 }
1906 return success();
1907}
1908
1909//===----------------------------------------------------------------------===//
1910// GPUBinaryOp
1911//===----------------------------------------------------------------------===//
1912void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1913 Attribute offloadingHandler, ArrayAttr objects) {
1914 auto &properties = result.getOrAddProperties<Properties>();
1915 result.attributes.push_back(newAttribute: builder.getNamedAttr(
1916 name: SymbolTable::getSymbolAttrName(), val: builder.getStringAttr(bytes: name)));
1917 properties.objects = objects;
1918 if (offloadingHandler)
1919 properties.offloadingHandler = offloadingHandler;
1920 else
1921 properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(args: nullptr);
1922}
1923
1924void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name,
1925 Attribute offloadingHandler, ArrayRef<Attribute> objects) {
1926 build(builder, result, name, offloadingHandler,
1927 objects: objects.empty() ? ArrayAttr() : builder.getArrayAttr(value: objects));
1928}
1929
1930static ParseResult parseOffloadingHandler(OpAsmParser &parser,
1931 Attribute &offloadingHandler) {
1932 if (succeeded(Result: parser.parseOptionalLess())) {
1933 if (parser.parseAttribute(result&: offloadingHandler))
1934 return failure();
1935 if (parser.parseGreater())
1936 return failure();
1937 }
1938 if (!offloadingHandler)
1939 offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(args: nullptr);
1940 return success();
1941}
1942
1943static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op,
1944 Attribute offloadingHandler) {
1945 if (offloadingHandler != SelectObjectAttr::get(context: op->getContext(), target: nullptr))
1946 printer << '<' << offloadingHandler << '>';
1947}
1948
1949//===----------------------------------------------------------------------===//
1950// GPUMemcpyOp
1951//===----------------------------------------------------------------------===//
1952
1953LogicalResult MemcpyOp::verify() {
1954 auto srcType = getSrc().getType();
1955 auto dstType = getDst().getType();
1956
1957 if (getElementTypeOrSelf(type: srcType) != getElementTypeOrSelf(type: dstType))
1958 return emitOpError(message: "arguments have incompatible element type");
1959
1960 if (failed(Result: verifyCompatibleShape(type1: srcType, type2: dstType)))
1961 return emitOpError(message: "arguments have incompatible shape");
1962
1963 return success();
1964}
1965
1966namespace {
1967
1968/// Erases a common case of copy ops where a destination value is used only by
1969/// the copy op, alloc and dealloc ops.
1970struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> {
1971 using OpRewritePattern<MemcpyOp>::OpRewritePattern;
1972
1973 LogicalResult matchAndRewrite(MemcpyOp op,
1974 PatternRewriter &rewriter) const override {
1975 Value dest = op.getDst();
1976 Operation *destDefOp = dest.getDefiningOp();
1977 // `dest` must be defined by an op having Allocate memory effect in order to
1978 // perform the folding.
1979 if (!destDefOp ||
1980 !hasSingleEffect<MemoryEffects::Allocate>(op: destDefOp, value: dest))
1981 return failure();
1982 // We can erase `op` iff `dest` has no other use apart from its
1983 // use by `op` and dealloc ops.
1984 if (llvm::any_of(Range: dest.getUsers(), P: [op, dest](Operation *user) {
1985 return user != op &&
1986 !hasSingleEffect<MemoryEffects::Free>(op: user, value: dest);
1987 }))
1988 return failure();
1989 // We can perform the folding if and only if op has a single async
1990 // dependency and produces an async token as result, or if it does not have
1991 // any async dependency and does not produce any async token result.
1992 if (op.getAsyncDependencies().size() > 1 ||
1993 ((op.getAsyncDependencies().empty() && op.getAsyncToken()) ||
1994 (!op.getAsyncDependencies().empty() && !op.getAsyncToken())))
1995 return failure();
1996 rewriter.replaceOp(op, newValues: op.getAsyncDependencies());
1997 return success();
1998 }
1999};
2000
2001} // end anonymous namespace
2002
2003void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results,
2004 MLIRContext *context) {
2005 results.add<EraseTrivialCopyOp>(arg&: context);
2006}
2007
2008//===----------------------------------------------------------------------===//
2009// GPU_SubgroupMmaLoadMatrixOp
2010//===----------------------------------------------------------------------===//
2011
2012LogicalResult SubgroupMmaLoadMatrixOp::verify() {
2013 auto srcType = getSrcMemref().getType();
2014 auto resType = getRes().getType();
2015 auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(Val&: resType);
2016 auto operand = resMatrixType.getOperand();
2017 auto srcMemrefType = llvm::cast<MemRefType>(Val&: srcType);
2018
2019 if (!srcMemrefType.isLastDimUnitStride())
2020 return emitError(
2021 message: "expected source memref most minor dim must have unit stride");
2022
2023 if (operand != "AOp" && operand != "BOp" && operand != "COp")
2024 return emitError(message: "only AOp, BOp and COp can be loaded");
2025
2026 return success();
2027}
2028
2029//===----------------------------------------------------------------------===//
2030// GPU_SubgroupMmaStoreMatrixOp
2031//===----------------------------------------------------------------------===//
2032
2033LogicalResult SubgroupMmaStoreMatrixOp::verify() {
2034 auto srcType = getSrc().getType();
2035 auto dstType = getDstMemref().getType();
2036 auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(Val&: srcType);
2037 auto dstMemrefType = llvm::cast<MemRefType>(Val&: dstType);
2038
2039 if (!dstMemrefType.isLastDimUnitStride())
2040 return emitError(
2041 message: "expected destination memref most minor dim must have unit stride");
2042
2043 if (srcMatrixType.getOperand() != "COp")
2044 return emitError(
2045 message: "expected the operand matrix being stored to have 'COp' operand type");
2046
2047 return success();
2048}
2049
2050//===----------------------------------------------------------------------===//
2051// GPU_SubgroupMmaComputeOp
2052//===----------------------------------------------------------------------===//
2053
2054LogicalResult SubgroupMmaComputeOp::verify() {
2055 enum OperandMap { A, B, C };
2056 SmallVector<MMAMatrixType, 3> opTypes;
2057 opTypes.push_back(Elt: llvm::cast<MMAMatrixType>(Val: getOpA().getType()));
2058 opTypes.push_back(Elt: llvm::cast<MMAMatrixType>(Val: getOpB().getType()));
2059 opTypes.push_back(Elt: llvm::cast<MMAMatrixType>(Val: getOpC().getType()));
2060
2061 if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
2062 opTypes[C].getOperand() != "COp")
2063 return emitError(message: "operands must be in the order AOp, BOp, COp");
2064
2065 ArrayRef<int64_t> aShape, bShape, cShape;
2066 aShape = opTypes[A].getShape();
2067 bShape = opTypes[B].getShape();
2068 cShape = opTypes[C].getShape();
2069
2070 if (aShape[1] != bShape[0] || aShape[0] != cShape[0] ||
2071 bShape[1] != cShape[1])
2072 return emitError(message: "operand shapes do not satisfy matmul constraints");
2073
2074 return success();
2075}
2076
2077LogicalResult MemcpyOp::fold(FoldAdaptor adaptor,
2078 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2079 return memref::foldMemRefCast(op: *this);
2080}
2081
2082LogicalResult MemsetOp::fold(FoldAdaptor adaptor,
2083 SmallVectorImpl<::mlir::OpFoldResult> &results) {
2084 return memref::foldMemRefCast(op: *this);
2085}
2086
2087//===----------------------------------------------------------------------===//
2088// GPU_WaitOp
2089//===----------------------------------------------------------------------===//
2090
2091namespace {
2092
2093/// Remove gpu.wait op use of gpu.wait op def without async dependencies.
2094/// %t = gpu.wait async [] // No async dependencies.
2095/// ... gpu.wait ... [%t, ...] // %t can be removed.
2096struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> {
2097public:
2098 using OpRewritePattern::OpRewritePattern;
2099
2100 LogicalResult matchAndRewrite(WaitOp op,
2101 PatternRewriter &rewriter) const final {
2102 auto predicate = [](Value value) {
2103 auto waitOp = value.getDefiningOp<WaitOp>();
2104 return waitOp && waitOp->getNumOperands() == 0;
2105 };
2106 if (llvm::none_of(Range: op.getAsyncDependencies(), P: predicate))
2107 return failure();
2108 SmallVector<Value> validOperands;
2109 for (Value operand : op->getOperands()) {
2110 if (predicate(operand))
2111 continue;
2112 validOperands.push_back(Elt: operand);
2113 }
2114 rewriter.modifyOpInPlace(root: op, callable: [&]() { op->setOperands(validOperands); });
2115 return success();
2116 }
2117};
2118
2119/// Simplify trivial gpu.wait ops for the following patterns.
2120/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async
2121/// dependencies).
2122/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with
2123/// %t0.
2124/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async
2125/// dependencies nor return any token.
2126struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> {
2127public:
2128 using OpRewritePattern::OpRewritePattern;
2129
2130 LogicalResult matchAndRewrite(WaitOp op,
2131 PatternRewriter &rewriter) const final {
2132 // Erase gpu.wait ops that neither have any async dependencies nor return
2133 // any async token.
2134 if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) {
2135 rewriter.eraseOp(op);
2136 return success();
2137 }
2138 // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op.
2139 if (llvm::hasSingleElement(C: op.getAsyncDependencies()) &&
2140 op.getAsyncToken()) {
2141 rewriter.replaceOp(op, newValues: op.getAsyncDependencies());
2142 return success();
2143 }
2144 // Erase %t = gpu.wait async ... ops, where %t has no uses.
2145 if (op.getAsyncToken() && op.getAsyncToken().use_empty()) {
2146 rewriter.eraseOp(op);
2147 return success();
2148 }
2149 return failure();
2150 }
2151};
2152
2153} // end anonymous namespace
2154
2155void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results,
2156 MLIRContext *context) {
2157 results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(arg&: context);
2158}
2159
2160//===----------------------------------------------------------------------===//
2161// GPU_AllocOp
2162//===----------------------------------------------------------------------===//
2163
2164LogicalResult AllocOp::verify() {
2165 auto memRefType = llvm::cast<MemRefType>(Val: getMemref().getType());
2166
2167 if (getDynamicSizes().size() != memRefType.getNumDynamicDims())
2168 return emitOpError(message: "dimension operand count does not equal memref "
2169 "dynamic dimension count");
2170
2171 unsigned numSymbols = 0;
2172 if (!memRefType.getLayout().isIdentity())
2173 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
2174 if (getSymbolOperands().size() != numSymbols) {
2175 return emitOpError(
2176 message: "symbol operand count does not equal memref symbol count");
2177 }
2178
2179 return success();
2180}
2181
2182namespace {
2183
2184/// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to
2185/// `memref::AllocOp`.
2186struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> {
2187 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
2188
2189 LogicalResult matchAndRewrite(memref::DimOp dimOp,
2190 PatternRewriter &rewriter) const override {
2191 std::optional<int64_t> index = dimOp.getConstantIndex();
2192 if (!index)
2193 return failure();
2194
2195 auto memrefType = llvm::dyn_cast<MemRefType>(Val: dimOp.getSource().getType());
2196 if (!memrefType || index.value() >= memrefType.getRank() ||
2197 !memrefType.isDynamicDim(idx: index.value()))
2198 return failure();
2199
2200 auto alloc = dimOp.getSource().getDefiningOp<AllocOp>();
2201 if (!alloc)
2202 return failure();
2203
2204 Value substituteOp = *(alloc.getDynamicSizes().begin() +
2205 memrefType.getDynamicDimIndex(index: index.value()));
2206 rewriter.replaceOp(op: dimOp, newValues: substituteOp);
2207 return success();
2208 }
2209};
2210
2211} // namespace
2212
2213void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
2214 MLIRContext *context) {
2215 results.add<SimplifyDimOfAllocOp>(arg&: context);
2216}
2217
2218//===----------------------------------------------------------------------===//
2219// GPU object attribute
2220//===----------------------------------------------------------------------===//
2221
2222LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2223 Attribute target, CompilationTarget format,
2224 StringAttr object, DictionaryAttr properties,
2225 KernelTableAttr kernels) {
2226 if (!target)
2227 return emitError() << "the target attribute cannot be null";
2228 if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>())
2229 return success();
2230 return emitError() << "the target attribute must implement or promise the "
2231 "`gpu::TargetAttrInterface`";
2232}
2233
2234namespace {
2235ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format,
2236 StringAttr &object) {
2237 std::optional<CompilationTarget> formatResult;
2238 StringRef enumKeyword;
2239 auto loc = odsParser.getCurrentLocation();
2240 if (failed(Result: odsParser.parseOptionalKeyword(keyword: &enumKeyword)))
2241 formatResult = CompilationTarget::Fatbin;
2242 if (!formatResult &&
2243 (formatResult =
2244 gpu::symbolizeEnum<gpu::CompilationTarget>(str: enumKeyword)) &&
2245 odsParser.parseEqual())
2246 return odsParser.emitError(loc, message: "expected an equal sign");
2247 if (!formatResult)
2248 return odsParser.emitError(loc, message: "expected keyword for GPU object format");
2249 FailureOr<StringAttr> objectResult =
2250 FieldParser<StringAttr>::parse(parser&: odsParser);
2251 if (failed(Result: objectResult))
2252 return odsParser.emitError(loc: odsParser.getCurrentLocation(),
2253 message: "failed to parse GPU_ObjectAttr parameter "
2254 "'object' which is to be a `StringAttr`");
2255 format = *formatResult;
2256 object = *objectResult;
2257 return success();
2258}
2259
2260void printObject(AsmPrinter &odsParser, CompilationTarget format,
2261 StringAttr object) {
2262 if (format != CompilationTarget::Fatbin)
2263 odsParser << stringifyEnum(enumValue: format) << " = ";
2264 odsParser << object;
2265}
2266} // namespace
2267
2268//===----------------------------------------------------------------------===//
2269// GPU select object attribute
2270//===----------------------------------------------------------------------===//
2271
2272LogicalResult
2273gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2274 Attribute target) {
2275 // Check `target`, it can be null, an integer attr or a GPU Target attribute.
2276 if (target) {
2277 if (auto intAttr = mlir::dyn_cast<IntegerAttr>(Val&: target)) {
2278 if (intAttr.getInt() < 0) {
2279 return emitError() << "the object index must be positive";
2280 }
2281 } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) {
2282 return emitError()
2283 << "the target attribute must be a GPU Target attribute";
2284 }
2285 }
2286 return success();
2287}
2288
2289//===----------------------------------------------------------------------===//
2290// DynamicSharedMemoryOp
2291//===----------------------------------------------------------------------===//
2292
2293LogicalResult gpu::DynamicSharedMemoryOp::verify() {
2294 if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
2295 return emitOpError() << "must be inside an op with symbol table";
2296
2297 MemRefType memrefType = getResultMemref().getType();
2298 // Check address space
2299 if (!GPUDialect::hasWorkgroupMemoryAddressSpace(type: memrefType)) {
2300 return emitOpError() << "address space must be "
2301 << gpu::AddressSpaceAttr::getMnemonic() << "<"
2302 << stringifyEnum(enumValue: gpu::AddressSpace::Workgroup) << ">";
2303 }
2304 if (memrefType.hasStaticShape()) {
2305 return emitOpError() << "result memref type must be memref<?xi8, "
2306 "#gpu.address_space<workgroup>>";
2307 }
2308 return success();
2309}
2310
2311//===----------------------------------------------------------------------===//
2312// GPU WarpExecuteOnLane0Op
2313//===----------------------------------------------------------------------===//
2314
2315void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
2316 p << "(" << getLaneid() << ")";
2317
2318 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
2319 auto warpSizeAttr = getOperation()->getAttr(name: getWarpSizeAttrName());
2320 p << "[" << llvm::cast<IntegerAttr>(Val&: warpSizeAttr).getInt() << "]";
2321
2322 if (!getArgs().empty())
2323 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
2324 if (!getResults().empty())
2325 p << " -> (" << getResults().getTypes() << ')';
2326 p << " ";
2327 p.printRegion(blocks&: getRegion(),
2328 /*printEntryBlockArgs=*/true,
2329 /*printBlockTerminators=*/!getResults().empty());
2330 p.printOptionalAttrDict(attrs: getOperation()->getAttrs(), elidedAttrs: coreAttr);
2331}
2332
2333ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
2334 OperationState &result) {
2335 // Create the region.
2336 result.regions.reserve(N: 1);
2337 Region *warpRegion = result.addRegion();
2338
2339 auto &builder = parser.getBuilder();
2340 OpAsmParser::UnresolvedOperand laneId;
2341
2342 // Parse predicate operand.
2343 if (parser.parseLParen() ||
2344 parser.parseOperand(result&: laneId, /*allowResultNumber=*/false) ||
2345 parser.parseRParen())
2346 return failure();
2347
2348 int64_t warpSize;
2349 if (parser.parseLSquare() || parser.parseInteger(result&: warpSize) ||
2350 parser.parseRSquare())
2351 return failure();
2352 result.addAttribute(name: getWarpSizeAttrName(name: OperationName(getOperationName(),
2353 builder.getContext())),
2354 attr: builder.getI64IntegerAttr(value: warpSize));
2355
2356 if (parser.resolveOperand(operand: laneId, type: builder.getIndexType(), result&: result.operands))
2357 return failure();
2358
2359 llvm::SMLoc inputsOperandsLoc;
2360 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
2361 SmallVector<Type> inputTypes;
2362 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "args"))) {
2363 if (parser.parseLParen())
2364 return failure();
2365
2366 inputsOperandsLoc = parser.getCurrentLocation();
2367 if (parser.parseOperandList(result&: inputsOperands) ||
2368 parser.parseColonTypeList(result&: inputTypes) || parser.parseRParen())
2369 return failure();
2370 }
2371 if (parser.resolveOperands(operands&: inputsOperands, types&: inputTypes, loc: inputsOperandsLoc,
2372 result&: result.operands))
2373 return failure();
2374
2375 // Parse optional results type list.
2376 if (parser.parseOptionalArrowTypeList(result&: result.types))
2377 return failure();
2378 // Parse the region.
2379 if (parser.parseRegion(region&: *warpRegion, /*arguments=*/{},
2380 /*argTypes=*/enableNameShadowing: {}))
2381 return failure();
2382 WarpExecuteOnLane0Op::ensureTerminator(region&: *warpRegion, builder, loc: result.location);
2383
2384 // Parse the optional attribute list.
2385 if (parser.parseOptionalAttrDict(result&: result.attributes))
2386 return failure();
2387 return success();
2388}
2389
2390void WarpExecuteOnLane0Op::getSuccessorRegions(
2391 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2392 if (!point.isParent()) {
2393 regions.push_back(Elt: RegionSuccessor(getResults()));
2394 return;
2395 }
2396
2397 // The warp region is always executed
2398 regions.push_back(Elt: RegionSuccessor(&getWarpRegion()));
2399}
2400
2401void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2402 TypeRange resultTypes, Value laneId,
2403 int64_t warpSize) {
2404 build(odsBuilder&: builder, odsState&: result, resultTypes, laneid: laneId, warpSize,
2405 /*operands=*/args: {}, /*argTypes=*/blockArgTypes: {});
2406}
2407
2408void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
2409 TypeRange resultTypes, Value laneId,
2410 int64_t warpSize, ValueRange args,
2411 TypeRange blockArgTypes) {
2412 result.addOperands(newOperands: laneId);
2413 result.addAttribute(name: getAttributeNames()[0],
2414 attr: builder.getI64IntegerAttr(value: warpSize));
2415 result.addTypes(newTypes&: resultTypes);
2416 result.addOperands(newOperands: args);
2417 assert(args.size() == blockArgTypes.size());
2418 OpBuilder::InsertionGuard guard(builder);
2419 Region *warpRegion = result.addRegion();
2420 Block *block = builder.createBlock(parent: warpRegion);
2421 for (auto [type, arg] : llvm::zip_equal(t&: blockArgTypes, u&: args))
2422 block->addArgument(type, loc: arg.getLoc());
2423}
2424
2425/// Helper check if the distributed vector type is consistent with the expanded
2426/// type and distributed size.
2427static LogicalResult verifyDistributedType(Type expanded, Type distributed,
2428 int64_t warpSize, Operation *op) {
2429 // If the types matches there is no distribution.
2430 if (expanded == distributed)
2431 return success();
2432 auto expandedVecType = llvm::dyn_cast<VectorType>(Val&: expanded);
2433 auto distributedVecType = llvm::dyn_cast<VectorType>(Val&: distributed);
2434 if (!expandedVecType || !distributedVecType)
2435 return op->emitOpError(message: "expected vector type for distributed operands.");
2436 if (expandedVecType.getRank() != distributedVecType.getRank() ||
2437 expandedVecType.getElementType() != distributedVecType.getElementType())
2438 return op->emitOpError(
2439 message: "expected distributed vectors to have same rank and element type.");
2440
2441 SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
2442 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
2443 int64_t eDim = expandedVecType.getDimSize(idx: i);
2444 int64_t dDim = distributedVecType.getDimSize(idx: i);
2445 if (eDim == dDim)
2446 continue;
2447 if (eDim % dDim != 0)
2448 return op->emitOpError()
2449 << "expected expanded vector dimension #" << i << " (" << eDim
2450 << ") to be a multipler of the distributed vector dimension ("
2451 << dDim << ")";
2452 scales[i] = eDim / dDim;
2453 }
2454 if (std::accumulate(first: scales.begin(), last: scales.end(), init: 1,
2455 binary_op: std::multiplies<int64_t>()) != warpSize)
2456 return op->emitOpError()
2457 << "incompatible distribution dimensions from " << expandedVecType
2458 << " to " << distributedVecType << " with warp size = " << warpSize;
2459
2460 return success();
2461}
2462
2463LogicalResult WarpExecuteOnLane0Op::verify() {
2464 if (getArgs().size() != getWarpRegion().getNumArguments())
2465 return emitOpError(
2466 message: "expected same number op arguments and block arguments.");
2467 auto yield =
2468 cast<YieldOp>(Val: getWarpRegion().getBlocks().begin()->getTerminator());
2469 if (yield.getNumOperands() != getNumResults())
2470 return emitOpError(
2471 message: "expected same number of yield operands and return values.");
2472 int64_t warpSize = getWarpSize();
2473 for (auto [regionArg, arg] :
2474 llvm::zip_equal(t: getWarpRegion().getArguments(), u: getArgs())) {
2475 if (failed(Result: verifyDistributedType(expanded: regionArg.getType(), distributed: arg.getType(),
2476 warpSize, op: getOperation())))
2477 return failure();
2478 }
2479 for (auto [yieldOperand, result] :
2480 llvm::zip_equal(t: yield.getOperands(), u: getResults())) {
2481 if (failed(Result: verifyDistributedType(expanded: yieldOperand.getType(), distributed: result.getType(),
2482 warpSize, op: getOperation())))
2483 return failure();
2484 }
2485 return success();
2486}
2487bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
2488 return succeeded(
2489 Result: verifyDistributedType(expanded: lhs, distributed: rhs, warpSize: getWarpSize(), op: getOperation()));
2490}
2491
2492//===----------------------------------------------------------------------===//
2493// GPU KernelMetadataAttr
2494//===----------------------------------------------------------------------===//
2495
2496KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel,
2497 DictionaryAttr metadata) {
2498 assert(kernel && "invalid kernel");
2499 return get(name: kernel.getNameAttr(), functionType: kernel.getFunctionType(),
2500 argAttrs: kernel.getAllArgAttrs(), metadata);
2501}
2502
2503KernelMetadataAttr
2504KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
2505 FunctionOpInterface kernel,
2506 DictionaryAttr metadata) {
2507 assert(kernel && "invalid kernel");
2508 return getChecked(emitError, name: kernel.getNameAttr(), functionType: kernel.getFunctionType(),
2509 argAttrs: kernel.getAllArgAttrs(), metadata);
2510}
2511
2512KernelMetadataAttr
2513KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const {
2514 if (attrs.empty())
2515 return *this;
2516 NamedAttrList attrList;
2517 if (DictionaryAttr dict = getMetadata())
2518 attrList.append(newAttributes&: dict);
2519 attrList.append(newAttributes&: attrs);
2520 return KernelMetadataAttr::get(name: getName(), functionType: getFunctionType(), argAttrs: getArgAttrs(),
2521 metadata: attrList.getDictionary(context: getContext()));
2522}
2523
2524LogicalResult
2525KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2526 StringAttr name, Type functionType,
2527 ArrayAttr argAttrs, DictionaryAttr metadata) {
2528 if (name.empty())
2529 return emitError() << "the kernel name can't be empty";
2530 if (argAttrs) {
2531 if (llvm::any_of(Range&: argAttrs, P: [](Attribute attr) {
2532 return !llvm::isa<DictionaryAttr>(Val: attr);
2533 }))
2534 return emitError()
2535 << "all attributes in the array must be a dictionary attribute";
2536 }
2537 return success();
2538}
2539
2540//===----------------------------------------------------------------------===//
2541// GPU KernelTableAttr
2542//===----------------------------------------------------------------------===//
2543
2544KernelTableAttr KernelTableAttr::get(MLIRContext *context,
2545 ArrayRef<KernelMetadataAttr> kernels,
2546 bool isSorted) {
2547 // Note that `is_sorted` is always only invoked once even with assertions ON.
2548 assert((!isSorted || llvm::is_sorted(kernels)) &&
2549 "expected a sorted kernel array");
2550 // Immediately return the attribute if the array is sorted.
2551 if (isSorted || llvm::is_sorted(Range&: kernels))
2552 return Base::get(ctx: context, args&: kernels);
2553 // Sort the array.
2554 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2555 llvm::array_pod_sort(Start: kernelsTmp.begin(), End: kernelsTmp.end());
2556 return Base::get(ctx: context, args&: kernelsTmp);
2557}
2558
2559KernelTableAttr KernelTableAttr::getChecked(
2560 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
2561 ArrayRef<KernelMetadataAttr> kernels, bool isSorted) {
2562 // Note that `is_sorted` is always only invoked once even with assertions ON.
2563 assert((!isSorted || llvm::is_sorted(kernels)) &&
2564 "expected a sorted kernel array");
2565 // Immediately return the attribute if the array is sorted.
2566 if (isSorted || llvm::is_sorted(Range&: kernels))
2567 return Base::getChecked(emitErrorFn: emitError, ctx: context, args: kernels);
2568 // Sort the array.
2569 SmallVector<KernelMetadataAttr> kernelsTmp(kernels);
2570 llvm::array_pod_sort(Start: kernelsTmp.begin(), End: kernelsTmp.end());
2571 return Base::getChecked(emitErrorFn: emitError, ctx: context, args: kernelsTmp);
2572}
2573
2574LogicalResult
2575KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError,
2576 ArrayRef<KernelMetadataAttr> kernels) {
2577 if (kernels.size() < 2)
2578 return success();
2579 // Check that the kernels are uniquely named.
2580 if (std::adjacent_find(first: kernels.begin(), last: kernels.end(),
2581 binary_pred: [](KernelMetadataAttr l, KernelMetadataAttr r) {
2582 return l.getName() == r.getName();
2583 }) != kernels.end()) {
2584 return emitError() << "expected all kernels to be uniquely named";
2585 }
2586 return success();
2587}
2588
2589KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const {
2590 auto [iterator, found] = impl::findAttrSorted(first: begin(), last: end(), name: key);
2591 return found ? *iterator : KernelMetadataAttr();
2592}
2593
2594KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
2595 auto [iterator, found] = impl::findAttrSorted(first: begin(), last: end(), name: key);
2596 return found ? *iterator : KernelMetadataAttr();
2597}
2598
2599//===----------------------------------------------------------------------===//
2600// GPU target options
2601//===----------------------------------------------------------------------===//
2602
2603TargetOptions::TargetOptions(
2604 StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2605 StringRef cmdOptions, StringRef elfSection,
2606 CompilationTarget compilationTarget,
2607 function_ref<SymbolTable *()> getSymbolTableCallback,
2608 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2609 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2610 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2611 function_ref<void(StringRef)> isaCallback)
2612 : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink,
2613 cmdOptions, elfSection, compilationTarget,
2614 getSymbolTableCallback, initialLlvmIRCallback,
2615 linkedLlvmIRCallback, optimizedLlvmIRCallback,
2616 isaCallback) {}
2617
2618TargetOptions::TargetOptions(
2619 TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink,
2620 StringRef cmdOptions, StringRef elfSection,
2621 CompilationTarget compilationTarget,
2622 function_ref<SymbolTable *()> getSymbolTableCallback,
2623 function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2624 function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2625 function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2626 function_ref<void(StringRef)> isaCallback)
2627 : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink),
2628 cmdOptions(cmdOptions.str()), elfSection(elfSection.str()),
2629 compilationTarget(compilationTarget),
2630 getSymbolTableCallback(getSymbolTableCallback),
2631 initialLlvmIRCallback(initialLlvmIRCallback),
2632 linkedLlvmIRCallback(linkedLlvmIRCallback),
2633 optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2634 isaCallback(isaCallback), typeID(typeID) {}
2635
2636TypeID TargetOptions::getTypeID() const { return typeID; }
2637
2638StringRef TargetOptions::getToolkitPath() const { return toolkitPath; }
2639
2640ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const {
2641 return librariesToLink;
2642}
2643
2644StringRef TargetOptions::getCmdOptions() const { return cmdOptions; }
2645
2646StringRef TargetOptions::getELFSection() const { return elfSection; }
2647
2648SymbolTable *TargetOptions::getSymbolTable() const {
2649 return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
2650}
2651
2652function_ref<void(llvm::Module &)>
2653TargetOptions::getInitialLlvmIRCallback() const {
2654 return initialLlvmIRCallback;
2655}
2656
2657function_ref<void(llvm::Module &)>
2658TargetOptions::getLinkedLlvmIRCallback() const {
2659 return linkedLlvmIRCallback;
2660}
2661
2662function_ref<void(llvm::Module &)>
2663TargetOptions::getOptimizedLlvmIRCallback() const {
2664 return optimizedLlvmIRCallback;
2665}
2666
2667function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2668 return isaCallback;
2669}
2670
2671CompilationTarget TargetOptions::getCompilationTarget() const {
2672 return compilationTarget;
2673}
2674
2675CompilationTarget TargetOptions::getDefaultCompilationTarget() {
2676 return CompilationTarget::Fatbin;
2677}
2678
2679std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2680TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) {
2681 std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
2682 llvm::StringSaver stringSaver(options.first);
2683 StringRef opts = cmdOptions;
2684 // For a correct tokenization of the command line options `opts` must be
2685 // unquoted, otherwise the tokenization function returns a single string: the
2686 // unquoted `cmdOptions` -which is not the desired behavior.
2687 // Remove any quotes if they are at the beginning and end of the string:
2688 if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
2689 opts.consume_front(Prefix: "\""), opts.consume_back(Suffix: "\"");
2690 if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
2691 opts.consume_front(Prefix: "'"), opts.consume_back(Suffix: "'");
2692#ifdef _WIN32
2693 llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second,
2694 /*MarkEOLs=*/false);
2695#else
2696 llvm::cl::TokenizeGNUCommandLine(Source: opts, Saver&: stringSaver, NewArgv&: options.second,
2697 /*MarkEOLs=*/false);
2698#endif // _WIN32
2699 return options;
2700}
2701
2702std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2703TargetOptions::tokenizeCmdOptions() const {
2704 return tokenizeCmdOptions(cmdOptions);
2705}
2706
2707std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>>
2708TargetOptions::tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith) {
2709 size_t startPos = cmdOptions.find(svt: startsWith);
2710 if (startPos == std::string::npos)
2711 return {llvm::BumpPtrAllocator(), SmallVector<const char *>()};
2712
2713 auto tokenized =
2714 tokenizeCmdOptions(cmdOptions: cmdOptions.substr(pos: startPos + startsWith.size()));
2715 cmdOptions.resize(n: startPos);
2716 return tokenized;
2717}
2718
2719MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions)
2720
2721#include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc"
2722#include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc"
2723
2724#define GET_ATTRDEF_CLASSES
2725#include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc"
2726
2727#define GET_OP_CLASSES
2728#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
2729
2730#include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc"
2731

source code of mlir/lib/Dialect/GPU/IR/GPUDialect.cpp