1 | //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the GPU kernel-related dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/IR/Attributes.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinOps.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/Diagnostics.h" |
24 | #include "mlir/IR/DialectImplementation.h" |
25 | #include "mlir/IR/Matchers.h" |
26 | #include "mlir/IR/OpImplementation.h" |
27 | #include "mlir/IR/PatternMatch.h" |
28 | #include "mlir/IR/SymbolTable.h" |
29 | #include "mlir/IR/TypeUtilities.h" |
30 | #include "mlir/Interfaces/FunctionImplementation.h" |
31 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
32 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
33 | #include "mlir/Transforms/InliningUtils.h" |
34 | #include "llvm/ADT/STLExtras.h" |
35 | #include "llvm/ADT/TypeSwitch.h" |
36 | #include "llvm/Support/CommandLine.h" |
37 | #include "llvm/Support/ErrorHandling.h" |
38 | #include "llvm/Support/FormatVariadic.h" |
39 | #include "llvm/Support/InterleavedRange.h" |
40 | #include "llvm/Support/StringSaver.h" |
41 | #include <cassert> |
42 | #include <numeric> |
43 | |
44 | using namespace mlir; |
45 | using namespace mlir::gpu; |
46 | |
47 | #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc" |
48 | |
49 | //===----------------------------------------------------------------------===// |
50 | // GPU Device Mapping Attributes |
51 | //===----------------------------------------------------------------------===// |
52 | |
53 | int64_t GPUBlockMappingAttr::getMappingId() const { |
54 | return static_cast<int64_t>(getBlock()); |
55 | } |
56 | |
57 | bool GPUBlockMappingAttr::isLinearMapping() const { |
58 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
59 | } |
60 | |
61 | int64_t GPUBlockMappingAttr::getRelativeIndex() const { |
62 | return isLinearMapping() |
63 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
64 | : getMappingId(); |
65 | } |
66 | |
67 | int64_t GPUWarpgroupMappingAttr::getMappingId() const { |
68 | return static_cast<int64_t>(getWarpgroup()); |
69 | } |
70 | |
71 | bool GPUWarpgroupMappingAttr::isLinearMapping() const { |
72 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
73 | } |
74 | |
75 | int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const { |
76 | return isLinearMapping() |
77 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
78 | : getMappingId(); |
79 | } |
80 | |
81 | int64_t GPUWarpMappingAttr::getMappingId() const { |
82 | return static_cast<int64_t>(getWarp()); |
83 | } |
84 | |
85 | bool GPUWarpMappingAttr::isLinearMapping() const { |
86 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
87 | } |
88 | |
89 | int64_t GPUWarpMappingAttr::getRelativeIndex() const { |
90 | return isLinearMapping() |
91 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
92 | : getMappingId(); |
93 | } |
94 | |
95 | int64_t GPUThreadMappingAttr::getMappingId() const { |
96 | return static_cast<int64_t>(getThread()); |
97 | } |
98 | |
99 | bool GPUThreadMappingAttr::isLinearMapping() const { |
100 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
101 | } |
102 | |
103 | int64_t GPUThreadMappingAttr::getRelativeIndex() const { |
104 | return isLinearMapping() |
105 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
106 | : getMappingId(); |
107 | } |
108 | |
109 | int64_t GPUMemorySpaceMappingAttr::getMappingId() const { |
110 | return static_cast<int64_t>(getAddressSpace()); |
111 | } |
112 | |
113 | bool GPUMemorySpaceMappingAttr::isLinearMapping() const { |
114 | llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping"); |
115 | } |
116 | |
117 | int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const { |
118 | llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index"); |
119 | } |
120 | |
121 | //===----------------------------------------------------------------------===// |
122 | // MMAMatrixType |
123 | //===----------------------------------------------------------------------===// |
124 | |
125 | MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType, |
126 | StringRef operand) { |
127 | return Base::get(elementType.getContext(), shape, elementType, operand); |
128 | } |
129 | |
130 | MMAMatrixType |
131 | MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
132 | ArrayRef<int64_t> shape, Type elementType, |
133 | StringRef operand) { |
134 | return Base::getChecked(emitErrorFn: emitError, ctx: elementType.getContext(), args: shape, |
135 | args: elementType, args: operand); |
136 | } |
137 | |
138 | unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } |
139 | |
140 | ArrayRef<int64_t> MMAMatrixType::getShape() const { |
141 | return getImpl()->getShape(); |
142 | } |
143 | |
144 | Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } |
145 | |
146 | StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } |
147 | |
148 | bool MMAMatrixType::isValidElementType(Type elementType) { |
149 | return elementType.isF16() || elementType.isF32() || |
150 | elementType.isUnsignedInteger(width: 8) || elementType.isSignedInteger(width: 8) || |
151 | elementType.isInteger(width: 32); |
152 | } |
153 | |
154 | LogicalResult |
155 | MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError, |
156 | ArrayRef<int64_t> shape, Type elementType, |
157 | StringRef operand) { |
158 | if (operand != "AOp"&& operand != "BOp"&& operand != "COp") |
159 | return emitError() << "operand expected to be one of AOp, BOp or COp"; |
160 | |
161 | if (shape.size() != 2) |
162 | return emitError() << "MMAMatrixType must have exactly two dimensions"; |
163 | |
164 | if (!MMAMatrixType::isValidElementType(elementType)) |
165 | return emitError() |
166 | << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; |
167 | |
168 | return success(); |
169 | } |
170 | |
171 | //===----------------------------------------------------------------------===// |
172 | // GPUDialect |
173 | //===----------------------------------------------------------------------===// |
174 | |
175 | bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) { |
176 | if (!memorySpace) |
177 | return false; |
178 | if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) |
179 | return gpuAttr.getValue() == getWorkgroupAddressSpace(); |
180 | return false; |
181 | } |
182 | |
183 | bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) { |
184 | Attribute memorySpace = type.getMemorySpace(); |
185 | return isWorkgroupMemoryAddressSpace(memorySpace); |
186 | } |
187 | |
188 | bool GPUDialect::isKernel(Operation *op) { |
189 | UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); |
190 | return static_cast<bool>(isKernelAttr); |
191 | } |
192 | |
193 | namespace { |
194 | /// This class defines the interface for handling inlining with gpu |
195 | /// operations. |
196 | struct GPUInlinerInterface : public DialectInlinerInterface { |
197 | using DialectInlinerInterface::DialectInlinerInterface; |
198 | |
199 | /// All gpu dialect ops can be inlined. |
200 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
201 | return true; |
202 | } |
203 | }; |
204 | } // namespace |
205 | |
206 | void GPUDialect::initialize() { |
207 | addTypes<AsyncTokenType>(); |
208 | addTypes<MMAMatrixType>(); |
209 | addTypes<SparseDnTensorHandleType>(); |
210 | addTypes<SparseSpMatHandleType>(); |
211 | addTypes<SparseSpGEMMOpHandleType>(); |
212 | addOperations< |
213 | #define GET_OP_LIST |
214 | #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" |
215 | >(); |
216 | addAttributes< |
217 | #define GET_ATTRDEF_LIST |
218 | #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" |
219 | >(); |
220 | addInterfaces<GPUInlinerInterface>(); |
221 | declarePromisedInterface<bufferization::BufferDeallocationOpInterface, |
222 | TerminatorOp>(); |
223 | declarePromisedInterfaces< |
224 | ValueBoundsOpInterface, ClusterDimOp, ClusterDimBlocksOp, ClusterIdOp, |
225 | ClusterBlockIdOp, BlockDimOp, BlockIdOp, GridDimOp, ThreadIdOp, LaneIdOp, |
226 | SubgroupIdOp, GlobalIdOp, NumSubgroupsOp, SubgroupSizeOp, LaunchOp>(); |
227 | } |
228 | |
229 | static std::string getSparseHandleKeyword(SparseHandleKind kind) { |
230 | switch (kind) { |
231 | case SparseHandleKind::DnTensor: |
232 | return "sparse.dntensor_handle"; |
233 | case SparseHandleKind::SpMat: |
234 | return "sparse.spmat_handle"; |
235 | case SparseHandleKind::SpGEMMOp: |
236 | return "sparse.spgemmop_handle"; |
237 | } |
238 | llvm_unreachable("unknown sparse handle kind"); |
239 | return ""; |
240 | } |
241 | |
242 | Type GPUDialect::parseType(DialectAsmParser &parser) const { |
243 | // Parse the main keyword for the type. |
244 | StringRef keyword; |
245 | if (parser.parseKeyword(&keyword)) |
246 | return Type(); |
247 | MLIRContext *context = getContext(); |
248 | |
249 | // Handle 'async token' types. |
250 | if (keyword == "async.token") |
251 | return AsyncTokenType::get(context); |
252 | |
253 | if (keyword == "mma_matrix") { |
254 | SMLoc beginLoc = parser.getNameLoc(); |
255 | |
256 | // Parse '<'. |
257 | if (parser.parseLess()) |
258 | return nullptr; |
259 | |
260 | // Parse the size and elementType. |
261 | SmallVector<int64_t> shape; |
262 | Type elementType; |
263 | if (parser.parseDimensionList(shape, /*allowDynamic=*/false) || |
264 | parser.parseType(elementType)) |
265 | return nullptr; |
266 | |
267 | // Parse ',' |
268 | if (parser.parseComma()) |
269 | return nullptr; |
270 | |
271 | // Parse operand. |
272 | std::string operand; |
273 | if (failed(parser.parseOptionalString(&operand))) |
274 | return nullptr; |
275 | |
276 | // Parse '>'. |
277 | if (parser.parseGreater()) |
278 | return nullptr; |
279 | |
280 | return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( |
281 | parser.getEncodedSourceLoc(beginLoc)), |
282 | shape, elementType, operand); |
283 | } |
284 | |
285 | if (keyword == getSparseHandleKeyword(SparseHandleKind::DnTensor)) |
286 | return SparseDnTensorHandleType::get(context); |
287 | if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat)) |
288 | return SparseSpMatHandleType::get(context); |
289 | if (keyword == getSparseHandleKeyword(SparseHandleKind::SpGEMMOp)) |
290 | return SparseSpGEMMOpHandleType::get(context); |
291 | |
292 | parser.emitError(parser.getNameLoc(), "unknown gpu type: "+ keyword); |
293 | return Type(); |
294 | } |
295 | // TODO: print refined type here. Notice that should be corresponding to the |
296 | // parser |
297 | void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { |
298 | TypeSwitch<Type>(type) |
299 | .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) |
300 | .Case<SparseDnTensorHandleType>([&](Type) { |
301 | os << getSparseHandleKeyword(SparseHandleKind::DnTensor); |
302 | }) |
303 | .Case<SparseSpMatHandleType>( |
304 | [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); }) |
305 | .Case<SparseSpGEMMOpHandleType>([&](Type) { |
306 | os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp); |
307 | }) |
308 | .Case<MMAMatrixType>([&](MMAMatrixType fragTy) { |
309 | os << "mma_matrix<"; |
310 | auto shape = fragTy.getShape(); |
311 | for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) |
312 | os << *dim << 'x'; |
313 | os << shape.back() << 'x' << fragTy.getElementType(); |
314 | os << ", \""<< fragTy.getOperand() << "\""<< '>'; |
315 | }) |
316 | .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); |
317 | } |
318 | |
319 | static LogicalResult verifyKnownLaunchSizeAttr(Operation *op, |
320 | NamedAttribute attr) { |
321 | auto array = dyn_cast<DenseI32ArrayAttr>(attr.getValue()); |
322 | if (!array) |
323 | return op->emitOpError(message: Twine(attr.getName()) + |
324 | " must be a dense i32 array"); |
325 | if (array.size() != 3) |
326 | return op->emitOpError(message: Twine(attr.getName()) + |
327 | " must contain exactly 3 elements"); |
328 | return success(); |
329 | } |
330 | |
331 | LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, |
332 | NamedAttribute attr) { |
333 | if (attr.getName() == getKnownBlockSizeAttrHelper().getName()) |
334 | return verifyKnownLaunchSizeAttr(op, attr); |
335 | if (attr.getName() == getKnownGridSizeAttrHelper().getName()) |
336 | return verifyKnownLaunchSizeAttr(op, attr); |
337 | if (!llvm::isa<UnitAttr>(attr.getValue()) || |
338 | attr.getName() != getContainerModuleAttrName()) |
339 | return success(); |
340 | |
341 | auto module = dyn_cast<ModuleOp>(op); |
342 | if (!module) |
343 | return op->emitError("expected '") |
344 | << getContainerModuleAttrName() << "' attribute to be attached to '" |
345 | << ModuleOp::getOperationName() << '\''; |
346 | |
347 | auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { |
348 | // Ignore launches that are nested more or less deep than functions in the |
349 | // module we are currently checking. |
350 | if (!launchOp->getParentOp() || |
351 | launchOp->getParentOp()->getParentOp() != module) |
352 | return success(); |
353 | |
354 | // Ignore launch ops with missing attributes here. The errors will be |
355 | // reported by the verifiers of those ops. |
356 | if (!launchOp->getAttrOfType<SymbolRefAttr>( |
357 | LaunchFuncOp::getKernelAttrName(launchOp->getName()))) |
358 | return success(); |
359 | |
360 | // Check that `launch_func` refers to a well-formed GPU kernel container. |
361 | StringAttr kernelContainerName = launchOp.getKernelModuleName(); |
362 | Operation *kernelContainer = module.lookupSymbol(kernelContainerName); |
363 | if (!kernelContainer) |
364 | return launchOp.emitOpError() |
365 | << "kernel container '"<< kernelContainerName.getValue() |
366 | << "' is undefined"; |
367 | |
368 | // If the container is a GPU binary op return success. |
369 | if (isa<BinaryOp>(kernelContainer)) |
370 | return success(); |
371 | |
372 | auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer); |
373 | if (!kernelModule) |
374 | return launchOp.emitOpError() |
375 | << "kernel module '"<< kernelContainerName.getValue() |
376 | << "' is undefined"; |
377 | |
378 | // Check that `launch_func` refers to a well-formed kernel function. |
379 | Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr()); |
380 | if (!kernelFunc) |
381 | return launchOp.emitOpError("kernel function '") |
382 | << launchOp.getKernel() << "' is undefined"; |
383 | auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc); |
384 | if (!kernelConvertedFunction) { |
385 | InFlightDiagnostic diag = launchOp.emitOpError() |
386 | << "referenced kernel '"<< launchOp.getKernel() |
387 | << "' is not a function"; |
388 | diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here"; |
389 | return diag; |
390 | } |
391 | |
392 | if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( |
393 | GPUDialect::getKernelFuncAttrName())) |
394 | return launchOp.emitOpError("kernel function is missing the '") |
395 | << GPUDialect::getKernelFuncAttrName() << "' attribute"; |
396 | |
397 | // TODO: If the kernel isn't a GPU function (which happens during separate |
398 | // compilation), do not check type correspondence as it would require the |
399 | // verifier to be aware of the type conversion. |
400 | auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc); |
401 | if (!kernelGPUFunction) |
402 | return success(); |
403 | |
404 | unsigned actualNumArguments = launchOp.getNumKernelOperands(); |
405 | unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); |
406 | if (expectedNumArguments != actualNumArguments) |
407 | return launchOp.emitOpError("got ") |
408 | << actualNumArguments << " kernel operands but expected " |
409 | << expectedNumArguments; |
410 | |
411 | auto functionType = kernelGPUFunction.getFunctionType(); |
412 | for (unsigned i = 0; i < expectedNumArguments; ++i) { |
413 | if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { |
414 | return launchOp.emitOpError("type of function argument ") |
415 | << i << " does not match"; |
416 | } |
417 | } |
418 | |
419 | return success(); |
420 | }); |
421 | |
422 | return walkResult.wasInterrupted() ? failure() : success(); |
423 | } |
424 | |
425 | /// Parses an optional list of async operands with an optional leading keyword. |
426 | /// (`async`)? (`[` ssa-id-list `]`)? |
427 | /// |
428 | /// This method is used by the tablegen assembly format for async ops as well. |
429 | static ParseResult parseAsyncDependencies( |
430 | OpAsmParser &parser, Type &asyncTokenType, |
431 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) { |
432 | auto loc = parser.getCurrentLocation(); |
433 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: "async"))) { |
434 | if (parser.getNumResults() == 0) |
435 | return parser.emitError(loc, message: "needs to be named when marked 'async'"); |
436 | asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); |
437 | } |
438 | return parser.parseOperandList(result&: asyncDependencies, |
439 | delimiter: OpAsmParser::Delimiter::OptionalSquare); |
440 | } |
441 | |
442 | /// Prints optional async dependencies with its leading keyword. |
443 | /// (`async`)? (`[` ssa-id-list `]`)? |
444 | // Used by the tablegen assembly format for several async ops. |
445 | static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, |
446 | Type asyncTokenType, |
447 | OperandRange asyncDependencies) { |
448 | if (asyncTokenType) |
449 | printer << "async"; |
450 | if (asyncDependencies.empty()) |
451 | return; |
452 | if (asyncTokenType) |
453 | printer << ' '; |
454 | printer << llvm::interleaved_array(R: asyncDependencies); |
455 | } |
456 | |
457 | // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp. |
458 | /// Parses a GPU function memory attribution. |
459 | /// |
460 | /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? |
461 | /// (`private` `(` ssa-id-and-type-list `)`)? |
462 | /// |
463 | /// Note that this function parses only one of the two similar parts, with the |
464 | /// keyword provided as argument. |
465 | static ParseResult |
466 | parseAttributions(OpAsmParser &parser, StringRef keyword, |
467 | SmallVectorImpl<OpAsmParser::Argument> &args) { |
468 | // If we could not parse the keyword, just assume empty list and succeed. |
469 | if (failed(Result: parser.parseOptionalKeyword(keyword))) |
470 | return success(); |
471 | |
472 | return parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren, |
473 | /*allowType=*/true); |
474 | } |
475 | |
476 | /// Prints a GPU function memory attribution. |
477 | static void printAttributions(OpAsmPrinter &p, StringRef keyword, |
478 | ArrayRef<BlockArgument> values) { |
479 | if (values.empty()) |
480 | return; |
481 | |
482 | auto printBlockArg = [](BlockArgument v) { |
483 | return llvm::formatv(Fmt: "{} : {}", Vals&: v, Vals: v.getType()); |
484 | }; |
485 | p << ' ' << keyword << '(' |
486 | << llvm::interleaved(R: llvm::map_range(C&: values, F: printBlockArg)) << ')'; |
487 | } |
488 | |
489 | /// Verifies a GPU function memory attribution. |
490 | static LogicalResult verifyAttributions(Operation *op, |
491 | ArrayRef<BlockArgument> attributions, |
492 | gpu::AddressSpace memorySpace) { |
493 | for (Value v : attributions) { |
494 | auto type = llvm::dyn_cast<MemRefType>(v.getType()); |
495 | if (!type) |
496 | return op->emitOpError() << "expected memref type in attribution"; |
497 | |
498 | // We can only verify the address space if it hasn't already been lowered |
499 | // from the AddressSpaceAttr to a target-specific numeric value. |
500 | auto addressSpace = |
501 | llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); |
502 | if (!addressSpace) |
503 | continue; |
504 | if (addressSpace.getValue() != memorySpace) |
505 | return op->emitOpError() |
506 | << "expected memory space "<< stringifyAddressSpace(memorySpace) |
507 | << " in attribution"; |
508 | } |
509 | return success(); |
510 | } |
511 | |
512 | //===----------------------------------------------------------------------===// |
513 | // AllReduceOp |
514 | //===----------------------------------------------------------------------===// |
515 | |
516 | static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, |
517 | Type resType) { |
518 | using Kind = gpu::AllReduceOperation; |
519 | if (llvm::is_contained( |
520 | {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF}, |
521 | opName)) { |
522 | if (!isa<FloatType>(Val: resType)) |
523 | return failure(); |
524 | } |
525 | |
526 | if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI, |
527 | Kind::AND, Kind::OR, Kind::XOR}, |
528 | opName)) { |
529 | if (!isa<IntegerType>(Val: resType)) |
530 | return failure(); |
531 | } |
532 | |
533 | return success(); |
534 | } |
535 | |
536 | LogicalResult gpu::AllReduceOp::verifyRegions() { |
537 | if (getBody().empty() != getOp().has_value()) |
538 | return emitError("expected either an op attribute or a non-empty body"); |
539 | if (!getBody().empty()) { |
540 | if (getBody().getNumArguments() != 2) |
541 | return emitError("expected two region arguments"); |
542 | for (auto argument : getBody().getArguments()) { |
543 | if (argument.getType() != getType()) |
544 | return emitError("incorrect region argument type"); |
545 | } |
546 | unsigned yieldCount = 0; |
547 | for (Block &block : getBody()) { |
548 | if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { |
549 | if (yield.getNumOperands() != 1) |
550 | return emitError("expected one gpu.yield operand"); |
551 | if (yield.getOperand(0).getType() != getType()) |
552 | return emitError("incorrect gpu.yield type"); |
553 | ++yieldCount; |
554 | } |
555 | } |
556 | if (yieldCount == 0) |
557 | return emitError("expected gpu.yield op in region"); |
558 | } else { |
559 | gpu::AllReduceOperation opName = *getOp(); |
560 | if (failed(verifyReduceOpAndType(opName, getType()))) { |
561 | return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) |
562 | << "` reduction operation is not compatible with type " |
563 | << getType(); |
564 | } |
565 | } |
566 | |
567 | return success(); |
568 | } |
569 | |
570 | static bool canMakeGroupOpUniform(Operation *op) { |
571 | auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp()); |
572 | if (!launchOp) |
573 | return false; |
574 | |
575 | Region &body = launchOp.getBody(); |
576 | assert(!body.empty() && "Invalid region"); |
577 | |
578 | // Only convert ops in gpu::launch entry block for now. |
579 | return op->getBlock() == &body.front(); |
580 | } |
581 | |
582 | OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) { |
583 | if (!getUniform() && canMakeGroupOpUniform(*this)) { |
584 | setUniform(true); |
585 | return getResult(); |
586 | } |
587 | |
588 | return nullptr; |
589 | } |
590 | |
591 | // TODO: Support optional custom attributes (without dialect prefix). |
592 | static ParseResult parseAllReduceOperation(AsmParser &parser, |
593 | AllReduceOperationAttr &attr) { |
594 | StringRef enumStr; |
595 | if (!parser.parseOptionalKeyword(keyword: &enumStr)) { |
596 | std::optional<AllReduceOperation> op = |
597 | gpu::symbolizeAllReduceOperation(enumStr); |
598 | if (!op) |
599 | return parser.emitError(loc: parser.getCurrentLocation(), message: "invalid op kind"); |
600 | attr = AllReduceOperationAttr::get(parser.getContext(), *op); |
601 | } |
602 | return success(); |
603 | } |
604 | |
605 | static void printAllReduceOperation(AsmPrinter &printer, Operation *op, |
606 | AllReduceOperationAttr attr) { |
607 | if (attr) |
608 | attr.print(printer); |
609 | } |
610 | |
611 | //===----------------------------------------------------------------------===// |
612 | // SubgroupReduceOp |
613 | //===----------------------------------------------------------------------===// |
614 | |
615 | LogicalResult gpu::SubgroupReduceOp::verify() { |
616 | Type elemType = getType(); |
617 | if (auto vecTy = dyn_cast<VectorType>(elemType)) { |
618 | if (vecTy.isScalable()) |
619 | return emitOpError() << "is not compatible with scalable vector types"; |
620 | |
621 | elemType = vecTy.getElementType(); |
622 | } |
623 | |
624 | gpu::AllReduceOperation opName = getOp(); |
625 | if (failed(verifyReduceOpAndType(opName, elemType))) { |
626 | return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) |
627 | << "` reduction operation is not compatible with type " |
628 | << getType(); |
629 | } |
630 | |
631 | auto clusterSize = getClusterSize(); |
632 | if (clusterSize) { |
633 | uint32_t size = *clusterSize; |
634 | if (!llvm::isPowerOf2_32(size)) { |
635 | return emitOpError() << "cluster size "<< size |
636 | << " is not a power of two"; |
637 | } |
638 | } |
639 | |
640 | uint32_t stride = getClusterStride(); |
641 | if (stride != 1 && !clusterSize) { |
642 | return emitOpError() << "cluster stride can only be specified if cluster " |
643 | "size is specified"; |
644 | } |
645 | if (!llvm::isPowerOf2_32(stride)) { |
646 | return emitOpError() << "cluster stride "<< stride |
647 | << " is not a power of two"; |
648 | } |
649 | |
650 | return success(); |
651 | } |
652 | |
653 | OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) { |
654 | if (getClusterSize() == 1) |
655 | return getValue(); |
656 | |
657 | if (!getUniform() && canMakeGroupOpUniform(*this)) { |
658 | setUniform(true); |
659 | return getResult(); |
660 | } |
661 | |
662 | return nullptr; |
663 | } |
664 | |
665 | //===----------------------------------------------------------------------===// |
666 | // AsyncOpInterface |
667 | //===----------------------------------------------------------------------===// |
668 | |
669 | void gpu::addAsyncDependency(Operation *op, Value token) { |
670 | op->insertOperands(index: 0, operands: {token}); |
671 | if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) |
672 | return; |
673 | auto attrName = |
674 | OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); |
675 | auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName); |
676 | |
677 | // Async dependencies is the only variadic operand. |
678 | if (!sizeAttr) |
679 | return; |
680 | |
681 | SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef()); |
682 | ++sizes.front(); |
683 | op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes)); |
684 | } |
685 | |
686 | //===----------------------------------------------------------------------===// |
687 | // LaunchOp |
688 | //===----------------------------------------------------------------------===// |
689 | |
690 | void LaunchOp::build(OpBuilder &builder, OperationState &result, |
691 | Value gridSizeX, Value gridSizeY, Value gridSizeZ, |
692 | Value getBlockSizeX, Value getBlockSizeY, |
693 | Value getBlockSizeZ, Value dynamicSharedMemorySize, |
694 | Type asyncTokenType, ValueRange asyncDependencies, |
695 | TypeRange workgroupAttributions, |
696 | TypeRange privateAttributions, Value clusterSizeX, |
697 | Value clusterSizeY, Value clusterSizeZ) { |
698 | OpBuilder::InsertionGuard g(builder); |
699 | |
700 | // Add a WorkGroup attribution attribute. This attribute is required to |
701 | // identify private attributions in the list of block argguments. |
702 | result.addAttribute(getNumWorkgroupAttributionsAttrName(), |
703 | builder.getI64IntegerAttr(workgroupAttributions.size())); |
704 | |
705 | // Add Op operands. |
706 | result.addOperands(asyncDependencies); |
707 | if (asyncTokenType) |
708 | result.types.push_back(builder.getType<AsyncTokenType>()); |
709 | |
710 | // Add grid and block sizes as op operands, followed by the data operands. |
711 | result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX, |
712 | getBlockSizeY, getBlockSizeZ}); |
713 | if (clusterSizeX) |
714 | result.addOperands(clusterSizeX); |
715 | if (clusterSizeY) |
716 | result.addOperands(clusterSizeY); |
717 | if (clusterSizeZ) |
718 | result.addOperands(clusterSizeZ); |
719 | if (dynamicSharedMemorySize) |
720 | result.addOperands(dynamicSharedMemorySize); |
721 | |
722 | // Create a kernel body region with kNumConfigRegionAttributes + N memory |
723 | // attributions, where the first kNumConfigRegionAttributes arguments have |
724 | // `index` type and the rest have the same types as the data operands. |
725 | Region *kernelRegion = result.addRegion(); |
726 | Block *body = builder.createBlock(kernelRegion); |
727 | // TODO: Allow passing in proper locations here. |
728 | for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i) |
729 | body->addArgument(builder.getIndexType(), result.location); |
730 | // Add WorkGroup & Private attributions to the region arguments. |
731 | for (Type argTy : workgroupAttributions) |
732 | body->addArgument(argTy, result.location); |
733 | for (Type argTy : privateAttributions) |
734 | body->addArgument(argTy, result.location); |
735 | // Fill OperandSegmentSize Attribute. |
736 | SmallVector<int32_t, 11> segmentSizes(11, 1); |
737 | segmentSizes.front() = asyncDependencies.size(); |
738 | segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0; |
739 | segmentSizes[7] = clusterSizeX ? 1 : 0; |
740 | segmentSizes[8] = clusterSizeY ? 1 : 0; |
741 | segmentSizes[9] = clusterSizeZ ? 1 : 0; |
742 | result.addAttribute(getOperandSegmentSizeAttr(), |
743 | builder.getDenseI32ArrayAttr(segmentSizes)); |
744 | } |
745 | |
746 | KernelDim3 LaunchOp::getBlockIds() { |
747 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
748 | auto args = getBody().getArguments(); |
749 | return KernelDim3{args[0], args[1], args[2]}; |
750 | } |
751 | |
752 | KernelDim3 LaunchOp::getThreadIds() { |
753 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
754 | auto args = getBody().getArguments(); |
755 | return KernelDim3{args[3], args[4], args[5]}; |
756 | } |
757 | |
758 | KernelDim3 LaunchOp::getGridSize() { |
759 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
760 | auto args = getBody().getArguments(); |
761 | return KernelDim3{args[6], args[7], args[8]}; |
762 | } |
763 | |
764 | KernelDim3 LaunchOp::getBlockSize() { |
765 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
766 | auto args = getBody().getArguments(); |
767 | return KernelDim3{args[9], args[10], args[11]}; |
768 | } |
769 | |
770 | std::optional<KernelDim3> LaunchOp::getClusterIds() { |
771 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
772 | if (!hasClusterSize()) |
773 | return std::nullopt; |
774 | auto args = getBody().getArguments(); |
775 | return KernelDim3{args[12], args[13], args[14]}; |
776 | } |
777 | |
778 | std::optional<KernelDim3> LaunchOp::getClusterSize() { |
779 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
780 | if (!hasClusterSize()) |
781 | return std::nullopt; |
782 | auto args = getBody().getArguments(); |
783 | return KernelDim3{args[15], args[16], args[17]}; |
784 | } |
785 | |
786 | KernelDim3 LaunchOp::getGridSizeOperandValues() { |
787 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
788 | return KernelDim3{operands[0], operands[1], operands[2]}; |
789 | } |
790 | |
791 | KernelDim3 LaunchOp::getBlockSizeOperandValues() { |
792 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
793 | return KernelDim3{operands[3], operands[4], operands[5]}; |
794 | } |
795 | |
796 | std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() { |
797 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
798 | if (!hasClusterSize()) |
799 | return std::nullopt; |
800 | return KernelDim3{operands[6], operands[7], operands[8]}; |
801 | } |
802 | |
803 | LogicalResult LaunchOp::verify() { |
804 | if (!(hasClusterSize()) && |
805 | (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ())) |
806 | return emitOpError() << "cluster size must be all present"; |
807 | return success(); |
808 | } |
809 | |
810 | LogicalResult LaunchOp::verifyRegions() { |
811 | // Kernel launch takes kNumConfigOperands leading operands for grid/block |
812 | // sizes and transforms them into kNumConfigRegionAttributes region arguments |
813 | // for block/thread identifiers and grid/block sizes. |
814 | if (!getBody().empty()) { |
815 | if (getBody().getNumArguments() < |
816 | kNumConfigRegionAttributes + getNumWorkgroupAttributions()) |
817 | return emitOpError("unexpected number of region arguments"); |
818 | } |
819 | |
820 | // Verify Attributions Address Spaces. |
821 | if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), |
822 | GPUDialect::getWorkgroupAddressSpace())) || |
823 | failed(verifyAttributions(getOperation(), getPrivateAttributions(), |
824 | GPUDialect::getPrivateAddressSpace()))) |
825 | return failure(); |
826 | |
827 | // Block terminators without successors are expected to exit the kernel region |
828 | // and must be `gpu.terminator`. |
829 | for (Block &block : getBody()) { |
830 | if (block.empty()) |
831 | continue; |
832 | if (block.back().getNumSuccessors() != 0) |
833 | continue; |
834 | if (!isa<gpu::TerminatorOp>(&block.back())) { |
835 | return block.back() |
836 | .emitError() |
837 | .append("expected '", gpu::TerminatorOp::getOperationName(), |
838 | "' or a terminator with successors") |
839 | .attachNote(getLoc()) |
840 | .append("in '", LaunchOp::getOperationName(), "' body region"); |
841 | } |
842 | } |
843 | |
844 | if (getNumResults() == 0 && getAsyncToken()) |
845 | return emitOpError("needs to be named when async keyword is specified"); |
846 | |
847 | return success(); |
848 | } |
849 | |
850 | // Pretty-print the kernel grid/block size assignment as |
851 | // (%iter-x, %iter-y, %iter-z) in |
852 | // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) |
853 | // where %size-* and %iter-* will correspond to the body region arguments. |
854 | static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, |
855 | KernelDim3 operands, KernelDim3 ids) { |
856 | p << '(' << ids.x << ", "<< ids.y << ", "<< ids.z << ") in ("; |
857 | p << size.x << " = "<< operands.x << ", "; |
858 | p << size.y << " = "<< operands.y << ", "; |
859 | p << size.z << " = "<< operands.z << ')'; |
860 | } |
861 | |
862 | void LaunchOp::print(OpAsmPrinter &p) { |
863 | if (getAsyncToken()) { |
864 | p << " async"; |
865 | if (!getAsyncDependencies().empty()) |
866 | p << " ["<< getAsyncDependencies() << ']'; |
867 | } |
868 | // Print the launch configuration. |
869 | if (hasClusterSize()) { |
870 | p << ' ' << getClustersKeyword(); |
871 | printSizeAssignment(p, getClusterSize().value(), |
872 | getClusterSizeOperandValues().value(), |
873 | getClusterIds().value()); |
874 | } |
875 | p << ' ' << getBlocksKeyword(); |
876 | printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(), |
877 | getBlockIds()); |
878 | p << ' ' << getThreadsKeyword(); |
879 | printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(), |
880 | getThreadIds()); |
881 | if (getDynamicSharedMemorySize()) |
882 | p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' |
883 | << getDynamicSharedMemorySize(); |
884 | |
885 | printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); |
886 | printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); |
887 | |
888 | p << ' '; |
889 | |
890 | p.printRegion(getBody(), /*printEntryBlockArgs=*/false); |
891 | p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ |
892 | LaunchOp::getOperandSegmentSizeAttr(), |
893 | getNumWorkgroupAttributionsAttrName()}); |
894 | } |
895 | |
896 | // Parse the size assignment blocks for blocks and threads. These have the form |
897 | // (%region_arg, %region_arg, %region_arg) in |
898 | // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) |
899 | // where %region_arg are percent-identifiers for the region arguments to be |
900 | // introduced further (SSA defs), and %operand are percent-identifiers for the |
901 | // SSA value uses. |
902 | static ParseResult |
903 | parseSizeAssignment(OpAsmParser &parser, |
904 | MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes, |
905 | MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes, |
906 | MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) { |
907 | assert(indices.size() == 3 && "space for three indices expected"); |
908 | SmallVector<OpAsmParser::UnresolvedOperand, 3> args; |
909 | if (parser.parseOperandList(result&: args, delimiter: OpAsmParser::Delimiter::Paren, |
910 | /*allowResultNumber=*/false) || |
911 | parser.parseKeyword(keyword: "in") || parser.parseLParen()) |
912 | return failure(); |
913 | std::move(first: args.begin(), last: args.end(), result: indices.begin()); |
914 | |
915 | for (int i = 0; i < 3; ++i) { |
916 | if (i != 0 && parser.parseComma()) |
917 | return failure(); |
918 | if (parser.parseOperand(result&: regionSizes[i], /*allowResultNumber=*/false) || |
919 | parser.parseEqual() || parser.parseOperand(result&: sizes[i])) |
920 | return failure(); |
921 | } |
922 | |
923 | return parser.parseRParen(); |
924 | } |
925 | |
926 | /// Parses a Launch operation. |
927 | /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)? |
928 | /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional) |
929 | /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment |
930 | /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment |
931 | /// memory-attribution |
932 | /// region attr-dict? |
933 | /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` |
934 | ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { |
935 | // Sizes of the grid and block. |
936 | SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands> |
937 | sizes(LaunchOp::kNumConfigOperands); |
938 | |
939 | // Region arguments to be created. |
940 | SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs( |
941 | LaunchOp::kNumConfigRegionAttributes); |
942 | |
943 | // Parse optional async dependencies. |
944 | SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies; |
945 | Type asyncTokenType; |
946 | if (failed( |
947 | parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) || |
948 | parser.resolveOperands(asyncDependencies, asyncTokenType, |
949 | result.operands)) |
950 | return failure(); |
951 | if (parser.getNumResults() > 0) |
952 | result.types.push_back(asyncTokenType); |
953 | |
954 | bool hasCluster = false; |
955 | if (succeeded( |
956 | parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) { |
957 | hasCluster = true; |
958 | sizes.resize(9); |
959 | regionArgs.resize(18); |
960 | } |
961 | MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes); |
962 | MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs); |
963 | |
964 | // Last three segment assigns the cluster size. In the region argument |
965 | // list, this is last 6 arguments. |
966 | if (hasCluster) { |
967 | if (parseSizeAssignment(parser, sizesRef.drop_front(6), |
968 | regionArgsRef.slice(15, 3), |
969 | regionArgsRef.slice(12, 3))) |
970 | return failure(); |
971 | } |
972 | // Parse the size assignment segments: the first segment assigns grid sizes |
973 | // and defines values for block identifiers; the second segment assigns block |
974 | // sizes and defines values for thread identifiers. In the region argument |
975 | // list, identifiers precede sizes, and block-related values precede |
976 | // thread-related values. |
977 | if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || |
978 | parseSizeAssignment(parser, sizesRef.take_front(3), |
979 | regionArgsRef.slice(6, 3), |
980 | regionArgsRef.slice(0, 3)) || |
981 | parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || |
982 | parseSizeAssignment(parser, sizesRef.drop_front(3), |
983 | regionArgsRef.slice(9, 3), |
984 | regionArgsRef.slice(3, 3)) || |
985 | parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), |
986 | result.operands)) |
987 | return failure(); |
988 | |
989 | OpAsmParser::UnresolvedOperand dynamicSharedMemorySize; |
990 | bool hasDynamicSharedMemorySize = false; |
991 | if (!parser.parseOptionalKeyword( |
992 | LaunchOp::getDynamicSharedMemorySizeKeyword())) { |
993 | hasDynamicSharedMemorySize = true; |
994 | if (parser.parseOperand(dynamicSharedMemorySize) || |
995 | parser.resolveOperand(dynamicSharedMemorySize, |
996 | parser.getBuilder().getI32Type(), |
997 | result.operands)) |
998 | return failure(); |
999 | } |
1000 | |
1001 | // Create the region arguments, it has kNumConfigRegionAttributes arguments |
1002 | // that correspond to block/thread identifiers and grid/block sizes, all |
1003 | // having `index` type, a variadic number of WorkGroup Attributions and |
1004 | // a variadic number of Private Attributions. The number of WorkGroup |
1005 | // Attributions is stored in the attr with name: |
1006 | // LaunchOp::getNumWorkgroupAttributionsAttrName(). |
1007 | Type index = parser.getBuilder().getIndexType(); |
1008 | SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( |
1009 | LaunchOp::kNumConfigRegionAttributes + 6, index); |
1010 | |
1011 | SmallVector<OpAsmParser::Argument> regionArguments; |
1012 | for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) { |
1013 | OpAsmParser::Argument arg; |
1014 | arg.ssaName = std::get<0>(ssaValueAndType); |
1015 | arg.type = std::get<1>(ssaValueAndType); |
1016 | regionArguments.push_back(arg); |
1017 | } |
1018 | |
1019 | Builder &builder = parser.getBuilder(); |
1020 | // Parse workgroup memory attributions. |
1021 | if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(), |
1022 | regionArguments))) |
1023 | return failure(); |
1024 | |
1025 | // Store the number of operands we just parsed as the number of workgroup |
1026 | // memory attributions. |
1027 | unsigned numWorkgroupAttrs = regionArguments.size() - |
1028 | LaunchOp::kNumConfigRegionAttributes - |
1029 | (hasCluster ? 6 : 0); |
1030 | result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(), |
1031 | builder.getI64IntegerAttr(numWorkgroupAttrs)); |
1032 | |
1033 | // Parse private memory attributions. |
1034 | if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(), |
1035 | regionArguments))) |
1036 | return failure(); |
1037 | |
1038 | // Introduce the body region and parse it. The region has |
1039 | // kNumConfigRegionAttributes arguments that correspond to |
1040 | // block/thread identifiers and grid/block sizes, all having `index` type. |
1041 | Region *body = result.addRegion(); |
1042 | if (parser.parseRegion(*body, regionArguments) || |
1043 | parser.parseOptionalAttrDict(result.attributes)) |
1044 | return failure(); |
1045 | |
1046 | SmallVector<int32_t, 11> segmentSizes(11, 1); |
1047 | segmentSizes.front() = asyncDependencies.size(); |
1048 | |
1049 | if (!hasCluster) { |
1050 | segmentSizes[7] = 0; |
1051 | segmentSizes[8] = 0; |
1052 | segmentSizes[9] = 0; |
1053 | } |
1054 | segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0; |
1055 | result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(), |
1056 | parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); |
1057 | return success(); |
1058 | } |
1059 | |
1060 | /// Simplify the gpu.launch when the range of a thread or block ID is |
1061 | /// trivially known to be one. |
1062 | struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> { |
1063 | using OpRewritePattern<LaunchOp>::OpRewritePattern; |
1064 | LogicalResult matchAndRewrite(LaunchOp op, |
1065 | PatternRewriter &rewriter) const override { |
1066 | // If the range implies a single value for `id`, replace `id`'s uses by |
1067 | // zero. |
1068 | Value zero; |
1069 | bool simplified = false; |
1070 | auto constPropIdUses = [&](Value id, Value size) { |
1071 | // Check if size is trivially one. |
1072 | if (!matchPattern(value: size, pattern: m_One())) |
1073 | return; |
1074 | if (id.getUses().empty()) |
1075 | return; |
1076 | if (!simplified) { |
1077 | // Create a zero value the first time. |
1078 | OpBuilder::InsertionGuard guard(rewriter); |
1079 | rewriter.setInsertionPointToStart(&op.getBody().front()); |
1080 | zero = |
1081 | rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0); |
1082 | } |
1083 | rewriter.replaceAllUsesWith(from: id, to: zero); |
1084 | simplified = true; |
1085 | }; |
1086 | constPropIdUses(op.getBlockIds().x, op.getGridSizeX()); |
1087 | constPropIdUses(op.getBlockIds().y, op.getGridSizeY()); |
1088 | constPropIdUses(op.getBlockIds().z, op.getGridSizeZ()); |
1089 | constPropIdUses(op.getThreadIds().x, op.getBlockSizeX()); |
1090 | constPropIdUses(op.getThreadIds().y, op.getBlockSizeY()); |
1091 | constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ()); |
1092 | |
1093 | return success(IsSuccess: simplified); |
1094 | } |
1095 | }; |
1096 | |
1097 | void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites, |
1098 | MLIRContext *context) { |
1099 | rewrites.add<FoldLaunchArguments>(context); |
1100 | } |
1101 | |
1102 | /// Adds a new block argument that corresponds to buffers located in |
1103 | /// workgroup memory. |
1104 | BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) { |
1105 | auto attrName = getNumWorkgroupAttributionsAttrName(); |
1106 | auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); |
1107 | (*this)->setAttr(attrName, |
1108 | IntegerAttr::get(attr.getType(), attr.getValue() + 1)); |
1109 | return getBody().insertArgument( |
1110 | LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc); |
1111 | } |
1112 | |
1113 | /// Adds a new block argument that corresponds to buffers located in |
1114 | /// private memory. |
1115 | BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) { |
1116 | // Buffers on the private memory always come after buffers on the workgroup |
1117 | // memory. |
1118 | return getBody().addArgument(type, loc); |
1119 | } |
1120 | |
1121 | //===----------------------------------------------------------------------===// |
1122 | // LaunchFuncOp |
1123 | //===----------------------------------------------------------------------===// |
1124 | |
1125 | void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, |
1126 | SymbolRefAttr kernelSymbol, KernelDim3 gridSize, |
1127 | KernelDim3 getBlockSize, Value dynamicSharedMemorySize, |
1128 | ValueRange kernelOperands, Type asyncTokenType, |
1129 | ValueRange asyncDependencies, |
1130 | std::optional<KernelDim3> clusterSize) { |
1131 | assert(kernelSymbol.getNestedReferences().size() == 1 && |
1132 | "expected a symbol reference with a single nested reference"); |
1133 | result.addOperands(asyncDependencies); |
1134 | if (asyncTokenType) |
1135 | result.types.push_back(builder.getType<AsyncTokenType>()); |
1136 | |
1137 | // Add grid and block sizes as op operands, followed by the data operands. |
1138 | result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, |
1139 | getBlockSize.y, getBlockSize.z}); |
1140 | if (clusterSize.has_value()) |
1141 | result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); |
1142 | if (dynamicSharedMemorySize) |
1143 | result.addOperands(dynamicSharedMemorySize); |
1144 | result.addOperands(kernelOperands); |
1145 | |
1146 | Properties &prop = result.getOrAddProperties<Properties>(); |
1147 | prop.kernel = kernelSymbol; |
1148 | size_t segmentSizesLen = std::size(prop.operandSegmentSizes); |
1149 | // Initialize the segment sizes to 1. |
1150 | for (auto &sz : prop.operandSegmentSizes) |
1151 | sz = 1; |
1152 | prop.operandSegmentSizes[0] = asyncDependencies.size(); |
1153 | if (!clusterSize.has_value()) { |
1154 | prop.operandSegmentSizes[segmentSizesLen - 4] = 0; |
1155 | prop.operandSegmentSizes[segmentSizesLen - 5] = 0; |
1156 | prop.operandSegmentSizes[segmentSizesLen - 6] = 0; |
1157 | } |
1158 | prop.operandSegmentSizes[segmentSizesLen - 3] = |
1159 | dynamicSharedMemorySize ? 1 : 0; |
1160 | prop.operandSegmentSizes[segmentSizesLen - 2] = |
1161 | static_cast<int32_t>(kernelOperands.size()); |
1162 | prop.operandSegmentSizes[segmentSizesLen - 1] = 0; |
1163 | } |
1164 | |
1165 | void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, |
1166 | GPUFuncOp kernelFunc, KernelDim3 gridSize, |
1167 | KernelDim3 getBlockSize, Value dynamicSharedMemorySize, |
1168 | ValueRange kernelOperands, Type asyncTokenType, |
1169 | ValueRange asyncDependencies, |
1170 | std::optional<KernelDim3> clusterSize) { |
1171 | auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); |
1172 | auto kernelSymbol = |
1173 | SymbolRefAttr::get(kernelModule.getNameAttr(), |
1174 | {SymbolRefAttr::get(kernelFunc.getNameAttr())}); |
1175 | build(builder, result, kernelSymbol, gridSize, getBlockSize, |
1176 | dynamicSharedMemorySize, kernelOperands, asyncTokenType, |
1177 | asyncDependencies, clusterSize); |
1178 | } |
1179 | |
1180 | void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, |
1181 | SymbolRefAttr kernel, KernelDim3 gridSize, |
1182 | KernelDim3 getBlockSize, Value dynamicSharedMemorySize, |
1183 | ValueRange kernelOperands, Value asyncObject, |
1184 | std::optional<KernelDim3> clusterSize) { |
1185 | // Add grid and block sizes as op operands, followed by the data operands. |
1186 | result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, |
1187 | getBlockSize.y, getBlockSize.z}); |
1188 | if (clusterSize.has_value()) |
1189 | result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); |
1190 | if (dynamicSharedMemorySize) |
1191 | result.addOperands(dynamicSharedMemorySize); |
1192 | result.addOperands(kernelOperands); |
1193 | if (asyncObject) |
1194 | result.addOperands(asyncObject); |
1195 | Properties &prop = result.getOrAddProperties<Properties>(); |
1196 | prop.kernel = kernel; |
1197 | size_t segmentSizesLen = std::size(prop.operandSegmentSizes); |
1198 | // Initialize the segment sizes to 1. |
1199 | for (auto &sz : prop.operandSegmentSizes) |
1200 | sz = 1; |
1201 | prop.operandSegmentSizes[0] = 0; |
1202 | if (!clusterSize.has_value()) { |
1203 | prop.operandSegmentSizes[segmentSizesLen - 4] = 0; |
1204 | prop.operandSegmentSizes[segmentSizesLen - 5] = 0; |
1205 | prop.operandSegmentSizes[segmentSizesLen - 6] = 0; |
1206 | } |
1207 | prop.operandSegmentSizes[segmentSizesLen - 3] = |
1208 | dynamicSharedMemorySize ? 1 : 0; |
1209 | prop.operandSegmentSizes[segmentSizesLen - 2] = |
1210 | static_cast<int32_t>(kernelOperands.size()); |
1211 | prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0; |
1212 | } |
1213 | |
1214 | StringAttr LaunchFuncOp::getKernelModuleName() { |
1215 | return getKernel().getRootReference(); |
1216 | } |
1217 | |
1218 | StringAttr LaunchFuncOp::getKernelName() { |
1219 | return getKernel().getLeafReference(); |
1220 | } |
1221 | |
1222 | unsigned LaunchFuncOp::getNumKernelOperands() { |
1223 | return getKernelOperands().size(); |
1224 | } |
1225 | |
1226 | Value LaunchFuncOp::getKernelOperand(unsigned i) { |
1227 | return getKernelOperands()[i]; |
1228 | } |
1229 | |
1230 | KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { |
1231 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
1232 | return KernelDim3{operands[0], operands[1], operands[2]}; |
1233 | } |
1234 | |
1235 | KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { |
1236 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
1237 | return KernelDim3{operands[3], operands[4], operands[5]}; |
1238 | } |
1239 | |
1240 | KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() { |
1241 | assert(hasClusterSize() && |
1242 | "cluster size is not set, check hasClusterSize() first"); |
1243 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
1244 | return KernelDim3{operands[6], operands[7], operands[8]}; |
1245 | } |
1246 | |
1247 | LogicalResult LaunchFuncOp::verify() { |
1248 | auto module = (*this)->getParentOfType<ModuleOp>(); |
1249 | if (!module) |
1250 | return emitOpError("expected to belong to a module"); |
1251 | |
1252 | if (!module->getAttrOfType<UnitAttr>( |
1253 | GPUDialect::getContainerModuleAttrName())) |
1254 | return emitOpError("expected the closest surrounding module to have the '"+ |
1255 | GPUDialect::getContainerModuleAttrName() + |
1256 | "' attribute"); |
1257 | |
1258 | if (hasClusterSize()) { |
1259 | if (getClusterSizeY().getType() != getClusterSizeX().getType() || |
1260 | getClusterSizeZ().getType() != getClusterSizeX().getType()) |
1261 | return emitOpError() |
1262 | << "expects types of the cluster dimensions must be the same"; |
1263 | } |
1264 | |
1265 | return success(); |
1266 | } |
1267 | |
1268 | static ParseResult |
1269 | parseLaunchDimType(OpAsmParser &parser, Type &dimTy, |
1270 | std::optional<OpAsmParser::UnresolvedOperand> clusterValue, |
1271 | Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) { |
1272 | if (succeeded(Result: parser.parseOptionalColon())) { |
1273 | if (parser.parseType(result&: dimTy)) |
1274 | return failure(); |
1275 | } else { |
1276 | dimTy = IndexType::get(parser.getContext()); |
1277 | } |
1278 | if (clusterValue.has_value()) { |
1279 | clusterXTy = clusterYTy = clusterZTy = dimTy; |
1280 | } |
1281 | return success(); |
1282 | } |
1283 | |
1284 | static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, |
1285 | Value clusterValue, Type clusterXTy, |
1286 | Type clusterYTy, Type clusterZTy) { |
1287 | if (!dimTy.isIndex()) |
1288 | printer << ": "<< dimTy; |
1289 | } |
1290 | |
1291 | static ParseResult parseLaunchFuncOperands( |
1292 | OpAsmParser &parser, |
1293 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, |
1294 | SmallVectorImpl<Type> &argTypes) { |
1295 | if (parser.parseOptionalKeyword(keyword: "args")) |
1296 | return success(); |
1297 | |
1298 | auto parseElement = [&]() -> ParseResult { |
1299 | return failure(IsFailure: parser.parseOperand(result&: argNames.emplace_back()) || |
1300 | parser.parseColonType(result&: argTypes.emplace_back())); |
1301 | }; |
1302 | |
1303 | return parser.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, |
1304 | parseElementFn: parseElement, contextMessage: " in argument list"); |
1305 | } |
1306 | |
1307 | static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, |
1308 | OperandRange operands, TypeRange types) { |
1309 | if (operands.empty()) |
1310 | return; |
1311 | printer << "args("; |
1312 | llvm::interleaveComma(c: llvm::zip_equal(t&: operands, u&: types), os&: printer, |
1313 | each_fn: [&](const auto &pair) { |
1314 | auto [operand, type] = pair; |
1315 | printer << operand << " : "<< type; |
1316 | }); |
1317 | printer << ")"; |
1318 | } |
1319 | |
1320 | //===----------------------------------------------------------------------===// |
1321 | // ShuffleOp |
1322 | //===----------------------------------------------------------------------===// |
1323 | |
1324 | void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, |
1325 | int32_t offset, int32_t width, ShuffleMode mode) { |
1326 | build(builder, result, value, |
1327 | builder.create<arith::ConstantOp>(result.location, |
1328 | builder.getI32IntegerAttr(offset)), |
1329 | builder.create<arith::ConstantOp>(result.location, |
1330 | builder.getI32IntegerAttr(width)), |
1331 | mode); |
1332 | } |
1333 | |
1334 | //===----------------------------------------------------------------------===// |
1335 | // BarrierOp |
1336 | //===----------------------------------------------------------------------===// |
1337 | |
1338 | namespace { |
1339 | |
1340 | /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized! |
1341 | LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, |
1342 | PatternRewriter &rewriter) { |
1343 | if (isa_and_nonnull<BarrierOp>(op->getNextNode())) { |
1344 | rewriter.eraseOp(op: op); |
1345 | return success(); |
1346 | } |
1347 | return failure(); |
1348 | } |
1349 | |
1350 | } // end anonymous namespace |
1351 | |
1352 | void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1353 | MLIRContext *context) { |
1354 | results.add(eraseRedundantGpuBarrierOps); |
1355 | } |
1356 | |
1357 | //===----------------------------------------------------------------------===// |
1358 | // GPUFuncOp |
1359 | //===----------------------------------------------------------------------===// |
1360 | |
1361 | /// Adds a new block argument that corresponds to buffers located in |
1362 | /// workgroup memory. |
1363 | BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) { |
1364 | auto attrName = getNumWorkgroupAttributionsAttrName(); |
1365 | auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); |
1366 | (*this)->setAttr(attrName, |
1367 | IntegerAttr::get(attr.getType(), attr.getValue() + 1)); |
1368 | return getBody().insertArgument( |
1369 | getFunctionType().getNumInputs() + attr.getInt(), type, loc); |
1370 | } |
1371 | |
1372 | /// Adds a new block argument that corresponds to buffers located in |
1373 | /// private memory. |
1374 | BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) { |
1375 | // Buffers on the private memory always come after buffers on the workgroup |
1376 | // memory. |
1377 | return getBody().addArgument(type, loc); |
1378 | } |
1379 | |
1380 | void GPUFuncOp::build(OpBuilder &builder, OperationState &result, |
1381 | StringRef name, FunctionType type, |
1382 | TypeRange workgroupAttributions, |
1383 | TypeRange privateAttributions, |
1384 | ArrayRef<NamedAttribute> attrs) { |
1385 | OpBuilder::InsertionGuard g(builder); |
1386 | |
1387 | result.addAttribute(SymbolTable::getSymbolAttrName(), |
1388 | builder.getStringAttr(name)); |
1389 | result.addAttribute(getFunctionTypeAttrName(result.name), |
1390 | TypeAttr::get(type)); |
1391 | result.addAttribute(getNumWorkgroupAttributionsAttrName(), |
1392 | builder.getI64IntegerAttr(workgroupAttributions.size())); |
1393 | result.addAttributes(attrs); |
1394 | Region *body = result.addRegion(); |
1395 | Block *entryBlock = builder.createBlock(body); |
1396 | |
1397 | // TODO: Allow passing in proper locations here. |
1398 | for (Type argTy : type.getInputs()) |
1399 | entryBlock->addArgument(argTy, result.location); |
1400 | for (Type argTy : workgroupAttributions) |
1401 | entryBlock->addArgument(argTy, result.location); |
1402 | for (Type argTy : privateAttributions) |
1403 | entryBlock->addArgument(argTy, result.location); |
1404 | } |
1405 | |
1406 | /// Parses a GPU function memory attribution. |
1407 | /// |
1408 | /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? |
1409 | /// (`private` `(` ssa-id-and-type-list `)`)? |
1410 | /// |
1411 | /// Note that this function parses only one of the two similar parts, with the |
1412 | /// keyword provided as argument. |
1413 | static ParseResult |
1414 | parseAttributions(OpAsmParser &parser, StringRef keyword, |
1415 | SmallVectorImpl<OpAsmParser::Argument> &args, |
1416 | Attribute &attributionAttrs) { |
1417 | // If we could not parse the keyword, just assume empty list and succeed. |
1418 | if (failed(Result: parser.parseOptionalKeyword(keyword))) |
1419 | return success(); |
1420 | |
1421 | size_t existingArgs = args.size(); |
1422 | ParseResult result = |
1423 | parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren, |
1424 | /*allowType=*/true, /*allowAttrs=*/true); |
1425 | if (failed(Result: result)) |
1426 | return result; |
1427 | |
1428 | bool hadAttrs = llvm::any_of(Range: ArrayRef(args).drop_front(N: existingArgs), |
1429 | P: [](const OpAsmParser::Argument &arg) -> bool { |
1430 | return arg.attrs && !arg.attrs.empty(); |
1431 | }); |
1432 | if (!hadAttrs) { |
1433 | attributionAttrs = nullptr; |
1434 | return result; |
1435 | } |
1436 | |
1437 | Builder &builder = parser.getBuilder(); |
1438 | SmallVector<Attribute> attributionAttrsVec; |
1439 | for (const auto &argument : ArrayRef(args).drop_front(N: existingArgs)) { |
1440 | if (!argument.attrs) |
1441 | attributionAttrsVec.push_back(builder.getDictionaryAttr({})); |
1442 | else |
1443 | attributionAttrsVec.push_back(Elt: argument.attrs); |
1444 | } |
1445 | attributionAttrs = builder.getArrayAttr(attributionAttrsVec); |
1446 | return result; |
1447 | } |
1448 | |
1449 | /// Parses a GPU function. |
1450 | /// |
1451 | /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` |
1452 | /// (`->` function-result-list)? memory-attribution `kernel`? |
1453 | /// function-attributes? region |
1454 | ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { |
1455 | SmallVector<OpAsmParser::Argument> entryArgs; |
1456 | SmallVector<DictionaryAttr> resultAttrs; |
1457 | SmallVector<Type> resultTypes; |
1458 | bool isVariadic; |
1459 | |
1460 | // Parse the function name. |
1461 | StringAttr nameAttr; |
1462 | if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), |
1463 | result.attributes)) |
1464 | return failure(); |
1465 | |
1466 | auto signatureLocation = parser.getCurrentLocation(); |
1467 | if (failed(function_interface_impl::parseFunctionSignatureWithArguments( |
1468 | parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, |
1469 | resultAttrs))) |
1470 | return failure(); |
1471 | |
1472 | if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty()) |
1473 | return parser.emitError(signatureLocation) |
1474 | << "gpu.func requires named arguments"; |
1475 | |
1476 | // Construct the function type. More types will be added to the region, but |
1477 | // not to the function type. |
1478 | Builder &builder = parser.getBuilder(); |
1479 | |
1480 | SmallVector<Type> argTypes; |
1481 | for (auto &arg : entryArgs) |
1482 | argTypes.push_back(arg.type); |
1483 | auto type = builder.getFunctionType(argTypes, resultTypes); |
1484 | result.addAttribute(getFunctionTypeAttrName(result.name), |
1485 | TypeAttr::get(type)); |
1486 | |
1487 | call_interface_impl::addArgAndResultAttrs( |
1488 | builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), |
1489 | getResAttrsAttrName(result.name)); |
1490 | |
1491 | Attribute workgroupAttributionAttrs; |
1492 | // Parse workgroup memory attributions. |
1493 | if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), |
1494 | entryArgs, workgroupAttributionAttrs))) |
1495 | return failure(); |
1496 | |
1497 | // Store the number of operands we just parsed as the number of workgroup |
1498 | // memory attributions. |
1499 | unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs(); |
1500 | result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), |
1501 | builder.getI64IntegerAttr(numWorkgroupAttrs)); |
1502 | if (workgroupAttributionAttrs) |
1503 | result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name), |
1504 | workgroupAttributionAttrs); |
1505 | |
1506 | Attribute privateAttributionAttrs; |
1507 | // Parse private memory attributions. |
1508 | if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), |
1509 | entryArgs, privateAttributionAttrs))) |
1510 | return failure(); |
1511 | if (privateAttributionAttrs) |
1512 | result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name), |
1513 | privateAttributionAttrs); |
1514 | |
1515 | // Parse the kernel attribute if present. |
1516 | if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) |
1517 | result.addAttribute(GPUDialect::getKernelFuncAttrName(), |
1518 | builder.getUnitAttr()); |
1519 | |
1520 | // Parse attributes. |
1521 | if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) |
1522 | return failure(); |
1523 | |
1524 | // Parse the region. If no argument names were provided, take all names |
1525 | // (including those of attributions) from the entry block. |
1526 | auto *body = result.addRegion(); |
1527 | return parser.parseRegion(*body, entryArgs); |
1528 | } |
1529 | |
1530 | static void printAttributions(OpAsmPrinter &p, StringRef keyword, |
1531 | ArrayRef<BlockArgument> values, |
1532 | ArrayAttr attributes) { |
1533 | if (values.empty()) |
1534 | return; |
1535 | |
1536 | p << ' ' << keyword << '('; |
1537 | llvm::interleaveComma( |
1538 | c: llvm::enumerate(First&: values), os&: p, each_fn: [&p, attributes](auto pair) { |
1539 | BlockArgument v = pair.value(); |
1540 | p << v << " : "<< v.getType(); |
1541 | |
1542 | size_t attributionIndex = pair.index(); |
1543 | DictionaryAttr attrs; |
1544 | if (attributes && attributionIndex < attributes.size()) |
1545 | attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]); |
1546 | if (attrs) |
1547 | p.printOptionalAttrDict(attrs: attrs.getValue()); |
1548 | }); |
1549 | p << ')'; |
1550 | } |
1551 | |
1552 | void GPUFuncOp::print(OpAsmPrinter &p) { |
1553 | p << ' '; |
1554 | p.printSymbolName(getName()); |
1555 | |
1556 | FunctionType type = getFunctionType(); |
1557 | function_interface_impl::printFunctionSignature(p, *this, type.getInputs(), |
1558 | /*isVariadic=*/false, |
1559 | type.getResults()); |
1560 | |
1561 | printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(), |
1562 | getWorkgroupAttribAttrs().value_or(nullptr)); |
1563 | printAttributions(p, getPrivateKeyword(), getPrivateAttributions(), |
1564 | getPrivateAttribAttrs().value_or(nullptr)); |
1565 | if (isKernel()) |
1566 | p << ' ' << getKernelKeyword(); |
1567 | |
1568 | function_interface_impl::printFunctionAttributes( |
1569 | p, *this, |
1570 | {getNumWorkgroupAttributionsAttrName(), |
1571 | GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(), |
1572 | getArgAttrsAttrName(), getResAttrsAttrName(), |
1573 | getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()}); |
1574 | p << ' '; |
1575 | p.printRegion(getBody(), /*printEntryBlockArgs=*/false); |
1576 | } |
1577 | |
1578 | static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, |
1579 | StringAttr attrName) { |
1580 | auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName)); |
1581 | if (!allAttrs || index >= allAttrs.size()) |
1582 | return DictionaryAttr(); |
1583 | return llvm::cast<DictionaryAttr>(allAttrs[index]); |
1584 | } |
1585 | |
1586 | DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) { |
1587 | return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName()); |
1588 | } |
1589 | |
1590 | DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) { |
1591 | return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName()); |
1592 | } |
1593 | |
1594 | static void setAttributionAttrs(GPUFuncOp op, unsigned index, |
1595 | DictionaryAttr value, StringAttr attrName) { |
1596 | MLIRContext *ctx = op.getContext(); |
1597 | auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName)); |
1598 | SmallVector<Attribute> elements; |
1599 | if (allAttrs) |
1600 | elements.append(allAttrs.begin(), allAttrs.end()); |
1601 | while (elements.size() <= index) |
1602 | elements.push_back(DictionaryAttr::get(ctx)); |
1603 | if (!value) |
1604 | elements[index] = DictionaryAttr::get(ctx); |
1605 | else |
1606 | elements[index] = value; |
1607 | ArrayAttr newValue = ArrayAttr::get(ctx, elements); |
1608 | op->setAttr(attrName, newValue); |
1609 | } |
1610 | |
1611 | void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index, |
1612 | DictionaryAttr value) { |
1613 | setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName()); |
1614 | } |
1615 | |
1616 | void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index, |
1617 | DictionaryAttr value) { |
1618 | setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName()); |
1619 | } |
1620 | |
1621 | static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, |
1622 | StringAttr name, StringAttr attrsName) { |
1623 | DictionaryAttr dict = getAttributionAttrs(op, index, attrsName); |
1624 | if (!dict) |
1625 | return Attribute(); |
1626 | return dict.get(name); |
1627 | } |
1628 | |
1629 | Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index, |
1630 | StringAttr name) { |
1631 | assert(index < getNumWorkgroupAttributions() && |
1632 | "index must map to a workgroup attribution"); |
1633 | return getAttributionAttr(*this, index, name, |
1634 | getWorkgroupAttribAttrsAttrName()); |
1635 | } |
1636 | |
1637 | Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index, |
1638 | StringAttr name) { |
1639 | assert(index < getNumPrivateAttributions() && |
1640 | "index must map to a private attribution"); |
1641 | return getAttributionAttr(*this, index, name, |
1642 | getPrivateAttribAttrsAttrName()); |
1643 | } |
1644 | |
1645 | static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, |
1646 | Attribute value, StringAttr attrsName) { |
1647 | MLIRContext *ctx = op.getContext(); |
1648 | SmallVector<NamedAttribute> elems; |
1649 | DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName); |
1650 | if (oldDict) |
1651 | elems.append(oldDict.getValue().begin(), oldDict.getValue().end()); |
1652 | |
1653 | bool found = false; |
1654 | bool mustSort = true; |
1655 | for (unsigned i = 0, e = elems.size(); i < e; ++i) { |
1656 | if (elems[i].getName() == name) { |
1657 | found = true; |
1658 | if (!value) { |
1659 | std::swap(a&: elems[i], b&: elems[elems.size() - 1]); |
1660 | elems.pop_back(); |
1661 | } else { |
1662 | mustSort = false; |
1663 | elems[i] = NamedAttribute(elems[i].getName(), value); |
1664 | } |
1665 | break; |
1666 | } |
1667 | } |
1668 | if (!found) { |
1669 | if (!value) |
1670 | return; |
1671 | elems.emplace_back(name, value); |
1672 | } |
1673 | if (mustSort) { |
1674 | DictionaryAttr::sortInPlace(elems); |
1675 | } |
1676 | auto newDict = DictionaryAttr::getWithSorted(ctx, elems); |
1677 | setAttributionAttrs(op, index, newDict, attrsName); |
1678 | } |
1679 | |
1680 | void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name, |
1681 | Attribute value) { |
1682 | assert(index < getNumWorkgroupAttributions() && |
1683 | "index must map to a workgroup attribution"); |
1684 | setAttributionAttr(*this, index, name, value, |
1685 | getWorkgroupAttribAttrsAttrName()); |
1686 | } |
1687 | |
1688 | void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name, |
1689 | Attribute value) { |
1690 | assert(index < getNumPrivateAttributions() && |
1691 | "index must map to a private attribution"); |
1692 | setAttributionAttr(*this, index, name, value, |
1693 | getPrivateAttribAttrsAttrName()); |
1694 | } |
1695 | |
1696 | LogicalResult GPUFuncOp::verifyType() { |
1697 | if (isKernel() && getFunctionType().getNumResults() != 0) |
1698 | return emitOpError() << "expected void return type for kernel function"; |
1699 | |
1700 | return success(); |
1701 | } |
1702 | |
1703 | /// Verifies the body of the function. |
1704 | LogicalResult GPUFuncOp::verifyBody() { |
1705 | if (empty()) |
1706 | return emitOpError() << "expected body with at least one block"; |
1707 | unsigned numFuncArguments = getNumArguments(); |
1708 | unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); |
1709 | unsigned numBlockArguments = front().getNumArguments(); |
1710 | if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) |
1711 | return emitOpError() << "expected at least " |
1712 | << numFuncArguments + numWorkgroupAttributions |
1713 | << " arguments to body region"; |
1714 | |
1715 | ArrayRef<Type> funcArgTypes = getFunctionType().getInputs(); |
1716 | for (unsigned i = 0; i < numFuncArguments; ++i) { |
1717 | Type blockArgType = front().getArgument(i).getType(); |
1718 | if (funcArgTypes[i] != blockArgType) |
1719 | return emitOpError() << "expected body region argument #"<< i |
1720 | << " to be of type "<< funcArgTypes[i] << ", got " |
1721 | << blockArgType; |
1722 | } |
1723 | |
1724 | if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), |
1725 | GPUDialect::getWorkgroupAddressSpace())) || |
1726 | failed(verifyAttributions(getOperation(), getPrivateAttributions(), |
1727 | GPUDialect::getPrivateAddressSpace()))) |
1728 | return failure(); |
1729 | |
1730 | return success(); |
1731 | } |
1732 | |
1733 | //===----------------------------------------------------------------------===// |
1734 | // ReturnOp |
1735 | //===----------------------------------------------------------------------===// |
1736 | |
1737 | LogicalResult gpu::ReturnOp::verify() { |
1738 | GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>(); |
1739 | |
1740 | FunctionType funType = function.getFunctionType(); |
1741 | |
1742 | if (funType.getNumResults() != getOperands().size()) |
1743 | return emitOpError() |
1744 | .append("expected ", funType.getNumResults(), " result operands") |
1745 | .attachNote(function.getLoc()) |
1746 | .append("return type declared here"); |
1747 | |
1748 | for (const auto &pair : llvm::enumerate( |
1749 | llvm::zip(function.getFunctionType().getResults(), getOperands()))) { |
1750 | auto [type, operand] = pair.value(); |
1751 | if (type != operand.getType()) |
1752 | return emitOpError() << "unexpected type `"<< operand.getType() |
1753 | << "' for operand #"<< pair.index(); |
1754 | } |
1755 | return success(); |
1756 | } |
1757 | |
1758 | //===----------------------------------------------------------------------===// |
1759 | // GPUModuleOp |
1760 | //===----------------------------------------------------------------------===// |
1761 | |
1762 | void GPUModuleOp::build(OpBuilder &builder, OperationState &result, |
1763 | StringRef name, ArrayAttr targets, |
1764 | Attribute offloadingHandler) { |
1765 | result.addRegion()->emplaceBlock(); |
1766 | Properties &props = result.getOrAddProperties<Properties>(); |
1767 | if (targets) |
1768 | props.targets = targets; |
1769 | props.setSymName(builder.getStringAttr(name)); |
1770 | props.offloadingHandler = offloadingHandler; |
1771 | } |
1772 | |
1773 | void GPUModuleOp::build(OpBuilder &builder, OperationState &result, |
1774 | StringRef name, ArrayRef<Attribute> targets, |
1775 | Attribute offloadingHandler) { |
1776 | build(builder, result, name, |
1777 | targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets), |
1778 | offloadingHandler); |
1779 | } |
1780 | |
1781 | bool GPUModuleOp::hasTarget(Attribute target) { |
1782 | if (ArrayAttr targets = getTargetsAttr()) |
1783 | return llvm::count(targets.getValue(), target); |
1784 | return false; |
1785 | } |
1786 | |
1787 | void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) { |
1788 | ArrayAttr &targetsAttr = getProperties().targets; |
1789 | SmallVector<Attribute> targetsVector(targets); |
1790 | targetsAttr = ArrayAttr::get(getContext(), targetsVector); |
1791 | } |
1792 | |
1793 | LogicalResult GPUModuleOp::verify() { |
1794 | auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets"); |
1795 | |
1796 | if (!targets) |
1797 | return success(); |
1798 | |
1799 | for (auto target : targets) { |
1800 | if (auto verifyTargetAttr = |
1801 | llvm::dyn_cast<TargetAttrVerifyInterface>(target)) { |
1802 | if (verifyTargetAttr.verifyTarget(getOperation()).failed()) |
1803 | return failure(); |
1804 | } |
1805 | } |
1806 | return success(); |
1807 | } |
1808 | |
1809 | //===----------------------------------------------------------------------===// |
1810 | // GPUBinaryOp |
1811 | //===----------------------------------------------------------------------===// |
1812 | void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
1813 | Attribute offloadingHandler, ArrayAttr objects) { |
1814 | auto &properties = result.getOrAddProperties<Properties>(); |
1815 | result.attributes.push_back(builder.getNamedAttr( |
1816 | SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); |
1817 | properties.objects = objects; |
1818 | if (offloadingHandler) |
1819 | properties.offloadingHandler = offloadingHandler; |
1820 | else |
1821 | properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr); |
1822 | } |
1823 | |
1824 | void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
1825 | Attribute offloadingHandler, ArrayRef<Attribute> objects) { |
1826 | build(builder, result, name, offloadingHandler, |
1827 | objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects)); |
1828 | } |
1829 | |
1830 | static ParseResult parseOffloadingHandler(OpAsmParser &parser, |
1831 | Attribute &offloadingHandler) { |
1832 | if (succeeded(Result: parser.parseOptionalLess())) { |
1833 | if (parser.parseAttribute(result&: offloadingHandler)) |
1834 | return failure(); |
1835 | if (parser.parseGreater()) |
1836 | return failure(); |
1837 | } |
1838 | if (!offloadingHandler) |
1839 | offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr); |
1840 | return success(); |
1841 | } |
1842 | |
1843 | static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, |
1844 | Attribute offloadingHandler) { |
1845 | if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr)) |
1846 | printer << '<' << offloadingHandler << '>'; |
1847 | } |
1848 | |
1849 | //===----------------------------------------------------------------------===// |
1850 | // GPUMemcpyOp |
1851 | //===----------------------------------------------------------------------===// |
1852 | |
1853 | LogicalResult MemcpyOp::verify() { |
1854 | auto srcType = getSrc().getType(); |
1855 | auto dstType = getDst().getType(); |
1856 | |
1857 | if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) |
1858 | return emitOpError("arguments have incompatible element type"); |
1859 | |
1860 | if (failed(verifyCompatibleShape(srcType, dstType))) |
1861 | return emitOpError("arguments have incompatible shape"); |
1862 | |
1863 | return success(); |
1864 | } |
1865 | |
1866 | namespace { |
1867 | |
1868 | /// Erases a common case of copy ops where a destination value is used only by |
1869 | /// the copy op, alloc and dealloc ops. |
1870 | struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> { |
1871 | using OpRewritePattern<MemcpyOp>::OpRewritePattern; |
1872 | |
1873 | LogicalResult matchAndRewrite(MemcpyOp op, |
1874 | PatternRewriter &rewriter) const override { |
1875 | Value dest = op.getDst(); |
1876 | Operation *destDefOp = dest.getDefiningOp(); |
1877 | // `dest` must be defined by an op having Allocate memory effect in order to |
1878 | // perform the folding. |
1879 | if (!destDefOp || |
1880 | !hasSingleEffect<MemoryEffects::Allocate>(op: destDefOp, value: dest)) |
1881 | return failure(); |
1882 | // We can erase `op` iff `dest` has no other use apart from its |
1883 | // use by `op` and dealloc ops. |
1884 | if (llvm::any_of(Range: dest.getUsers(), P: [op, dest](Operation *user) { |
1885 | return user != op && |
1886 | !hasSingleEffect<MemoryEffects::Free>(user, dest); |
1887 | })) |
1888 | return failure(); |
1889 | // We can perform the folding if and only if op has a single async |
1890 | // dependency and produces an async token as result, or if it does not have |
1891 | // any async dependency and does not produce any async token result. |
1892 | if (op.getAsyncDependencies().size() > 1 || |
1893 | ((op.getAsyncDependencies().empty() && op.getAsyncToken()) || |
1894 | (!op.getAsyncDependencies().empty() && !op.getAsyncToken()))) |
1895 | return failure(); |
1896 | rewriter.replaceOp(op, op.getAsyncDependencies()); |
1897 | return success(); |
1898 | } |
1899 | }; |
1900 | |
1901 | } // end anonymous namespace |
1902 | |
1903 | void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1904 | MLIRContext *context) { |
1905 | results.add<EraseTrivialCopyOp>(context); |
1906 | } |
1907 | |
1908 | //===----------------------------------------------------------------------===// |
1909 | // GPU_SubgroupMmaLoadMatrixOp |
1910 | //===----------------------------------------------------------------------===// |
1911 | |
1912 | LogicalResult SubgroupMmaLoadMatrixOp::verify() { |
1913 | auto srcType = getSrcMemref().getType(); |
1914 | auto resType = getRes().getType(); |
1915 | auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType); |
1916 | auto operand = resMatrixType.getOperand(); |
1917 | auto srcMemrefType = llvm::cast<MemRefType>(srcType); |
1918 | |
1919 | if (!srcMemrefType.isLastDimUnitStride()) |
1920 | return emitError( |
1921 | "expected source memref most minor dim must have unit stride"); |
1922 | |
1923 | if (operand != "AOp"&& operand != "BOp"&& operand != "COp") |
1924 | return emitError("only AOp, BOp and COp can be loaded"); |
1925 | |
1926 | return success(); |
1927 | } |
1928 | |
1929 | //===----------------------------------------------------------------------===// |
1930 | // GPU_SubgroupMmaStoreMatrixOp |
1931 | //===----------------------------------------------------------------------===// |
1932 | |
1933 | LogicalResult SubgroupMmaStoreMatrixOp::verify() { |
1934 | auto srcType = getSrc().getType(); |
1935 | auto dstType = getDstMemref().getType(); |
1936 | auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType); |
1937 | auto dstMemrefType = llvm::cast<MemRefType>(dstType); |
1938 | |
1939 | if (!dstMemrefType.isLastDimUnitStride()) |
1940 | return emitError( |
1941 | "expected destination memref most minor dim must have unit stride"); |
1942 | |
1943 | if (srcMatrixType.getOperand() != "COp") |
1944 | return emitError( |
1945 | "expected the operand matrix being stored to have 'COp' operand type"); |
1946 | |
1947 | return success(); |
1948 | } |
1949 | |
1950 | //===----------------------------------------------------------------------===// |
1951 | // GPU_SubgroupMmaComputeOp |
1952 | //===----------------------------------------------------------------------===// |
1953 | |
1954 | LogicalResult SubgroupMmaComputeOp::verify() { |
1955 | enum OperandMap { A, B, C }; |
1956 | SmallVector<MMAMatrixType, 3> opTypes; |
1957 | opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType())); |
1958 | opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType())); |
1959 | opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType())); |
1960 | |
1961 | if (opTypes[A].getOperand() != "AOp"|| opTypes[B].getOperand() != "BOp"|| |
1962 | opTypes[C].getOperand() != "COp") |
1963 | return emitError("operands must be in the order AOp, BOp, COp"); |
1964 | |
1965 | ArrayRef<int64_t> aShape, bShape, cShape; |
1966 | aShape = opTypes[A].getShape(); |
1967 | bShape = opTypes[B].getShape(); |
1968 | cShape = opTypes[C].getShape(); |
1969 | |
1970 | if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || |
1971 | bShape[1] != cShape[1]) |
1972 | return emitError("operand shapes do not satisfy matmul constraints"); |
1973 | |
1974 | return success(); |
1975 | } |
1976 | |
1977 | LogicalResult MemcpyOp::fold(FoldAdaptor adaptor, |
1978 | SmallVectorImpl<::mlir::OpFoldResult> &results) { |
1979 | return memref::foldMemRefCast(*this); |
1980 | } |
1981 | |
1982 | LogicalResult MemsetOp::fold(FoldAdaptor adaptor, |
1983 | SmallVectorImpl<::mlir::OpFoldResult> &results) { |
1984 | return memref::foldMemRefCast(*this); |
1985 | } |
1986 | |
1987 | //===----------------------------------------------------------------------===// |
1988 | // GPU_WaitOp |
1989 | //===----------------------------------------------------------------------===// |
1990 | |
1991 | namespace { |
1992 | |
1993 | /// Remove gpu.wait op use of gpu.wait op def without async dependencies. |
1994 | /// %t = gpu.wait async [] // No async dependencies. |
1995 | /// ... gpu.wait ... [%t, ...] // %t can be removed. |
1996 | struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> { |
1997 | public: |
1998 | using OpRewritePattern::OpRewritePattern; |
1999 | |
2000 | LogicalResult matchAndRewrite(WaitOp op, |
2001 | PatternRewriter &rewriter) const final { |
2002 | auto predicate = [](Value value) { |
2003 | auto waitOp = value.getDefiningOp<WaitOp>(); |
2004 | return waitOp && waitOp->getNumOperands() == 0; |
2005 | }; |
2006 | if (llvm::none_of(op.getAsyncDependencies(), predicate)) |
2007 | return failure(); |
2008 | SmallVector<Value> validOperands; |
2009 | for (Value operand : op->getOperands()) { |
2010 | if (predicate(operand)) |
2011 | continue; |
2012 | validOperands.push_back(operand); |
2013 | } |
2014 | rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); }); |
2015 | return success(); |
2016 | } |
2017 | }; |
2018 | |
2019 | /// Simplify trivial gpu.wait ops for the following patterns. |
2020 | /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async |
2021 | /// dependencies). |
2022 | /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with |
2023 | /// %t0. |
2024 | /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async |
2025 | /// dependencies nor return any token. |
2026 | struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> { |
2027 | public: |
2028 | using OpRewritePattern::OpRewritePattern; |
2029 | |
2030 | LogicalResult matchAndRewrite(WaitOp op, |
2031 | PatternRewriter &rewriter) const final { |
2032 | // Erase gpu.wait ops that neither have any async dependencies nor return |
2033 | // any async token. |
2034 | if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) { |
2035 | rewriter.eraseOp(op: op); |
2036 | return success(); |
2037 | } |
2038 | // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op. |
2039 | if (llvm::hasSingleElement(op.getAsyncDependencies()) && |
2040 | op.getAsyncToken()) { |
2041 | rewriter.replaceOp(op, op.getAsyncDependencies()); |
2042 | return success(); |
2043 | } |
2044 | // Erase %t = gpu.wait async ... ops, where %t has no uses. |
2045 | if (op.getAsyncToken() && op.getAsyncToken().use_empty()) { |
2046 | rewriter.eraseOp(op: op); |
2047 | return success(); |
2048 | } |
2049 | return failure(); |
2050 | } |
2051 | }; |
2052 | |
2053 | } // end anonymous namespace |
2054 | |
2055 | void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2056 | MLIRContext *context) { |
2057 | results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context); |
2058 | } |
2059 | |
2060 | //===----------------------------------------------------------------------===// |
2061 | // GPU_AllocOp |
2062 | //===----------------------------------------------------------------------===// |
2063 | |
2064 | LogicalResult AllocOp::verify() { |
2065 | auto memRefType = llvm::cast<MemRefType>(getMemref().getType()); |
2066 | |
2067 | if (getDynamicSizes().size() != memRefType.getNumDynamicDims()) |
2068 | return emitOpError("dimension operand count does not equal memref " |
2069 | "dynamic dimension count"); |
2070 | |
2071 | unsigned numSymbols = 0; |
2072 | if (!memRefType.getLayout().isIdentity()) |
2073 | numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); |
2074 | if (getSymbolOperands().size() != numSymbols) { |
2075 | return emitOpError( |
2076 | "symbol operand count does not equal memref symbol count"); |
2077 | } |
2078 | |
2079 | return success(); |
2080 | } |
2081 | |
2082 | namespace { |
2083 | |
2084 | /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to |
2085 | /// `memref::AllocOp`. |
2086 | struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> { |
2087 | using OpRewritePattern<memref::DimOp>::OpRewritePattern; |
2088 | |
2089 | LogicalResult matchAndRewrite(memref::DimOp dimOp, |
2090 | PatternRewriter &rewriter) const override { |
2091 | std::optional<int64_t> index = dimOp.getConstantIndex(); |
2092 | if (!index) |
2093 | return failure(); |
2094 | |
2095 | auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType()); |
2096 | if (!memrefType || index.value() >= memrefType.getRank() || |
2097 | !memrefType.isDynamicDim(index.value())) |
2098 | return failure(); |
2099 | |
2100 | auto alloc = dimOp.getSource().getDefiningOp<AllocOp>(); |
2101 | if (!alloc) |
2102 | return failure(); |
2103 | |
2104 | Value substituteOp = *(alloc.getDynamicSizes().begin() + |
2105 | memrefType.getDynamicDimIndex(index.value())); |
2106 | rewriter.replaceOp(dimOp, substituteOp); |
2107 | return success(); |
2108 | } |
2109 | }; |
2110 | |
2111 | } // namespace |
2112 | |
2113 | void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2114 | MLIRContext *context) { |
2115 | results.add<SimplifyDimOfAllocOp>(context); |
2116 | } |
2117 | |
2118 | //===----------------------------------------------------------------------===// |
2119 | // GPU object attribute |
2120 | //===----------------------------------------------------------------------===// |
2121 | |
2122 | LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
2123 | Attribute target, CompilationTarget format, |
2124 | StringAttr object, DictionaryAttr properties, |
2125 | KernelTableAttr kernels) { |
2126 | if (!target) |
2127 | return emitError() << "the target attribute cannot be null"; |
2128 | if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) |
2129 | return success(); |
2130 | return emitError() << "the target attribute must implement or promise the " |
2131 | "`gpu::TargetAttrInterface`"; |
2132 | } |
2133 | |
2134 | namespace { |
2135 | ParseResult parseObject(AsmParser &odsParser, CompilationTarget &format, |
2136 | StringAttr &object) { |
2137 | std::optional<CompilationTarget> formatResult; |
2138 | StringRef enumKeyword; |
2139 | auto loc = odsParser.getCurrentLocation(); |
2140 | if (failed(odsParser.parseOptionalKeyword(&enumKeyword))) |
2141 | formatResult = CompilationTarget::Fatbin; |
2142 | if (!formatResult && |
2143 | (formatResult = |
2144 | gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) && |
2145 | odsParser.parseEqual()) |
2146 | return odsParser.emitError(loc, message: "expected an equal sign"); |
2147 | if (!formatResult) |
2148 | return odsParser.emitError(loc, message: "expected keyword for GPU object format"); |
2149 | FailureOr<StringAttr> objectResult = |
2150 | FieldParser<StringAttr>::parse(odsParser); |
2151 | if (failed(objectResult)) |
2152 | return odsParser.emitError(loc: odsParser.getCurrentLocation(), |
2153 | message: "failed to parse GPU_ObjectAttr parameter " |
2154 | "'object' which is to be a `StringAttr`"); |
2155 | format = *formatResult; |
2156 | object = *objectResult; |
2157 | return success(); |
2158 | } |
2159 | |
2160 | void printObject(AsmPrinter &odsParser, CompilationTarget format, |
2161 | StringAttr object) { |
2162 | if (format != CompilationTarget::Fatbin) |
2163 | odsParser << stringifyEnum(format) << " = "; |
2164 | odsParser << object; |
2165 | } |
2166 | } // namespace |
2167 | |
2168 | //===----------------------------------------------------------------------===// |
2169 | // GPU select object attribute |
2170 | //===----------------------------------------------------------------------===// |
2171 | |
2172 | LogicalResult |
2173 | gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
2174 | Attribute target) { |
2175 | // Check `target`, it can be null, an integer attr or a GPU Target attribute. |
2176 | if (target) { |
2177 | if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) { |
2178 | if (intAttr.getInt() < 0) { |
2179 | return emitError() << "the object index must be positive"; |
2180 | } |
2181 | } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) { |
2182 | return emitError() |
2183 | << "the target attribute must be a GPU Target attribute"; |
2184 | } |
2185 | } |
2186 | return success(); |
2187 | } |
2188 | |
2189 | //===----------------------------------------------------------------------===// |
2190 | // DynamicSharedMemoryOp |
2191 | //===----------------------------------------------------------------------===// |
2192 | |
2193 | LogicalResult gpu::DynamicSharedMemoryOp::verify() { |
2194 | if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>()) |
2195 | return emitOpError() << "must be inside an op with symbol table"; |
2196 | |
2197 | MemRefType memrefType = getResultMemref().getType(); |
2198 | // Check address space |
2199 | if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) { |
2200 | return emitOpError() << "address space must be " |
2201 | << gpu::AddressSpaceAttr::getMnemonic() << "<" |
2202 | << stringifyEnum(gpu::AddressSpace::Workgroup) << ">"; |
2203 | } |
2204 | if (memrefType.hasStaticShape()) { |
2205 | return emitOpError() << "result memref type must be memref<?xi8, " |
2206 | "#gpu.address_space<workgroup>>"; |
2207 | } |
2208 | return success(); |
2209 | } |
2210 | |
2211 | //===----------------------------------------------------------------------===// |
2212 | // GPU WarpExecuteOnLane0Op |
2213 | //===----------------------------------------------------------------------===// |
2214 | |
2215 | void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) { |
2216 | p << "("<< getLaneid() << ")"; |
2217 | |
2218 | SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()}; |
2219 | auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName()); |
2220 | p << "["<< llvm::cast<IntegerAttr>(warpSizeAttr).getInt() << "]"; |
2221 | |
2222 | if (!getArgs().empty()) |
2223 | p << " args("<< getArgs() << " : "<< getArgs().getTypes() << ")"; |
2224 | if (!getResults().empty()) |
2225 | p << " -> ("<< getResults().getTypes() << ')'; |
2226 | p << " "; |
2227 | p.printRegion(getRegion(), |
2228 | /*printEntryBlockArgs=*/true, |
2229 | /*printBlockTerminators=*/!getResults().empty()); |
2230 | p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr); |
2231 | } |
2232 | |
2233 | ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, |
2234 | OperationState &result) { |
2235 | // Create the region. |
2236 | result.regions.reserve(1); |
2237 | Region *warpRegion = result.addRegion(); |
2238 | |
2239 | auto &builder = parser.getBuilder(); |
2240 | OpAsmParser::UnresolvedOperand laneId; |
2241 | |
2242 | // Parse predicate operand. |
2243 | if (parser.parseLParen() || |
2244 | parser.parseOperand(laneId, /*allowResultNumber=*/false) || |
2245 | parser.parseRParen()) |
2246 | return failure(); |
2247 | |
2248 | int64_t warpSize; |
2249 | if (parser.parseLSquare() || parser.parseInteger(warpSize) || |
2250 | parser.parseRSquare()) |
2251 | return failure(); |
2252 | result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(), |
2253 | builder.getContext())), |
2254 | builder.getI64IntegerAttr(warpSize)); |
2255 | |
2256 | if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands)) |
2257 | return failure(); |
2258 | |
2259 | llvm::SMLoc inputsOperandsLoc; |
2260 | SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands; |
2261 | SmallVector<Type> inputTypes; |
2262 | if (succeeded(parser.parseOptionalKeyword("args"))) { |
2263 | if (parser.parseLParen()) |
2264 | return failure(); |
2265 | |
2266 | inputsOperandsLoc = parser.getCurrentLocation(); |
2267 | if (parser.parseOperandList(inputsOperands) || |
2268 | parser.parseColonTypeList(inputTypes) || parser.parseRParen()) |
2269 | return failure(); |
2270 | } |
2271 | if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, |
2272 | result.operands)) |
2273 | return failure(); |
2274 | |
2275 | // Parse optional results type list. |
2276 | if (parser.parseOptionalArrowTypeList(result.types)) |
2277 | return failure(); |
2278 | // Parse the region. |
2279 | if (parser.parseRegion(*warpRegion, /*arguments=*/{}, |
2280 | /*argTypes=*/{})) |
2281 | return failure(); |
2282 | WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location); |
2283 | |
2284 | // Parse the optional attribute list. |
2285 | if (parser.parseOptionalAttrDict(result.attributes)) |
2286 | return failure(); |
2287 | return success(); |
2288 | } |
2289 | |
2290 | void WarpExecuteOnLane0Op::getSuccessorRegions( |
2291 | RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { |
2292 | if (!point.isParent()) { |
2293 | regions.push_back(RegionSuccessor(getResults())); |
2294 | return; |
2295 | } |
2296 | |
2297 | // The warp region is always executed |
2298 | regions.push_back(RegionSuccessor(&getWarpRegion())); |
2299 | } |
2300 | |
2301 | void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, |
2302 | TypeRange resultTypes, Value laneId, |
2303 | int64_t warpSize) { |
2304 | build(builder, result, resultTypes, laneId, warpSize, |
2305 | /*operands=*/std::nullopt, /*argTypes=*/std::nullopt); |
2306 | } |
2307 | |
2308 | void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, |
2309 | TypeRange resultTypes, Value laneId, |
2310 | int64_t warpSize, ValueRange args, |
2311 | TypeRange blockArgTypes) { |
2312 | result.addOperands(laneId); |
2313 | result.addAttribute(getAttributeNames()[0], |
2314 | builder.getI64IntegerAttr(warpSize)); |
2315 | result.addTypes(resultTypes); |
2316 | result.addOperands(args); |
2317 | assert(args.size() == blockArgTypes.size()); |
2318 | OpBuilder::InsertionGuard guard(builder); |
2319 | Region *warpRegion = result.addRegion(); |
2320 | Block *block = builder.createBlock(warpRegion); |
2321 | for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args)) |
2322 | block->addArgument(type, arg.getLoc()); |
2323 | } |
2324 | |
2325 | /// Helper check if the distributed vector type is consistent with the expanded |
2326 | /// type and distributed size. |
2327 | static LogicalResult verifyDistributedType(Type expanded, Type distributed, |
2328 | int64_t warpSize, Operation *op) { |
2329 | // If the types matches there is no distribution. |
2330 | if (expanded == distributed) |
2331 | return success(); |
2332 | auto expandedVecType = llvm::dyn_cast<VectorType>(expanded); |
2333 | auto distributedVecType = llvm::dyn_cast<VectorType>(distributed); |
2334 | if (!expandedVecType || !distributedVecType) |
2335 | return op->emitOpError(message: "expected vector type for distributed operands."); |
2336 | if (expandedVecType.getRank() != distributedVecType.getRank() || |
2337 | expandedVecType.getElementType() != distributedVecType.getElementType()) |
2338 | return op->emitOpError( |
2339 | message: "expected distributed vectors to have same rank and element type."); |
2340 | |
2341 | SmallVector<int64_t> scales(expandedVecType.getRank(), 1); |
2342 | for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) { |
2343 | int64_t eDim = expandedVecType.getDimSize(i); |
2344 | int64_t dDim = distributedVecType.getDimSize(i); |
2345 | if (eDim == dDim) |
2346 | continue; |
2347 | if (eDim % dDim != 0) |
2348 | return op->emitOpError() |
2349 | << "expected expanded vector dimension #"<< i << " ("<< eDim |
2350 | << ") to be a multipler of the distributed vector dimension (" |
2351 | << dDim << ")"; |
2352 | scales[i] = eDim / dDim; |
2353 | } |
2354 | if (std::accumulate(first: scales.begin(), last: scales.end(), init: 1, |
2355 | binary_op: std::multiplies<int64_t>()) != warpSize) |
2356 | return op->emitOpError() |
2357 | << "incompatible distribution dimensions from "<< expandedVecType |
2358 | << " to "<< distributedVecType << " with warp size = "<< warpSize; |
2359 | |
2360 | return success(); |
2361 | } |
2362 | |
2363 | LogicalResult WarpExecuteOnLane0Op::verify() { |
2364 | if (getArgs().size() != getWarpRegion().getNumArguments()) |
2365 | return emitOpError( |
2366 | "expected same number op arguments and block arguments."); |
2367 | auto yield = |
2368 | cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator()); |
2369 | if (yield.getNumOperands() != getNumResults()) |
2370 | return emitOpError( |
2371 | "expected same number of yield operands and return values."); |
2372 | int64_t warpSize = getWarpSize(); |
2373 | for (auto [regionArg, arg] : |
2374 | llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) { |
2375 | if (failed(verifyDistributedType(regionArg.getType(), arg.getType(), |
2376 | warpSize, getOperation()))) |
2377 | return failure(); |
2378 | } |
2379 | for (auto [yieldOperand, result] : |
2380 | llvm::zip_equal(yield.getOperands(), getResults())) { |
2381 | if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(), |
2382 | warpSize, getOperation()))) |
2383 | return failure(); |
2384 | } |
2385 | return success(); |
2386 | } |
2387 | bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { |
2388 | return succeeded( |
2389 | verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); |
2390 | } |
2391 | |
2392 | //===----------------------------------------------------------------------===// |
2393 | // GPU KernelMetadataAttr |
2394 | //===----------------------------------------------------------------------===// |
2395 | |
2396 | KernelMetadataAttr KernelMetadataAttr::get(FunctionOpInterface kernel, |
2397 | DictionaryAttr metadata) { |
2398 | assert(kernel && "invalid kernel"); |
2399 | return get(kernel.getNameAttr(), kernel.getFunctionType(), |
2400 | kernel.getAllArgAttrs(), metadata); |
2401 | } |
2402 | |
2403 | KernelMetadataAttr |
2404 | KernelMetadataAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, |
2405 | FunctionOpInterface kernel, |
2406 | DictionaryAttr metadata) { |
2407 | assert(kernel && "invalid kernel"); |
2408 | return getChecked(emitError, kernel.getNameAttr(), kernel.getFunctionType(), |
2409 | kernel.getAllArgAttrs(), metadata); |
2410 | } |
2411 | |
2412 | KernelMetadataAttr |
2413 | KernelMetadataAttr::appendMetadata(ArrayRef<NamedAttribute> attrs) const { |
2414 | if (attrs.empty()) |
2415 | return *this; |
2416 | NamedAttrList attrList; |
2417 | if (DictionaryAttr dict = getMetadata()) |
2418 | attrList.append(dict); |
2419 | attrList.append(attrs); |
2420 | return KernelMetadataAttr::get(getName(), getFunctionType(), getArgAttrs(), |
2421 | attrList.getDictionary(getContext())); |
2422 | } |
2423 | |
2424 | LogicalResult |
2425 | KernelMetadataAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
2426 | StringAttr name, Type functionType, |
2427 | ArrayAttr argAttrs, DictionaryAttr metadata) { |
2428 | if (name.empty()) |
2429 | return emitError() << "the kernel name can't be empty"; |
2430 | if (argAttrs) { |
2431 | if (llvm::any_of(argAttrs, [](Attribute attr) { |
2432 | return !llvm::isa<DictionaryAttr>(attr); |
2433 | })) |
2434 | return emitError() |
2435 | << "all attributes in the array must be a dictionary attribute"; |
2436 | } |
2437 | return success(); |
2438 | } |
2439 | |
2440 | //===----------------------------------------------------------------------===// |
2441 | // GPU KernelTableAttr |
2442 | //===----------------------------------------------------------------------===// |
2443 | |
2444 | KernelTableAttr KernelTableAttr::get(MLIRContext *context, |
2445 | ArrayRef<KernelMetadataAttr> kernels, |
2446 | bool isSorted) { |
2447 | // Note that `is_sorted` is always only invoked once even with assertions ON. |
2448 | assert((!isSorted || llvm::is_sorted(kernels)) && |
2449 | "expected a sorted kernel array"); |
2450 | // Immediately return the attribute if the array is sorted. |
2451 | if (isSorted || llvm::is_sorted(kernels)) |
2452 | return Base::get(context, kernels); |
2453 | // Sort the array. |
2454 | SmallVector<KernelMetadataAttr> kernelsTmp(kernels); |
2455 | llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end()); |
2456 | return Base::get(context, kernelsTmp); |
2457 | } |
2458 | |
2459 | KernelTableAttr KernelTableAttr::getChecked( |
2460 | function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, |
2461 | ArrayRef<KernelMetadataAttr> kernels, bool isSorted) { |
2462 | // Note that `is_sorted` is always only invoked once even with assertions ON. |
2463 | assert((!isSorted || llvm::is_sorted(kernels)) && |
2464 | "expected a sorted kernel array"); |
2465 | // Immediately return the attribute if the array is sorted. |
2466 | if (isSorted || llvm::is_sorted(kernels)) |
2467 | return Base::getChecked(emitError, context, kernels); |
2468 | // Sort the array. |
2469 | SmallVector<KernelMetadataAttr> kernelsTmp(kernels); |
2470 | llvm::array_pod_sort(kernelsTmp.begin(), kernelsTmp.end()); |
2471 | return Base::getChecked(emitError, context, kernelsTmp); |
2472 | } |
2473 | |
2474 | LogicalResult |
2475 | KernelTableAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
2476 | ArrayRef<KernelMetadataAttr> kernels) { |
2477 | if (kernels.size() < 2) |
2478 | return success(); |
2479 | // Check that the kernels are uniquely named. |
2480 | if (std::adjacent_find(kernels.begin(), kernels.end(), |
2481 | [](KernelMetadataAttr l, KernelMetadataAttr r) { |
2482 | return l.getName() == r.getName(); |
2483 | }) != kernels.end()) { |
2484 | return emitError() << "expected all kernels to be uniquely named"; |
2485 | } |
2486 | return success(); |
2487 | } |
2488 | |
2489 | KernelMetadataAttr KernelTableAttr::lookup(StringRef key) const { |
2490 | auto [iterator, found] = impl::findAttrSorted(begin(), end(), key); |
2491 | return found ? *iterator : KernelMetadataAttr(); |
2492 | } |
2493 | |
2494 | KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const { |
2495 | auto [iterator, found] = impl::findAttrSorted(begin(), end(), key); |
2496 | return found ? *iterator : KernelMetadataAttr(); |
2497 | } |
2498 | |
2499 | //===----------------------------------------------------------------------===// |
2500 | // GPU target options |
2501 | //===----------------------------------------------------------------------===// |
2502 | |
2503 | TargetOptions::TargetOptions( |
2504 | StringRef toolkitPath, ArrayRef<Attribute> librariesToLink, |
2505 | StringRef cmdOptions, StringRef elfSection, |
2506 | CompilationTarget compilationTarget, |
2507 | function_ref<SymbolTable *()> getSymbolTableCallback, |
2508 | function_ref<void(llvm::Module &)> initialLlvmIRCallback, |
2509 | function_ref<void(llvm::Module &)> linkedLlvmIRCallback, |
2510 | function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, |
2511 | function_ref<void(StringRef)> isaCallback) |
2512 | : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, librariesToLink, |
2513 | cmdOptions, elfSection, compilationTarget, |
2514 | getSymbolTableCallback, initialLlvmIRCallback, |
2515 | linkedLlvmIRCallback, optimizedLlvmIRCallback, |
2516 | isaCallback) {} |
2517 | |
2518 | TargetOptions::TargetOptions( |
2519 | TypeID typeID, StringRef toolkitPath, ArrayRef<Attribute> librariesToLink, |
2520 | StringRef cmdOptions, StringRef elfSection, |
2521 | CompilationTarget compilationTarget, |
2522 | function_ref<SymbolTable *()> getSymbolTableCallback, |
2523 | function_ref<void(llvm::Module &)> initialLlvmIRCallback, |
2524 | function_ref<void(llvm::Module &)> linkedLlvmIRCallback, |
2525 | function_ref<void(llvm::Module &)> optimizedLlvmIRCallback, |
2526 | function_ref<void(StringRef)> isaCallback) |
2527 | : toolkitPath(toolkitPath.str()), librariesToLink(librariesToLink), |
2528 | cmdOptions(cmdOptions.str()), elfSection(elfSection.str()), |
2529 | compilationTarget(compilationTarget), |
2530 | getSymbolTableCallback(getSymbolTableCallback), |
2531 | initialLlvmIRCallback(initialLlvmIRCallback), |
2532 | linkedLlvmIRCallback(linkedLlvmIRCallback), |
2533 | optimizedLlvmIRCallback(optimizedLlvmIRCallback), |
2534 | isaCallback(isaCallback), typeID(typeID) {} |
2535 | |
2536 | TypeID TargetOptions::getTypeID() const { return typeID; } |
2537 | |
2538 | StringRef TargetOptions::getToolkitPath() const { return toolkitPath; } |
2539 | |
2540 | ArrayRef<Attribute> TargetOptions::getLibrariesToLink() const { |
2541 | return librariesToLink; |
2542 | } |
2543 | |
2544 | StringRef TargetOptions::getCmdOptions() const { return cmdOptions; } |
2545 | |
2546 | StringRef TargetOptions::getELFSection() const { return elfSection; } |
2547 | |
2548 | SymbolTable *TargetOptions::getSymbolTable() const { |
2549 | return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; |
2550 | } |
2551 | |
2552 | function_ref<void(llvm::Module &)> |
2553 | TargetOptions::getInitialLlvmIRCallback() const { |
2554 | return initialLlvmIRCallback; |
2555 | } |
2556 | |
2557 | function_ref<void(llvm::Module &)> |
2558 | TargetOptions::getLinkedLlvmIRCallback() const { |
2559 | return linkedLlvmIRCallback; |
2560 | } |
2561 | |
2562 | function_ref<void(llvm::Module &)> |
2563 | TargetOptions::getOptimizedLlvmIRCallback() const { |
2564 | return optimizedLlvmIRCallback; |
2565 | } |
2566 | |
2567 | function_ref<void(StringRef)> TargetOptions::getISACallback() const { |
2568 | return isaCallback; |
2569 | } |
2570 | |
2571 | CompilationTarget TargetOptions::getCompilationTarget() const { |
2572 | return compilationTarget; |
2573 | } |
2574 | |
2575 | CompilationTarget TargetOptions::getDefaultCompilationTarget() { |
2576 | return CompilationTarget::Fatbin; |
2577 | } |
2578 | |
2579 | std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> |
2580 | TargetOptions::tokenizeCmdOptions(const std::string &cmdOptions) { |
2581 | std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options; |
2582 | llvm::StringSaver stringSaver(options.first); |
2583 | StringRef opts = cmdOptions; |
2584 | // For a correct tokenization of the command line options `opts` must be |
2585 | // unquoted, otherwise the tokenization function returns a single string: the |
2586 | // unquoted `cmdOptions` -which is not the desired behavior. |
2587 | // Remove any quotes if they are at the beginning and end of the string: |
2588 | if (!opts.empty() && opts.front() == '"' && opts.back() == '"') |
2589 | opts.consume_front(Prefix: "\""), opts.consume_back(Suffix: "\""); |
2590 | if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'') |
2591 | opts.consume_front(Prefix: "'"), opts.consume_back(Suffix: "'"); |
2592 | #ifdef _WIN32 |
2593 | llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second, |
2594 | /*MarkEOLs=*/false); |
2595 | #else |
2596 | llvm::cl::TokenizeGNUCommandLine(Source: opts, Saver&: stringSaver, NewArgv&: options.second, |
2597 | /*MarkEOLs=*/false); |
2598 | #endif // _WIN32 |
2599 | return options; |
2600 | } |
2601 | |
2602 | std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> |
2603 | TargetOptions::tokenizeCmdOptions() const { |
2604 | return tokenizeCmdOptions(cmdOptions); |
2605 | } |
2606 | |
2607 | std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> |
2608 | TargetOptions::tokenizeAndRemoveSuffixCmdOptions(llvm::StringRef startsWith) { |
2609 | size_t startPos = cmdOptions.find(svt: startsWith); |
2610 | if (startPos == std::string::npos) |
2611 | return {llvm::BumpPtrAllocator(), SmallVector<const char *>()}; |
2612 | |
2613 | auto tokenized = |
2614 | tokenizeCmdOptions(cmdOptions: cmdOptions.substr(pos: startPos + startsWith.size())); |
2615 | cmdOptions.resize(n: startPos); |
2616 | return tokenized; |
2617 | } |
2618 | |
2619 | MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions) |
2620 | |
2621 | #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc" |
2622 | #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc" |
2623 | |
2624 | #define GET_ATTRDEF_CLASSES |
2625 | #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" |
2626 | |
2627 | #define GET_OP_CLASSES |
2628 | #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" |
2629 | |
2630 | #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc" |
2631 |
Definitions
- get
- getChecked
- getNumDims
- getShape
- getElementType
- getOperand
- isValidElementType
- verifyInvariants
- GPUInlinerInterface
- isLegalToInline
- getSparseHandleKeyword
- verifyKnownLaunchSizeAttr
- parseAsyncDependencies
- printAsyncDependencies
- parseAttributions
- printAttributions
- verifyAttributions
- verifyReduceOpAndType
- canMakeGroupOpUniform
- parseAllReduceOperation
- printAllReduceOperation
- addAsyncDependency
- printSizeAssignment
- parseSizeAssignment
- FoldLaunchArguments
- matchAndRewrite
- parseLaunchDimType
- printLaunchDimType
- parseLaunchFuncOperands
- printLaunchFuncOperands
- eraseRedundantGpuBarrierOps
- parseAttributions
- printAttributions
- getAttributionAttrs
- setAttributionAttrs
- getAttributionAttr
- setAttributionAttr
- parseOffloadingHandler
- printOffloadingHandler
- EraseTrivialCopyOp
- matchAndRewrite
- EraseRedundantGpuWaitOpPairs
- matchAndRewrite
- SimplifyGpuWaitOp
- matchAndRewrite
- SimplifyDimOfAllocOp
- matchAndRewrite
- parseObject
- printObject
- verifyDistributedType
- TargetOptions
- TargetOptions
- getTypeID
- getToolkitPath
- getLibrariesToLink
- getCmdOptions
- getELFSection
- getSymbolTable
- getInitialLlvmIRCallback
- getLinkedLlvmIRCallback
- getOptimizedLlvmIRCallback
- getISACallback
- getCompilationTarget
- getDefaultCompilationTarget
- tokenizeCmdOptions
- tokenizeCmdOptions
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more