1 | //===- GPUDialect.cpp - MLIR Dialect for GPU Kernels implementation -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the GPU kernel-related dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
17 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
18 | #include "mlir/IR/Attributes.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/BuiltinAttributes.h" |
21 | #include "mlir/IR/BuiltinOps.h" |
22 | #include "mlir/IR/BuiltinTypes.h" |
23 | #include "mlir/IR/Diagnostics.h" |
24 | #include "mlir/IR/DialectImplementation.h" |
25 | #include "mlir/IR/Matchers.h" |
26 | #include "mlir/IR/OpImplementation.h" |
27 | #include "mlir/IR/PatternMatch.h" |
28 | #include "mlir/IR/SymbolTable.h" |
29 | #include "mlir/IR/TypeUtilities.h" |
30 | #include "mlir/Interfaces/FunctionImplementation.h" |
31 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
32 | #include "mlir/Support/LogicalResult.h" |
33 | #include "mlir/Transforms/InliningUtils.h" |
34 | #include "llvm/ADT/STLExtras.h" |
35 | #include "llvm/ADT/TypeSwitch.h" |
36 | #include "llvm/Support/CommandLine.h" |
37 | #include "llvm/Support/ErrorHandling.h" |
38 | #include "llvm/Support/StringSaver.h" |
39 | #include <cassert> |
40 | |
41 | using namespace mlir; |
42 | using namespace mlir::gpu; |
43 | |
44 | #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc" |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // GPU Device Mapping Attributes |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | int64_t GPUBlockMappingAttr::getMappingId() const { |
51 | return static_cast<int64_t>(getBlock()); |
52 | } |
53 | |
54 | bool GPUBlockMappingAttr::isLinearMapping() const { |
55 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
56 | } |
57 | |
58 | int64_t GPUBlockMappingAttr::getRelativeIndex() const { |
59 | return isLinearMapping() |
60 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
61 | : getMappingId(); |
62 | } |
63 | |
64 | int64_t GPUWarpgroupMappingAttr::getMappingId() const { |
65 | return static_cast<int64_t>(getWarpgroup()); |
66 | } |
67 | |
68 | bool GPUWarpgroupMappingAttr::isLinearMapping() const { |
69 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
70 | } |
71 | |
72 | int64_t GPUWarpgroupMappingAttr::getRelativeIndex() const { |
73 | return isLinearMapping() |
74 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
75 | : getMappingId(); |
76 | } |
77 | |
78 | int64_t GPUWarpMappingAttr::getMappingId() const { |
79 | return static_cast<int64_t>(getWarp()); |
80 | } |
81 | |
82 | bool GPUWarpMappingAttr::isLinearMapping() const { |
83 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
84 | } |
85 | |
86 | int64_t GPUWarpMappingAttr::getRelativeIndex() const { |
87 | return isLinearMapping() |
88 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
89 | : getMappingId(); |
90 | } |
91 | |
92 | int64_t GPUThreadMappingAttr::getMappingId() const { |
93 | return static_cast<int64_t>(getThread()); |
94 | } |
95 | |
96 | bool GPUThreadMappingAttr::isLinearMapping() const { |
97 | return getMappingId() >= static_cast<int64_t>(MappingId::LinearDim0); |
98 | } |
99 | |
100 | int64_t GPUThreadMappingAttr::getRelativeIndex() const { |
101 | return isLinearMapping() |
102 | ? getMappingId() - static_cast<int64_t>(MappingId::LinearDim0) |
103 | : getMappingId(); |
104 | } |
105 | |
106 | int64_t GPUMemorySpaceMappingAttr::getMappingId() const { |
107 | return static_cast<int64_t>(getAddressSpace()); |
108 | } |
109 | |
110 | bool GPUMemorySpaceMappingAttr::isLinearMapping() const { |
111 | llvm_unreachable("GPUMemorySpaceMappingAttr does not support linear mapping"); |
112 | } |
113 | |
114 | int64_t GPUMemorySpaceMappingAttr::getRelativeIndex() const { |
115 | llvm_unreachable("GPUMemorySpaceMappingAttr does not support relative index"); |
116 | } |
117 | |
118 | //===----------------------------------------------------------------------===// |
119 | // MMAMatrixType |
120 | //===----------------------------------------------------------------------===// |
121 | |
122 | MMAMatrixType MMAMatrixType::get(ArrayRef<int64_t> shape, Type elementType, |
123 | StringRef operand) { |
124 | return Base::get(elementType.getContext(), shape, elementType, operand); |
125 | } |
126 | |
127 | MMAMatrixType |
128 | MMAMatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
129 | ArrayRef<int64_t> shape, Type elementType, |
130 | StringRef operand) { |
131 | return Base::getChecked(emitErrorFn: emitError, ctx: elementType.getContext(), args: shape, |
132 | args: elementType, args: operand); |
133 | } |
134 | |
135 | unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } |
136 | |
137 | ArrayRef<int64_t> MMAMatrixType::getShape() const { |
138 | return getImpl()->getShape(); |
139 | } |
140 | |
141 | Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } |
142 | |
143 | StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } |
144 | |
145 | bool MMAMatrixType::isValidElementType(Type elementType) { |
146 | return elementType.isF16() || elementType.isF32() || |
147 | elementType.isUnsignedInteger(width: 8) || elementType.isSignedInteger(width: 8) || |
148 | elementType.isInteger(width: 32); |
149 | } |
150 | |
151 | LogicalResult |
152 | MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError, |
153 | ArrayRef<int64_t> shape, Type elementType, |
154 | StringRef operand) { |
155 | if (!operand.equals(RHS: "AOp") && !operand.equals(RHS: "BOp") && |
156 | !operand.equals(RHS: "COp")) |
157 | return emitError() << "operand expected to be one of AOp, BOp or COp"; |
158 | |
159 | if (shape.size() != 2) |
160 | return emitError() << "MMAMatrixType must have exactly two dimensions"; |
161 | |
162 | if (!MMAMatrixType::isValidElementType(elementType)) |
163 | return emitError() |
164 | << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; |
165 | |
166 | return success(); |
167 | } |
168 | |
169 | //===----------------------------------------------------------------------===// |
170 | // GPUDialect |
171 | //===----------------------------------------------------------------------===// |
172 | |
173 | bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) { |
174 | if (!memorySpace) |
175 | return false; |
176 | if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) |
177 | return gpuAttr.getValue() == getWorkgroupAddressSpace(); |
178 | return false; |
179 | } |
180 | |
181 | bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) { |
182 | Attribute memorySpace = type.getMemorySpace(); |
183 | return isWorkgroupMemoryAddressSpace(memorySpace); |
184 | } |
185 | |
186 | bool GPUDialect::isKernel(Operation *op) { |
187 | UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); |
188 | return static_cast<bool>(isKernelAttr); |
189 | } |
190 | |
191 | namespace { |
192 | /// This class defines the interface for handling inlining with gpu |
193 | /// operations. |
194 | struct GPUInlinerInterface : public DialectInlinerInterface { |
195 | using DialectInlinerInterface::DialectInlinerInterface; |
196 | |
197 | /// All gpu dialect ops can be inlined. |
198 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
199 | return true; |
200 | } |
201 | }; |
202 | } // namespace |
203 | |
204 | void GPUDialect::initialize() { |
205 | addTypes<AsyncTokenType>(); |
206 | addTypes<MMAMatrixType>(); |
207 | addTypes<SparseDnTensorHandleType>(); |
208 | addTypes<SparseSpMatHandleType>(); |
209 | addTypes<SparseSpGEMMOpHandleType>(); |
210 | addOperations< |
211 | #define GET_OP_LIST |
212 | #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" |
213 | >(); |
214 | addAttributes< |
215 | #define GET_ATTRDEF_LIST |
216 | #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" |
217 | >(); |
218 | addInterfaces<GPUInlinerInterface>(); |
219 | declarePromisedInterface<bufferization::BufferDeallocationOpInterface, |
220 | TerminatorOp>(); |
221 | } |
222 | |
223 | static std::string getSparseHandleKeyword(SparseHandleKind kind) { |
224 | switch (kind) { |
225 | case SparseHandleKind::DnTensor: |
226 | return "sparse.dntensor_handle"; |
227 | case SparseHandleKind::SpMat: |
228 | return "sparse.spmat_handle"; |
229 | case SparseHandleKind::SpGEMMOp: |
230 | return "sparse.spgemmop_handle"; |
231 | } |
232 | llvm_unreachable("unknown sparse handle kind"); |
233 | return ""; |
234 | } |
235 | |
236 | Type GPUDialect::parseType(DialectAsmParser &parser) const { |
237 | // Parse the main keyword for the type. |
238 | StringRef keyword; |
239 | if (parser.parseKeyword(&keyword)) |
240 | return Type(); |
241 | MLIRContext *context = getContext(); |
242 | |
243 | // Handle 'async token' types. |
244 | if (keyword == "async.token") |
245 | return AsyncTokenType::get(context); |
246 | |
247 | if (keyword == "mma_matrix") { |
248 | SMLoc beginLoc = parser.getNameLoc(); |
249 | |
250 | // Parse '<'. |
251 | if (parser.parseLess()) |
252 | return nullptr; |
253 | |
254 | // Parse the size and elementType. |
255 | SmallVector<int64_t> shape; |
256 | Type elementType; |
257 | if (parser.parseDimensionList(shape, /*allowDynamic=*/false) || |
258 | parser.parseType(elementType)) |
259 | return nullptr; |
260 | |
261 | // Parse ',' |
262 | if (parser.parseComma()) |
263 | return nullptr; |
264 | |
265 | // Parse operand. |
266 | std::string operand; |
267 | if (failed(parser.parseOptionalString(&operand))) |
268 | return nullptr; |
269 | |
270 | // Parse '>'. |
271 | if (parser.parseGreater()) |
272 | return nullptr; |
273 | |
274 | return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( |
275 | parser.getEncodedSourceLoc(beginLoc)), |
276 | shape, elementType, operand); |
277 | } |
278 | |
279 | if (keyword == getSparseHandleKeyword(SparseHandleKind::DnTensor)) |
280 | return SparseDnTensorHandleType::get(context); |
281 | if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat)) |
282 | return SparseSpMatHandleType::get(context); |
283 | if (keyword == getSparseHandleKeyword(SparseHandleKind::SpGEMMOp)) |
284 | return SparseSpGEMMOpHandleType::get(context); |
285 | |
286 | parser.emitError(parser.getNameLoc(), "unknown gpu type: "+ keyword); |
287 | return Type(); |
288 | } |
289 | // TODO: print refined type here. Notice that should be corresponding to the |
290 | // parser |
291 | void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { |
292 | TypeSwitch<Type>(type) |
293 | .Case<AsyncTokenType>([&](Type) { os << "async.token"; }) |
294 | .Case<SparseDnTensorHandleType>([&](Type) { |
295 | os << getSparseHandleKeyword(SparseHandleKind::DnTensor); |
296 | }) |
297 | .Case<SparseSpMatHandleType>( |
298 | [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); }) |
299 | .Case<SparseSpGEMMOpHandleType>([&](Type) { |
300 | os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp); |
301 | }) |
302 | .Case<MMAMatrixType>([&](MMAMatrixType fragTy) { |
303 | os << "mma_matrix<"; |
304 | auto shape = fragTy.getShape(); |
305 | for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) |
306 | os << *dim << 'x'; |
307 | os << shape.back() << 'x' << fragTy.getElementType(); |
308 | os << ", \""<< fragTy.getOperand() << "\""<< '>'; |
309 | }) |
310 | .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); |
311 | } |
312 | |
313 | LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, |
314 | NamedAttribute attr) { |
315 | if (!llvm::isa<UnitAttr>(attr.getValue()) || |
316 | attr.getName() != getContainerModuleAttrName()) |
317 | return success(); |
318 | |
319 | auto module = dyn_cast<ModuleOp>(op); |
320 | if (!module) |
321 | return op->emitError("expected '") |
322 | << getContainerModuleAttrName() << "' attribute to be attached to '" |
323 | << ModuleOp::getOperationName() << '\''; |
324 | |
325 | auto walkResult = module.walk([&module](LaunchFuncOp launchOp) -> WalkResult { |
326 | // Ignore launches that are nested more or less deep than functions in the |
327 | // module we are currently checking. |
328 | if (!launchOp->getParentOp() || |
329 | launchOp->getParentOp()->getParentOp() != module) |
330 | return success(); |
331 | |
332 | // Ignore launch ops with missing attributes here. The errors will be |
333 | // reported by the verifiers of those ops. |
334 | if (!launchOp->getAttrOfType<SymbolRefAttr>( |
335 | LaunchFuncOp::getKernelAttrName(launchOp->getName()))) |
336 | return success(); |
337 | |
338 | // Check that `launch_func` refers to a well-formed GPU kernel container. |
339 | StringAttr kernelContainerName = launchOp.getKernelModuleName(); |
340 | Operation *kernelContainer = module.lookupSymbol(kernelContainerName); |
341 | if (!kernelContainer) |
342 | return launchOp.emitOpError() |
343 | << "kernel container '"<< kernelContainerName.getValue() |
344 | << "' is undefined"; |
345 | |
346 | // If the container is a GPU binary op return success. |
347 | if (isa<BinaryOp>(kernelContainer)) |
348 | return success(); |
349 | |
350 | auto kernelModule = dyn_cast<GPUModuleOp>(kernelContainer); |
351 | if (!kernelModule) |
352 | return launchOp.emitOpError() |
353 | << "kernel module '"<< kernelContainerName.getValue() |
354 | << "' is undefined"; |
355 | |
356 | // Check that `launch_func` refers to a well-formed kernel function. |
357 | Operation *kernelFunc = module.lookupSymbol(launchOp.getKernelAttr()); |
358 | if (!kernelFunc) |
359 | return launchOp.emitOpError("kernel function '") |
360 | << launchOp.getKernel() << "' is undefined"; |
361 | auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc); |
362 | if (!kernelConvertedFunction) { |
363 | InFlightDiagnostic diag = launchOp.emitOpError() |
364 | << "referenced kernel '"<< launchOp.getKernel() |
365 | << "' is not a function"; |
366 | diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here"; |
367 | return diag; |
368 | } |
369 | |
370 | if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( |
371 | GPUDialect::getKernelFuncAttrName())) |
372 | return launchOp.emitOpError("kernel function is missing the '") |
373 | << GPUDialect::getKernelFuncAttrName() << "' attribute"; |
374 | |
375 | // TODO: If the kernel isn't a GPU function (which happens during separate |
376 | // compilation), do not check type correspondence as it would require the |
377 | // verifier to be aware of the type conversion. |
378 | auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc); |
379 | if (!kernelGPUFunction) |
380 | return success(); |
381 | |
382 | unsigned actualNumArguments = launchOp.getNumKernelOperands(); |
383 | unsigned expectedNumArguments = kernelGPUFunction.getNumArguments(); |
384 | if (expectedNumArguments != actualNumArguments) |
385 | return launchOp.emitOpError("got ") |
386 | << actualNumArguments << " kernel operands but expected " |
387 | << expectedNumArguments; |
388 | |
389 | auto functionType = kernelGPUFunction.getFunctionType(); |
390 | for (unsigned i = 0; i < expectedNumArguments; ++i) { |
391 | if (launchOp.getKernelOperand(i).getType() != functionType.getInput(i)) { |
392 | return launchOp.emitOpError("type of function argument ") |
393 | << i << " does not match"; |
394 | } |
395 | } |
396 | |
397 | return success(); |
398 | }); |
399 | |
400 | return walkResult.wasInterrupted() ? failure() : success(); |
401 | } |
402 | |
403 | /// Parses an optional list of async operands with an optional leading keyword. |
404 | /// (`async`)? (`[` ssa-id-list `]`)? |
405 | /// |
406 | /// This method is used by the tablegen assembly format for async ops as well. |
407 | static ParseResult parseAsyncDependencies( |
408 | OpAsmParser &parser, Type &asyncTokenType, |
409 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) { |
410 | auto loc = parser.getCurrentLocation(); |
411 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "async"))) { |
412 | if (parser.getNumResults() == 0) |
413 | return parser.emitError(loc, message: "needs to be named when marked 'async'"); |
414 | asyncTokenType = parser.getBuilder().getType<AsyncTokenType>(); |
415 | } |
416 | return parser.parseOperandList(result&: asyncDependencies, |
417 | delimiter: OpAsmParser::Delimiter::OptionalSquare); |
418 | } |
419 | |
420 | /// Prints optional async dependencies with its leading keyword. |
421 | /// (`async`)? (`[` ssa-id-list `]`)? |
422 | // Used by the tablegen assembly format for several async ops. |
423 | static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, |
424 | Type asyncTokenType, |
425 | OperandRange asyncDependencies) { |
426 | if (asyncTokenType) |
427 | printer << "async"; |
428 | if (asyncDependencies.empty()) |
429 | return; |
430 | if (asyncTokenType) |
431 | printer << ' '; |
432 | printer << '['; |
433 | llvm::interleaveComma(c: asyncDependencies, os&: printer); |
434 | printer << ']'; |
435 | } |
436 | |
437 | // GPU Memory attributions functions shared by LaunchOp and GPUFuncOp. |
438 | /// Parses a GPU function memory attribution. |
439 | /// |
440 | /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? |
441 | /// (`private` `(` ssa-id-and-type-list `)`)? |
442 | /// |
443 | /// Note that this function parses only one of the two similar parts, with the |
444 | /// keyword provided as argument. |
445 | static ParseResult |
446 | parseAttributions(OpAsmParser &parser, StringRef keyword, |
447 | SmallVectorImpl<OpAsmParser::Argument> &args) { |
448 | // If we could not parse the keyword, just assume empty list and succeed. |
449 | if (failed(result: parser.parseOptionalKeyword(keyword))) |
450 | return success(); |
451 | |
452 | return parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren, |
453 | /*allowType=*/true); |
454 | } |
455 | |
456 | /// Prints a GPU function memory attribution. |
457 | static void printAttributions(OpAsmPrinter &p, StringRef keyword, |
458 | ArrayRef<BlockArgument> values) { |
459 | if (values.empty()) |
460 | return; |
461 | |
462 | p << ' ' << keyword << '('; |
463 | llvm::interleaveComma( |
464 | c: values, os&: p, each_fn: [&p](BlockArgument v) { p << v << " : "<< v.getType(); }); |
465 | p << ')'; |
466 | } |
467 | |
468 | /// Verifies a GPU function memory attribution. |
469 | static LogicalResult verifyAttributions(Operation *op, |
470 | ArrayRef<BlockArgument> attributions, |
471 | gpu::AddressSpace memorySpace) { |
472 | for (Value v : attributions) { |
473 | auto type = llvm::dyn_cast<MemRefType>(v.getType()); |
474 | if (!type) |
475 | return op->emitOpError() << "expected memref type in attribution"; |
476 | |
477 | // We can only verify the address space if it hasn't already been lowered |
478 | // from the AddressSpaceAttr to a target-specific numeric value. |
479 | auto addressSpace = |
480 | llvm::dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace()); |
481 | if (!addressSpace) |
482 | continue; |
483 | if (addressSpace.getValue() != memorySpace) |
484 | return op->emitOpError() |
485 | << "expected memory space "<< stringifyAddressSpace(memorySpace) |
486 | << " in attribution"; |
487 | } |
488 | return success(); |
489 | } |
490 | |
491 | //===----------------------------------------------------------------------===// |
492 | // AllReduceOp |
493 | //===----------------------------------------------------------------------===// |
494 | |
495 | static LogicalResult verifyReduceOpAndType(gpu::AllReduceOperation opName, |
496 | Type resType) { |
497 | using Kind = gpu::AllReduceOperation; |
498 | if (llvm::is_contained( |
499 | {Kind::MINNUMF, Kind::MAXNUMF, Kind::MINIMUMF, Kind::MAXIMUMF}, |
500 | opName)) { |
501 | if (!isa<FloatType>(Val: resType)) |
502 | return failure(); |
503 | } |
504 | |
505 | if (llvm::is_contained({Kind::MINSI, Kind::MINUI, Kind::MAXSI, Kind::MAXUI, |
506 | Kind::AND, Kind::OR, Kind::XOR}, |
507 | opName)) { |
508 | if (!isa<IntegerType>(Val: resType)) |
509 | return failure(); |
510 | } |
511 | |
512 | return success(); |
513 | } |
514 | |
515 | LogicalResult gpu::AllReduceOp::verifyRegions() { |
516 | if (getBody().empty() != getOp().has_value()) |
517 | return emitError("expected either an op attribute or a non-empty body"); |
518 | if (!getBody().empty()) { |
519 | if (getBody().getNumArguments() != 2) |
520 | return emitError("expected two region arguments"); |
521 | for (auto argument : getBody().getArguments()) { |
522 | if (argument.getType() != getType()) |
523 | return emitError("incorrect region argument type"); |
524 | } |
525 | unsigned yieldCount = 0; |
526 | for (Block &block : getBody()) { |
527 | if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) { |
528 | if (yield.getNumOperands() != 1) |
529 | return emitError("expected one gpu.yield operand"); |
530 | if (yield.getOperand(0).getType() != getType()) |
531 | return emitError("incorrect gpu.yield type"); |
532 | ++yieldCount; |
533 | } |
534 | } |
535 | if (yieldCount == 0) |
536 | return emitError("expected gpu.yield op in region"); |
537 | } else { |
538 | gpu::AllReduceOperation opName = *getOp(); |
539 | if (failed(verifyReduceOpAndType(opName, getType()))) { |
540 | return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) |
541 | << "` reduction operation is not compatible with type " |
542 | << getType(); |
543 | } |
544 | } |
545 | |
546 | return success(); |
547 | } |
548 | |
549 | static bool canMakeGroupOpUniform(Operation *op) { |
550 | auto launchOp = dyn_cast<gpu::LaunchOp>(op->getParentOp()); |
551 | if (!launchOp) |
552 | return false; |
553 | |
554 | Region &body = launchOp.getBody(); |
555 | assert(!body.empty() && "Invalid region"); |
556 | |
557 | // Only convert ops in gpu::launch entry block for now. |
558 | return op->getBlock() == &body.front(); |
559 | } |
560 | |
561 | OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) { |
562 | if (!getUniform() && canMakeGroupOpUniform(*this)) { |
563 | setUniform(true); |
564 | return getResult(); |
565 | } |
566 | |
567 | return nullptr; |
568 | } |
569 | |
570 | // TODO: Support optional custom attributes (without dialect prefix). |
571 | static ParseResult parseAllReduceOperation(AsmParser &parser, |
572 | AllReduceOperationAttr &attr) { |
573 | StringRef enumStr; |
574 | if (!parser.parseOptionalKeyword(keyword: &enumStr)) { |
575 | std::optional<AllReduceOperation> op = |
576 | gpu::symbolizeAllReduceOperation(enumStr); |
577 | if (!op) |
578 | return parser.emitError(loc: parser.getCurrentLocation(), message: "invalid op kind"); |
579 | attr = AllReduceOperationAttr::get(parser.getContext(), *op); |
580 | } |
581 | return success(); |
582 | } |
583 | |
584 | static void printAllReduceOperation(AsmPrinter &printer, Operation *op, |
585 | AllReduceOperationAttr attr) { |
586 | if (attr) |
587 | attr.print(printer); |
588 | } |
589 | |
590 | //===----------------------------------------------------------------------===// |
591 | // SubgroupReduceOp |
592 | //===----------------------------------------------------------------------===// |
593 | |
594 | LogicalResult gpu::SubgroupReduceOp::verify() { |
595 | Type elemType = getType(); |
596 | if (auto vecTy = dyn_cast<VectorType>(elemType)) { |
597 | if (vecTy.isScalable()) |
598 | return emitOpError() << "is not compatible with scalable vector types"; |
599 | |
600 | elemType = vecTy.getElementType(); |
601 | } |
602 | |
603 | gpu::AllReduceOperation opName = getOp(); |
604 | if (failed(verifyReduceOpAndType(opName, elemType))) { |
605 | return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) |
606 | << "` reduction operation is not compatible with type " |
607 | << getType(); |
608 | } |
609 | return success(); |
610 | } |
611 | |
612 | OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) { |
613 | if (!getUniform() && canMakeGroupOpUniform(*this)) { |
614 | setUniform(true); |
615 | return getResult(); |
616 | } |
617 | |
618 | return nullptr; |
619 | } |
620 | |
621 | //===----------------------------------------------------------------------===// |
622 | // AsyncOpInterface |
623 | //===----------------------------------------------------------------------===// |
624 | |
625 | void gpu::addAsyncDependency(Operation *op, Value token) { |
626 | op->insertOperands(index: 0, operands: {token}); |
627 | if (!op->template hasTrait<OpTrait::AttrSizedOperandSegments>()) |
628 | return; |
629 | auto attrName = |
630 | OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(); |
631 | auto sizeAttr = op->template getAttrOfType<DenseI32ArrayAttr>(attrName); |
632 | |
633 | // Async dependencies is the only variadic operand. |
634 | if (!sizeAttr) |
635 | return; |
636 | |
637 | SmallVector<int32_t, 8> sizes(sizeAttr.asArrayRef()); |
638 | ++sizes.front(); |
639 | op->setAttr(attrName, Builder(op->getContext()).getDenseI32ArrayAttr(sizes)); |
640 | } |
641 | |
642 | //===----------------------------------------------------------------------===// |
643 | // LaunchOp |
644 | //===----------------------------------------------------------------------===// |
645 | |
646 | void LaunchOp::build(OpBuilder &builder, OperationState &result, |
647 | Value gridSizeX, Value gridSizeY, Value gridSizeZ, |
648 | Value getBlockSizeX, Value getBlockSizeY, |
649 | Value getBlockSizeZ, Value dynamicSharedMemorySize, |
650 | Type asyncTokenType, ValueRange asyncDependencies, |
651 | TypeRange workgroupAttributions, |
652 | TypeRange privateAttributions, Value clusterSizeX, |
653 | Value clusterSizeY, Value clusterSizeZ) { |
654 | OpBuilder::InsertionGuard g(builder); |
655 | |
656 | // Add a WorkGroup attribution attribute. This attribute is required to |
657 | // identify private attributions in the list of block argguments. |
658 | result.addAttribute(getNumWorkgroupAttributionsAttrName(), |
659 | builder.getI64IntegerAttr(workgroupAttributions.size())); |
660 | |
661 | // Add Op operands. |
662 | result.addOperands(asyncDependencies); |
663 | if (asyncTokenType) |
664 | result.types.push_back(builder.getType<AsyncTokenType>()); |
665 | |
666 | // Add grid and block sizes as op operands, followed by the data operands. |
667 | result.addOperands({gridSizeX, gridSizeY, gridSizeZ, getBlockSizeX, |
668 | getBlockSizeY, getBlockSizeZ}); |
669 | if (clusterSizeX) |
670 | result.addOperands(clusterSizeX); |
671 | if (clusterSizeY) |
672 | result.addOperands(clusterSizeY); |
673 | if (clusterSizeZ) |
674 | result.addOperands(clusterSizeZ); |
675 | if (dynamicSharedMemorySize) |
676 | result.addOperands(dynamicSharedMemorySize); |
677 | |
678 | // Create a kernel body region with kNumConfigRegionAttributes + N memory |
679 | // attributions, where the first kNumConfigRegionAttributes arguments have |
680 | // `index` type and the rest have the same types as the data operands. |
681 | Region *kernelRegion = result.addRegion(); |
682 | Block *body = builder.createBlock(kernelRegion); |
683 | // TODO: Allow passing in proper locations here. |
684 | for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i) |
685 | body->addArgument(builder.getIndexType(), result.location); |
686 | // Add WorkGroup & Private attributions to the region arguments. |
687 | for (Type argTy : workgroupAttributions) |
688 | body->addArgument(argTy, result.location); |
689 | for (Type argTy : privateAttributions) |
690 | body->addArgument(argTy, result.location); |
691 | // Fill OperandSegmentSize Attribute. |
692 | SmallVector<int32_t, 11> segmentSizes(11, 1); |
693 | segmentSizes.front() = asyncDependencies.size(); |
694 | segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0; |
695 | segmentSizes[7] = clusterSizeX ? 1 : 0; |
696 | segmentSizes[8] = clusterSizeY ? 1 : 0; |
697 | segmentSizes[9] = clusterSizeZ ? 1 : 0; |
698 | result.addAttribute(getOperandSegmentSizeAttr(), |
699 | builder.getDenseI32ArrayAttr(segmentSizes)); |
700 | } |
701 | |
702 | KernelDim3 LaunchOp::getBlockIds() { |
703 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
704 | auto args = getBody().getArguments(); |
705 | return KernelDim3{args[0], args[1], args[2]}; |
706 | } |
707 | |
708 | KernelDim3 LaunchOp::getThreadIds() { |
709 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
710 | auto args = getBody().getArguments(); |
711 | return KernelDim3{args[3], args[4], args[5]}; |
712 | } |
713 | |
714 | KernelDim3 LaunchOp::getGridSize() { |
715 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
716 | auto args = getBody().getArguments(); |
717 | return KernelDim3{args[6], args[7], args[8]}; |
718 | } |
719 | |
720 | KernelDim3 LaunchOp::getBlockSize() { |
721 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
722 | auto args = getBody().getArguments(); |
723 | return KernelDim3{args[9], args[10], args[11]}; |
724 | } |
725 | |
726 | std::optional<KernelDim3> LaunchOp::getClusterIds() { |
727 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
728 | if (!hasClusterSize()) |
729 | return std::nullopt; |
730 | auto args = getBody().getArguments(); |
731 | return KernelDim3{args[12], args[13], args[14]}; |
732 | } |
733 | |
734 | std::optional<KernelDim3> LaunchOp::getClusterSize() { |
735 | assert(!getBody().empty() && "LaunchOp body must not be empty."); |
736 | if (!hasClusterSize()) |
737 | return std::nullopt; |
738 | auto args = getBody().getArguments(); |
739 | return KernelDim3{args[15], args[16], args[17]}; |
740 | } |
741 | |
742 | KernelDim3 LaunchOp::getGridSizeOperandValues() { |
743 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
744 | return KernelDim3{operands[0], operands[1], operands[2]}; |
745 | } |
746 | |
747 | KernelDim3 LaunchOp::getBlockSizeOperandValues() { |
748 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
749 | return KernelDim3{operands[3], operands[4], operands[5]}; |
750 | } |
751 | |
752 | std::optional<KernelDim3> LaunchOp::getClusterSizeOperandValues() { |
753 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
754 | if (!hasClusterSize()) |
755 | return std::nullopt; |
756 | return KernelDim3{operands[6], operands[7], operands[8]}; |
757 | } |
758 | |
759 | LogicalResult LaunchOp::verify() { |
760 | if (!(hasClusterSize()) && |
761 | (getClusterSizeX() || getClusterSizeY() || getClusterSizeZ())) |
762 | return emitOpError() << "cluster size must be all present"; |
763 | return success(); |
764 | } |
765 | |
766 | LogicalResult LaunchOp::verifyRegions() { |
767 | // Kernel launch takes kNumConfigOperands leading operands for grid/block |
768 | // sizes and transforms them into kNumConfigRegionAttributes region arguments |
769 | // for block/thread identifiers and grid/block sizes. |
770 | if (!getBody().empty()) { |
771 | if (getBody().getNumArguments() < |
772 | kNumConfigRegionAttributes + getNumWorkgroupAttributions()) |
773 | return emitOpError("unexpected number of region arguments"); |
774 | } |
775 | |
776 | // Verify Attributions Address Spaces. |
777 | if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), |
778 | GPUDialect::getWorkgroupAddressSpace())) || |
779 | failed(verifyAttributions(getOperation(), getPrivateAttributions(), |
780 | GPUDialect::getPrivateAddressSpace()))) |
781 | return failure(); |
782 | |
783 | // Block terminators without successors are expected to exit the kernel region |
784 | // and must be `gpu.terminator`. |
785 | for (Block &block : getBody()) { |
786 | if (block.empty()) |
787 | continue; |
788 | if (block.back().getNumSuccessors() != 0) |
789 | continue; |
790 | if (!isa<gpu::TerminatorOp>(&block.back())) { |
791 | return block.back() |
792 | .emitError() |
793 | .append("expected '", gpu::TerminatorOp::getOperationName(), |
794 | "' or a terminator with successors") |
795 | .attachNote(getLoc()) |
796 | .append("in '", LaunchOp::getOperationName(), "' body region"); |
797 | } |
798 | } |
799 | |
800 | if (getNumResults() == 0 && getAsyncToken()) |
801 | return emitOpError("needs to be named when async keyword is specified"); |
802 | |
803 | return success(); |
804 | } |
805 | |
806 | // Pretty-print the kernel grid/block size assignment as |
807 | // (%iter-x, %iter-y, %iter-z) in |
808 | // (%size-x = %ssa-use, %size-y = %ssa-use, %size-z = %ssa-use) |
809 | // where %size-* and %iter-* will correspond to the body region arguments. |
810 | static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size, |
811 | KernelDim3 operands, KernelDim3 ids) { |
812 | p << '(' << ids.x << ", "<< ids.y << ", "<< ids.z << ") in ("; |
813 | p << size.x << " = "<< operands.x << ", "; |
814 | p << size.y << " = "<< operands.y << ", "; |
815 | p << size.z << " = "<< operands.z << ')'; |
816 | } |
817 | |
818 | void LaunchOp::print(OpAsmPrinter &p) { |
819 | if (getAsyncToken()) { |
820 | p << " async"; |
821 | if (!getAsyncDependencies().empty()) |
822 | p << " ["<< getAsyncDependencies() << ']'; |
823 | } |
824 | // Print the launch configuration. |
825 | if (hasClusterSize()) { |
826 | p << ' ' << getClustersKeyword(); |
827 | printSizeAssignment(p, getClusterSize().value(), |
828 | getClusterSizeOperandValues().value(), |
829 | getClusterIds().value()); |
830 | } |
831 | p << ' ' << getBlocksKeyword(); |
832 | printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(), |
833 | getBlockIds()); |
834 | p << ' ' << getThreadsKeyword(); |
835 | printSizeAssignment(p, getBlockSize(), getBlockSizeOperandValues(), |
836 | getThreadIds()); |
837 | if (getDynamicSharedMemorySize()) |
838 | p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' |
839 | << getDynamicSharedMemorySize(); |
840 | |
841 | printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); |
842 | printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); |
843 | |
844 | p << ' '; |
845 | |
846 | p.printRegion(getBody(), /*printEntryBlockArgs=*/false); |
847 | p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ |
848 | LaunchOp::getOperandSegmentSizeAttr(), |
849 | getNumWorkgroupAttributionsAttrName()}); |
850 | } |
851 | |
852 | // Parse the size assignment blocks for blocks and threads. These have the form |
853 | // (%region_arg, %region_arg, %region_arg) in |
854 | // (%region_arg = %operand, %region_arg = %operand, %region_arg = %operand) |
855 | // where %region_arg are percent-identifiers for the region arguments to be |
856 | // introduced further (SSA defs), and %operand are percent-identifiers for the |
857 | // SSA value uses. |
858 | static ParseResult |
859 | parseSizeAssignment(OpAsmParser &parser, |
860 | MutableArrayRef<OpAsmParser::UnresolvedOperand> sizes, |
861 | MutableArrayRef<OpAsmParser::UnresolvedOperand> regionSizes, |
862 | MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) { |
863 | assert(indices.size() == 3 && "space for three indices expected"); |
864 | SmallVector<OpAsmParser::UnresolvedOperand, 3> args; |
865 | if (parser.parseOperandList(result&: args, delimiter: OpAsmParser::Delimiter::Paren, |
866 | /*allowResultNumber=*/false) || |
867 | parser.parseKeyword(keyword: "in") || parser.parseLParen()) |
868 | return failure(); |
869 | std::move(first: args.begin(), last: args.end(), result: indices.begin()); |
870 | |
871 | for (int i = 0; i < 3; ++i) { |
872 | if (i != 0 && parser.parseComma()) |
873 | return failure(); |
874 | if (parser.parseOperand(result&: regionSizes[i], /*allowResultNumber=*/false) || |
875 | parser.parseEqual() || parser.parseOperand(result&: sizes[i])) |
876 | return failure(); |
877 | } |
878 | |
879 | return parser.parseRParen(); |
880 | } |
881 | |
882 | /// Parses a Launch operation. |
883 | /// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)? |
884 | /// `clusters` `(` ssa-id-list `)` `in` ssa-reassignment (Optional) |
885 | /// `blocks` `(` ssa-id-list `)` `in` ssa-reassignment |
886 | /// `threads` `(` ssa-id-list `)` `in` ssa-reassignment |
887 | /// memory-attribution |
888 | /// region attr-dict? |
889 | /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)` |
890 | ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) { |
891 | // Sizes of the grid and block. |
892 | SmallVector<OpAsmParser::UnresolvedOperand, LaunchOp::kNumConfigOperands> |
893 | sizes(LaunchOp::kNumConfigOperands); |
894 | |
895 | // Actual (data) operands passed to the kernel. |
896 | SmallVector<OpAsmParser::UnresolvedOperand, 4> dataOperands; |
897 | |
898 | // Region arguments to be created. |
899 | SmallVector<OpAsmParser::UnresolvedOperand, 16> regionArgs( |
900 | LaunchOp::kNumConfigRegionAttributes); |
901 | |
902 | // Parse optional async dependencies. |
903 | SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies; |
904 | Type asyncTokenType; |
905 | if (failed( |
906 | parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) || |
907 | parser.resolveOperands(asyncDependencies, asyncTokenType, |
908 | result.operands)) |
909 | return failure(); |
910 | if (parser.getNumResults() > 0) |
911 | result.types.push_back(asyncTokenType); |
912 | |
913 | bool hasCluster = false; |
914 | if (succeeded( |
915 | parser.parseOptionalKeyword(LaunchOp::getClustersKeyword().data()))) { |
916 | hasCluster = true; |
917 | sizes.resize(9); |
918 | regionArgs.resize(18); |
919 | } |
920 | MutableArrayRef<OpAsmParser::UnresolvedOperand> sizesRef(sizes); |
921 | MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs); |
922 | |
923 | // Last three segment assigns the cluster size. In the region argument |
924 | // list, this is last 6 arguments. |
925 | if (hasCluster) { |
926 | if (parseSizeAssignment(parser, sizesRef.drop_front(6), |
927 | regionArgsRef.slice(15, 3), |
928 | regionArgsRef.slice(12, 3))) |
929 | return failure(); |
930 | } |
931 | // Parse the size assignment segments: the first segment assigns grid sizes |
932 | // and defines values for block identifiers; the second segment assigns block |
933 | // sizes and defines values for thread identifiers. In the region argument |
934 | // list, identifiers precede sizes, and block-related values precede |
935 | // thread-related values. |
936 | if (parser.parseKeyword(LaunchOp::getBlocksKeyword().data()) || |
937 | parseSizeAssignment(parser, sizesRef.take_front(3), |
938 | regionArgsRef.slice(6, 3), |
939 | regionArgsRef.slice(0, 3)) || |
940 | parser.parseKeyword(LaunchOp::getThreadsKeyword().data()) || |
941 | parseSizeAssignment(parser, sizesRef.drop_front(3), |
942 | regionArgsRef.slice(9, 3), |
943 | regionArgsRef.slice(3, 3)) || |
944 | parser.resolveOperands(sizes, parser.getBuilder().getIndexType(), |
945 | result.operands)) |
946 | return failure(); |
947 | |
948 | OpAsmParser::UnresolvedOperand dynamicSharedMemorySize; |
949 | bool hasDynamicSharedMemorySize = false; |
950 | if (!parser.parseOptionalKeyword( |
951 | LaunchOp::getDynamicSharedMemorySizeKeyword())) { |
952 | hasDynamicSharedMemorySize = true; |
953 | if (parser.parseOperand(dynamicSharedMemorySize) || |
954 | parser.resolveOperand(dynamicSharedMemorySize, |
955 | parser.getBuilder().getI32Type(), |
956 | result.operands)) |
957 | return failure(); |
958 | } |
959 | |
960 | // Create the region arguments, it has kNumConfigRegionAttributes arguments |
961 | // that correspond to block/thread identifiers and grid/block sizes, all |
962 | // having `index` type, a variadic number of WorkGroup Attributions and |
963 | // a variadic number of Private Attributions. The number of WorkGroup |
964 | // Attributions is stored in the attr with name: |
965 | // LaunchOp::getNumWorkgroupAttributionsAttrName(). |
966 | Type index = parser.getBuilder().getIndexType(); |
967 | SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes( |
968 | LaunchOp::kNumConfigRegionAttributes + 6, index); |
969 | |
970 | SmallVector<OpAsmParser::Argument> regionArguments; |
971 | for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) { |
972 | OpAsmParser::Argument arg; |
973 | arg.ssaName = std::get<0>(ssaValueAndType); |
974 | arg.type = std::get<1>(ssaValueAndType); |
975 | regionArguments.push_back(arg); |
976 | } |
977 | |
978 | Builder &builder = parser.getBuilder(); |
979 | // Parse workgroup memory attributions. |
980 | if (failed(parseAttributions(parser, LaunchOp::getWorkgroupKeyword(), |
981 | regionArguments))) |
982 | return failure(); |
983 | |
984 | // Store the number of operands we just parsed as the number of workgroup |
985 | // memory attributions. |
986 | unsigned numWorkgroupAttrs = regionArguments.size() - |
987 | LaunchOp::kNumConfigRegionAttributes - |
988 | (hasCluster ? 6 : 0); |
989 | result.addAttribute(LaunchOp::getNumWorkgroupAttributionsAttrName(), |
990 | builder.getI64IntegerAttr(numWorkgroupAttrs)); |
991 | |
992 | // Parse private memory attributions. |
993 | if (failed(parseAttributions(parser, LaunchOp::getPrivateKeyword(), |
994 | regionArguments))) |
995 | return failure(); |
996 | |
997 | // Introduce the body region and parse it. The region has |
998 | // kNumConfigRegionAttributes arguments that correspond to |
999 | // block/thread identifiers and grid/block sizes, all having `index` type. |
1000 | Region *body = result.addRegion(); |
1001 | if (parser.parseRegion(*body, regionArguments) || |
1002 | parser.parseOptionalAttrDict(result.attributes)) |
1003 | return failure(); |
1004 | |
1005 | SmallVector<int32_t, 11> segmentSizes(11, 1); |
1006 | segmentSizes.front() = asyncDependencies.size(); |
1007 | |
1008 | if (!hasCluster) { |
1009 | segmentSizes[7] = 0; |
1010 | segmentSizes[8] = 0; |
1011 | segmentSizes[9] = 0; |
1012 | } |
1013 | segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0; |
1014 | result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(), |
1015 | parser.getBuilder().getDenseI32ArrayAttr(segmentSizes)); |
1016 | return success(); |
1017 | } |
1018 | |
1019 | /// Simplify the gpu.launch when the range of a thread or block ID is |
1020 | /// trivially known to be one. |
1021 | struct FoldLaunchArguments : public OpRewritePattern<LaunchOp> { |
1022 | using OpRewritePattern<LaunchOp>::OpRewritePattern; |
1023 | LogicalResult matchAndRewrite(LaunchOp op, |
1024 | PatternRewriter &rewriter) const override { |
1025 | // If the range implies a single value for `id`, replace `id`'s uses by |
1026 | // zero. |
1027 | Value zero; |
1028 | bool simplified = false; |
1029 | auto constPropIdUses = [&](Value id, Value size) { |
1030 | // Check if size is trivially one. |
1031 | if (!matchPattern(value: size, pattern: m_One())) |
1032 | return; |
1033 | if (id.getUses().empty()) |
1034 | return; |
1035 | if (!simplified) { |
1036 | // Create a zero value the first time. |
1037 | OpBuilder::InsertionGuard guard(rewriter); |
1038 | rewriter.setInsertionPointToStart(&op.getBody().front()); |
1039 | zero = |
1040 | rewriter.create<arith::ConstantIndexOp>(op.getLoc(), /*value=*/0); |
1041 | } |
1042 | rewriter.replaceAllUsesWith(from: id, to: zero); |
1043 | simplified = true; |
1044 | }; |
1045 | constPropIdUses(op.getBlockIds().x, op.getGridSizeX()); |
1046 | constPropIdUses(op.getBlockIds().y, op.getGridSizeY()); |
1047 | constPropIdUses(op.getBlockIds().z, op.getGridSizeZ()); |
1048 | constPropIdUses(op.getThreadIds().x, op.getBlockSizeX()); |
1049 | constPropIdUses(op.getThreadIds().y, op.getBlockSizeY()); |
1050 | constPropIdUses(op.getThreadIds().z, op.getBlockSizeZ()); |
1051 | |
1052 | return success(isSuccess: simplified); |
1053 | } |
1054 | }; |
1055 | |
1056 | void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites, |
1057 | MLIRContext *context) { |
1058 | rewrites.add<FoldLaunchArguments>(context); |
1059 | } |
1060 | |
1061 | /// Adds a new block argument that corresponds to buffers located in |
1062 | /// workgroup memory. |
1063 | BlockArgument LaunchOp::addWorkgroupAttribution(Type type, Location loc) { |
1064 | auto attrName = getNumWorkgroupAttributionsAttrName(); |
1065 | auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); |
1066 | (*this)->setAttr(attrName, |
1067 | IntegerAttr::get(attr.getType(), attr.getValue() + 1)); |
1068 | return getBody().insertArgument( |
1069 | LaunchOp::getNumConfigRegionAttributes() + attr.getInt(), type, loc); |
1070 | } |
1071 | |
1072 | /// Adds a new block argument that corresponds to buffers located in |
1073 | /// private memory. |
1074 | BlockArgument LaunchOp::addPrivateAttribution(Type type, Location loc) { |
1075 | // Buffers on the private memory always come after buffers on the workgroup |
1076 | // memory. |
1077 | return getBody().addArgument(type, loc); |
1078 | } |
1079 | |
1080 | //===----------------------------------------------------------------------===// |
1081 | // LaunchFuncOp |
1082 | //===----------------------------------------------------------------------===// |
1083 | |
1084 | void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, |
1085 | GPUFuncOp kernelFunc, KernelDim3 gridSize, |
1086 | KernelDim3 getBlockSize, Value dynamicSharedMemorySize, |
1087 | ValueRange kernelOperands, Type asyncTokenType, |
1088 | ValueRange asyncDependencies, |
1089 | std::optional<KernelDim3> clusterSize) { |
1090 | result.addOperands(asyncDependencies); |
1091 | if (asyncTokenType) |
1092 | result.types.push_back(builder.getType<AsyncTokenType>()); |
1093 | |
1094 | // Add grid and block sizes as op operands, followed by the data operands. |
1095 | result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, |
1096 | getBlockSize.y, getBlockSize.z}); |
1097 | if (clusterSize.has_value()) |
1098 | result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); |
1099 | if (dynamicSharedMemorySize) |
1100 | result.addOperands(dynamicSharedMemorySize); |
1101 | result.addOperands(kernelOperands); |
1102 | auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>(); |
1103 | auto kernelSymbol = |
1104 | SymbolRefAttr::get(kernelModule.getNameAttr(), |
1105 | {SymbolRefAttr::get(kernelFunc.getNameAttr())}); |
1106 | |
1107 | Properties &prop = result.getOrAddProperties<Properties>(); |
1108 | prop.kernel = kernelSymbol; |
1109 | size_t segmentSizesLen = std::size(prop.operandSegmentSizes); |
1110 | // Initialize the segment sizes to 1. |
1111 | for (auto &sz : prop.operandSegmentSizes) |
1112 | sz = 1; |
1113 | prop.operandSegmentSizes[0] = asyncDependencies.size(); |
1114 | if (!clusterSize.has_value()) { |
1115 | prop.operandSegmentSizes[segmentSizesLen - 4] = 0; |
1116 | prop.operandSegmentSizes[segmentSizesLen - 5] = 0; |
1117 | prop.operandSegmentSizes[segmentSizesLen - 6] = 0; |
1118 | } |
1119 | prop.operandSegmentSizes[segmentSizesLen - 3] = |
1120 | dynamicSharedMemorySize ? 1 : 0; |
1121 | prop.operandSegmentSizes[segmentSizesLen - 2] = |
1122 | static_cast<int32_t>(kernelOperands.size()); |
1123 | prop.operandSegmentSizes[segmentSizesLen - 1] = 0; |
1124 | } |
1125 | |
1126 | void LaunchFuncOp::build(OpBuilder &builder, OperationState &result, |
1127 | SymbolRefAttr kernel, KernelDim3 gridSize, |
1128 | KernelDim3 getBlockSize, Value dynamicSharedMemorySize, |
1129 | ValueRange kernelOperands, Value asyncObject, |
1130 | std::optional<KernelDim3> clusterSize) { |
1131 | // Add grid and block sizes as op operands, followed by the data operands. |
1132 | result.addOperands({gridSize.x, gridSize.y, gridSize.z, getBlockSize.x, |
1133 | getBlockSize.y, getBlockSize.z}); |
1134 | if (clusterSize.has_value()) |
1135 | result.addOperands({clusterSize->x, clusterSize->y, clusterSize->z}); |
1136 | if (dynamicSharedMemorySize) |
1137 | result.addOperands(dynamicSharedMemorySize); |
1138 | result.addOperands(kernelOperands); |
1139 | if (asyncObject) |
1140 | result.addOperands(asyncObject); |
1141 | Properties &prop = result.getOrAddProperties<Properties>(); |
1142 | prop.kernel = kernel; |
1143 | size_t segmentSizesLen = std::size(prop.operandSegmentSizes); |
1144 | // Initialize the segment sizes to 1. |
1145 | for (auto &sz : prop.operandSegmentSizes) |
1146 | sz = 1; |
1147 | prop.operandSegmentSizes[0] = 0; |
1148 | if (!clusterSize.has_value()) { |
1149 | prop.operandSegmentSizes[segmentSizesLen - 4] = 0; |
1150 | prop.operandSegmentSizes[segmentSizesLen - 5] = 0; |
1151 | prop.operandSegmentSizes[segmentSizesLen - 6] = 0; |
1152 | } |
1153 | prop.operandSegmentSizes[segmentSizesLen - 3] = |
1154 | dynamicSharedMemorySize ? 1 : 0; |
1155 | prop.operandSegmentSizes[segmentSizesLen - 2] = |
1156 | static_cast<int32_t>(kernelOperands.size()); |
1157 | prop.operandSegmentSizes[segmentSizesLen - 1] = asyncObject ? 1 : 0; |
1158 | } |
1159 | |
1160 | StringAttr LaunchFuncOp::getKernelModuleName() { |
1161 | return getKernel().getRootReference(); |
1162 | } |
1163 | |
1164 | StringAttr LaunchFuncOp::getKernelName() { |
1165 | return getKernel().getLeafReference(); |
1166 | } |
1167 | |
1168 | unsigned LaunchFuncOp::getNumKernelOperands() { |
1169 | return getKernelOperands().size(); |
1170 | } |
1171 | |
1172 | Value LaunchFuncOp::getKernelOperand(unsigned i) { |
1173 | return getKernelOperands()[i]; |
1174 | } |
1175 | |
1176 | KernelDim3 LaunchFuncOp::getGridSizeOperandValues() { |
1177 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
1178 | return KernelDim3{operands[0], operands[1], operands[2]}; |
1179 | } |
1180 | |
1181 | KernelDim3 LaunchFuncOp::getBlockSizeOperandValues() { |
1182 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
1183 | return KernelDim3{operands[3], operands[4], operands[5]}; |
1184 | } |
1185 | |
1186 | KernelDim3 LaunchFuncOp::getClusterSizeOperandValues() { |
1187 | assert(hasClusterSize() && |
1188 | "cluster size is not set, check hasClusterSize() first"); |
1189 | auto operands = getOperands().drop_front(getAsyncDependencies().size()); |
1190 | return KernelDim3{operands[6], operands[7], operands[8]}; |
1191 | } |
1192 | |
1193 | LogicalResult LaunchFuncOp::verify() { |
1194 | auto module = (*this)->getParentOfType<ModuleOp>(); |
1195 | if (!module) |
1196 | return emitOpError("expected to belong to a module"); |
1197 | |
1198 | if (!module->getAttrOfType<UnitAttr>( |
1199 | GPUDialect::getContainerModuleAttrName())) |
1200 | return emitOpError("expected the closest surrounding module to have the '"+ |
1201 | GPUDialect::getContainerModuleAttrName() + |
1202 | "' attribute"); |
1203 | |
1204 | if (hasClusterSize()) { |
1205 | if (getClusterSizeY().getType() != getClusterSizeX().getType() || |
1206 | getClusterSizeZ().getType() != getClusterSizeX().getType()) |
1207 | return emitOpError() |
1208 | << "expects types of the cluster dimensions must be the same"; |
1209 | } |
1210 | |
1211 | return success(); |
1212 | } |
1213 | |
1214 | static ParseResult |
1215 | parseLaunchDimType(OpAsmParser &parser, Type &dimTy, |
1216 | std::optional<OpAsmParser::UnresolvedOperand> clusterValue, |
1217 | Type &clusterXTy, Type &clusterYTy, Type &clusterZTy) { |
1218 | if (succeeded(result: parser.parseOptionalColon())) { |
1219 | if (parser.parseType(result&: dimTy)) |
1220 | return failure(); |
1221 | } else { |
1222 | dimTy = IndexType::get(parser.getContext()); |
1223 | } |
1224 | if (clusterValue.has_value()) { |
1225 | clusterXTy = clusterYTy = clusterZTy = dimTy; |
1226 | } |
1227 | return success(); |
1228 | } |
1229 | |
1230 | static void printLaunchDimType(OpAsmPrinter &printer, Operation *op, Type dimTy, |
1231 | Value clusterValue, Type clusterXTy, |
1232 | Type clusterYTy, Type clusterZTy) { |
1233 | if (!dimTy.isIndex()) |
1234 | printer << ": "<< dimTy; |
1235 | } |
1236 | |
1237 | static ParseResult parseLaunchFuncOperands( |
1238 | OpAsmParser &parser, |
1239 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, |
1240 | SmallVectorImpl<Type> &argTypes) { |
1241 | if (parser.parseOptionalKeyword(keyword: "args")) |
1242 | return success(); |
1243 | |
1244 | auto parseElement = [&]() -> ParseResult { |
1245 | return failure(isFailure: parser.parseOperand(result&: argNames.emplace_back()) || |
1246 | parser.parseColonType(result&: argTypes.emplace_back())); |
1247 | }; |
1248 | |
1249 | return parser.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, |
1250 | parseElementFn: parseElement, contextMessage: " in argument list"); |
1251 | } |
1252 | |
1253 | static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, |
1254 | OperandRange operands, TypeRange types) { |
1255 | if (operands.empty()) |
1256 | return; |
1257 | printer << "args("; |
1258 | llvm::interleaveComma(c: llvm::zip(t&: operands, u&: types), os&: printer, |
1259 | each_fn: [&](const auto &pair) { |
1260 | printer.printOperand(std::get<0>(pair)); |
1261 | printer << " : "; |
1262 | printer.printType(type: std::get<1>(pair)); |
1263 | }); |
1264 | printer << ")"; |
1265 | } |
1266 | |
1267 | //===----------------------------------------------------------------------===// |
1268 | // ShuffleOp |
1269 | //===----------------------------------------------------------------------===// |
1270 | |
1271 | void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, |
1272 | int32_t offset, int32_t width, ShuffleMode mode) { |
1273 | build(builder, result, value, |
1274 | builder.create<arith::ConstantOp>(result.location, |
1275 | builder.getI32IntegerAttr(offset)), |
1276 | builder.create<arith::ConstantOp>(result.location, |
1277 | builder.getI32IntegerAttr(width)), |
1278 | mode); |
1279 | } |
1280 | |
1281 | //===----------------------------------------------------------------------===// |
1282 | // BarrierOp |
1283 | //===----------------------------------------------------------------------===// |
1284 | |
1285 | namespace { |
1286 | |
1287 | /// Remove gpu.barrier after gpu.barrier, the threads are already synchronized! |
1288 | LogicalResult eraseRedundantGpuBarrierOps(BarrierOp op, |
1289 | PatternRewriter &rewriter) { |
1290 | if (isa_and_nonnull<BarrierOp>(op->getNextNode())) { |
1291 | rewriter.eraseOp(op: op); |
1292 | return success(); |
1293 | } |
1294 | return failure(); |
1295 | } |
1296 | |
1297 | } // end anonymous namespace |
1298 | |
1299 | void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1300 | MLIRContext *context) { |
1301 | results.add(eraseRedundantGpuBarrierOps); |
1302 | } |
1303 | |
1304 | //===----------------------------------------------------------------------===// |
1305 | // GPUFuncOp |
1306 | //===----------------------------------------------------------------------===// |
1307 | |
1308 | /// Adds a new block argument that corresponds to buffers located in |
1309 | /// workgroup memory. |
1310 | BlockArgument GPUFuncOp::addWorkgroupAttribution(Type type, Location loc) { |
1311 | auto attrName = getNumWorkgroupAttributionsAttrName(); |
1312 | auto attr = (*this)->getAttrOfType<IntegerAttr>(attrName); |
1313 | (*this)->setAttr(attrName, |
1314 | IntegerAttr::get(attr.getType(), attr.getValue() + 1)); |
1315 | return getBody().insertArgument( |
1316 | getFunctionType().getNumInputs() + attr.getInt(), type, loc); |
1317 | } |
1318 | |
1319 | /// Adds a new block argument that corresponds to buffers located in |
1320 | /// private memory. |
1321 | BlockArgument GPUFuncOp::addPrivateAttribution(Type type, Location loc) { |
1322 | // Buffers on the private memory always come after buffers on the workgroup |
1323 | // memory. |
1324 | return getBody().addArgument(type, loc); |
1325 | } |
1326 | |
1327 | void GPUFuncOp::build(OpBuilder &builder, OperationState &result, |
1328 | StringRef name, FunctionType type, |
1329 | TypeRange workgroupAttributions, |
1330 | TypeRange privateAttributions, |
1331 | ArrayRef<NamedAttribute> attrs) { |
1332 | OpBuilder::InsertionGuard g(builder); |
1333 | |
1334 | result.addAttribute(SymbolTable::getSymbolAttrName(), |
1335 | builder.getStringAttr(name)); |
1336 | result.addAttribute(getFunctionTypeAttrName(result.name), |
1337 | TypeAttr::get(type)); |
1338 | result.addAttribute(getNumWorkgroupAttributionsAttrName(), |
1339 | builder.getI64IntegerAttr(workgroupAttributions.size())); |
1340 | result.addAttributes(attrs); |
1341 | Region *body = result.addRegion(); |
1342 | Block *entryBlock = builder.createBlock(body); |
1343 | |
1344 | // TODO: Allow passing in proper locations here. |
1345 | for (Type argTy : type.getInputs()) |
1346 | entryBlock->addArgument(argTy, result.location); |
1347 | for (Type argTy : workgroupAttributions) |
1348 | entryBlock->addArgument(argTy, result.location); |
1349 | for (Type argTy : privateAttributions) |
1350 | entryBlock->addArgument(argTy, result.location); |
1351 | } |
1352 | |
1353 | /// Parses a GPU function memory attribution. |
1354 | /// |
1355 | /// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? |
1356 | /// (`private` `(` ssa-id-and-type-list `)`)? |
1357 | /// |
1358 | /// Note that this function parses only one of the two similar parts, with the |
1359 | /// keyword provided as argument. |
1360 | static ParseResult |
1361 | parseAttributions(OpAsmParser &parser, StringRef keyword, |
1362 | SmallVectorImpl<OpAsmParser::Argument> &args, |
1363 | Attribute &attributionAttrs) { |
1364 | // If we could not parse the keyword, just assume empty list and succeed. |
1365 | if (failed(result: parser.parseOptionalKeyword(keyword))) |
1366 | return success(); |
1367 | |
1368 | size_t existingArgs = args.size(); |
1369 | ParseResult result = |
1370 | parser.parseArgumentList(result&: args, delimiter: OpAsmParser::Delimiter::Paren, |
1371 | /*allowType=*/true, /*allowAttrs=*/true); |
1372 | if (failed(result)) |
1373 | return result; |
1374 | |
1375 | bool hadAttrs = llvm::any_of(Range: ArrayRef(args).drop_front(N: existingArgs), |
1376 | P: [](const OpAsmParser::Argument &arg) -> bool { |
1377 | return arg.attrs && !arg.attrs.empty(); |
1378 | }); |
1379 | if (!hadAttrs) { |
1380 | attributionAttrs = nullptr; |
1381 | return result; |
1382 | } |
1383 | |
1384 | Builder &builder = parser.getBuilder(); |
1385 | SmallVector<Attribute> attributionAttrsVec; |
1386 | for (const auto &argument : ArrayRef(args).drop_front(N: existingArgs)) { |
1387 | if (!argument.attrs) |
1388 | attributionAttrsVec.push_back(builder.getDictionaryAttr({})); |
1389 | else |
1390 | attributionAttrsVec.push_back(Elt: argument.attrs); |
1391 | } |
1392 | attributionAttrs = builder.getArrayAttr(attributionAttrsVec); |
1393 | return result; |
1394 | } |
1395 | |
1396 | /// Parses a GPU function. |
1397 | /// |
1398 | /// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` |
1399 | /// (`->` function-result-list)? memory-attribution `kernel`? |
1400 | /// function-attributes? region |
1401 | ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { |
1402 | SmallVector<OpAsmParser::Argument> entryArgs; |
1403 | SmallVector<DictionaryAttr> resultAttrs; |
1404 | SmallVector<Type> resultTypes; |
1405 | bool isVariadic; |
1406 | |
1407 | // Parse the function name. |
1408 | StringAttr nameAttr; |
1409 | if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), |
1410 | result.attributes)) |
1411 | return failure(); |
1412 | |
1413 | auto signatureLocation = parser.getCurrentLocation(); |
1414 | if (failed(function_interface_impl::parseFunctionSignature( |
1415 | parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, |
1416 | resultAttrs))) |
1417 | return failure(); |
1418 | |
1419 | if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty()) |
1420 | return parser.emitError(signatureLocation) |
1421 | << "gpu.func requires named arguments"; |
1422 | |
1423 | // Construct the function type. More types will be added to the region, but |
1424 | // not to the function type. |
1425 | Builder &builder = parser.getBuilder(); |
1426 | |
1427 | SmallVector<Type> argTypes; |
1428 | for (auto &arg : entryArgs) |
1429 | argTypes.push_back(arg.type); |
1430 | auto type = builder.getFunctionType(argTypes, resultTypes); |
1431 | result.addAttribute(getFunctionTypeAttrName(result.name), |
1432 | TypeAttr::get(type)); |
1433 | |
1434 | function_interface_impl::addArgAndResultAttrs( |
1435 | builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), |
1436 | getResAttrsAttrName(result.name)); |
1437 | |
1438 | Attribute workgroupAttributionAttrs; |
1439 | // Parse workgroup memory attributions. |
1440 | if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), |
1441 | entryArgs, workgroupAttributionAttrs))) |
1442 | return failure(); |
1443 | |
1444 | // Store the number of operands we just parsed as the number of workgroup |
1445 | // memory attributions. |
1446 | unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs(); |
1447 | result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(), |
1448 | builder.getI64IntegerAttr(numWorkgroupAttrs)); |
1449 | if (workgroupAttributionAttrs) |
1450 | result.addAttribute(GPUFuncOp::getWorkgroupAttribAttrsAttrName(result.name), |
1451 | workgroupAttributionAttrs); |
1452 | |
1453 | Attribute privateAttributionAttrs; |
1454 | // Parse private memory attributions. |
1455 | if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), |
1456 | entryArgs, privateAttributionAttrs))) |
1457 | return failure(); |
1458 | if (privateAttributionAttrs) |
1459 | result.addAttribute(GPUFuncOp::getPrivateAttribAttrsAttrName(result.name), |
1460 | privateAttributionAttrs); |
1461 | |
1462 | // Parse the kernel attribute if present. |
1463 | if (succeeded(parser.parseOptionalKeyword(GPUFuncOp::getKernelKeyword()))) |
1464 | result.addAttribute(GPUDialect::getKernelFuncAttrName(), |
1465 | builder.getUnitAttr()); |
1466 | |
1467 | // Parse attributes. |
1468 | if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) |
1469 | return failure(); |
1470 | |
1471 | // Parse the region. If no argument names were provided, take all names |
1472 | // (including those of attributions) from the entry block. |
1473 | auto *body = result.addRegion(); |
1474 | return parser.parseRegion(*body, entryArgs); |
1475 | } |
1476 | |
1477 | static void printAttributions(OpAsmPrinter &p, StringRef keyword, |
1478 | ArrayRef<BlockArgument> values, |
1479 | ArrayAttr attributes) { |
1480 | if (values.empty()) |
1481 | return; |
1482 | |
1483 | p << ' ' << keyword << '('; |
1484 | llvm::interleaveComma( |
1485 | c: llvm::enumerate(First&: values), os&: p, each_fn: [&p, attributes](auto pair) { |
1486 | BlockArgument v = pair.value(); |
1487 | p << v << " : "<< v.getType(); |
1488 | |
1489 | size_t attributionIndex = pair.index(); |
1490 | DictionaryAttr attrs; |
1491 | if (attributes && attributionIndex < attributes.size()) |
1492 | attrs = llvm::cast<DictionaryAttr>(attributes[attributionIndex]); |
1493 | if (attrs) |
1494 | p.printOptionalAttrDict(attrs: attrs.getValue()); |
1495 | }); |
1496 | p << ')'; |
1497 | } |
1498 | |
1499 | void GPUFuncOp::print(OpAsmPrinter &p) { |
1500 | p << ' '; |
1501 | p.printSymbolName(getName()); |
1502 | |
1503 | FunctionType type = getFunctionType(); |
1504 | function_interface_impl::printFunctionSignature(p, *this, type.getInputs(), |
1505 | /*isVariadic=*/false, |
1506 | type.getResults()); |
1507 | |
1508 | printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions(), |
1509 | getWorkgroupAttribAttrs().value_or(nullptr)); |
1510 | printAttributions(p, getPrivateKeyword(), getPrivateAttributions(), |
1511 | getPrivateAttribAttrs().value_or(nullptr)); |
1512 | if (isKernel()) |
1513 | p << ' ' << getKernelKeyword(); |
1514 | |
1515 | function_interface_impl::printFunctionAttributes( |
1516 | p, *this, |
1517 | {getNumWorkgroupAttributionsAttrName(), |
1518 | GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(), |
1519 | getArgAttrsAttrName(), getResAttrsAttrName(), |
1520 | getWorkgroupAttribAttrsAttrName(), getPrivateAttribAttrsAttrName()}); |
1521 | p << ' '; |
1522 | p.printRegion(getBody(), /*printEntryBlockArgs=*/false); |
1523 | } |
1524 | |
1525 | static DictionaryAttr getAttributionAttrs(GPUFuncOp op, unsigned index, |
1526 | StringAttr attrName) { |
1527 | auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName)); |
1528 | if (!allAttrs || index >= allAttrs.size()) |
1529 | return DictionaryAttr(); |
1530 | return llvm::cast<DictionaryAttr>(allAttrs[index]); |
1531 | } |
1532 | |
1533 | DictionaryAttr GPUFuncOp::getworkgroupAttributionAttrs(unsigned index) { |
1534 | return getAttributionAttrs(*this, index, getWorkgroupAttribAttrsAttrName()); |
1535 | } |
1536 | |
1537 | DictionaryAttr GPUFuncOp::getPrivateAttributionAttrs(unsigned index) { |
1538 | return getAttributionAttrs(*this, index, getPrivateAttribAttrsAttrName()); |
1539 | } |
1540 | |
1541 | static void setAttributionAttrs(GPUFuncOp op, unsigned index, |
1542 | DictionaryAttr value, StringAttr attrName) { |
1543 | MLIRContext *ctx = op.getContext(); |
1544 | auto allAttrs = llvm::dyn_cast_or_null<ArrayAttr>(op->getAttr(attrName)); |
1545 | SmallVector<Attribute> elements; |
1546 | if (allAttrs) |
1547 | elements.append(allAttrs.begin(), allAttrs.end()); |
1548 | while (elements.size() <= index) |
1549 | elements.push_back(DictionaryAttr::get(ctx)); |
1550 | if (!value) |
1551 | elements[index] = DictionaryAttr::get(ctx); |
1552 | else |
1553 | elements[index] = value; |
1554 | ArrayAttr newValue = ArrayAttr::get(ctx, elements); |
1555 | op->setAttr(attrName, newValue); |
1556 | } |
1557 | |
1558 | void GPUFuncOp::setworkgroupAttributionAttrs(unsigned index, |
1559 | DictionaryAttr value) { |
1560 | setAttributionAttrs(*this, index, value, getWorkgroupAttribAttrsAttrName()); |
1561 | } |
1562 | |
1563 | void GPUFuncOp::setPrivateAttributionAttrs(unsigned int index, |
1564 | DictionaryAttr value) { |
1565 | setAttributionAttrs(*this, index, value, getPrivateAttribAttrsAttrName()); |
1566 | } |
1567 | |
1568 | static Attribute getAttributionAttr(GPUFuncOp op, unsigned index, |
1569 | StringAttr name, StringAttr attrsName) { |
1570 | DictionaryAttr dict = getAttributionAttrs(op, index, attrsName); |
1571 | if (!dict) |
1572 | return Attribute(); |
1573 | return dict.get(name); |
1574 | } |
1575 | |
1576 | Attribute GPUFuncOp::getWorkgroupAttributionAttr(unsigned index, |
1577 | StringAttr name) { |
1578 | assert(index < getNumWorkgroupAttributions() && |
1579 | "index must map to a workgroup attribution"); |
1580 | return getAttributionAttr(*this, index, name, |
1581 | getWorkgroupAttribAttrsAttrName()); |
1582 | } |
1583 | |
1584 | Attribute GPUFuncOp::getPrivateAttributionAttr(unsigned index, |
1585 | StringAttr name) { |
1586 | assert(index < getNumPrivateAttributions() && |
1587 | "index must map to a private attribution"); |
1588 | return getAttributionAttr(*this, index, name, |
1589 | getPrivateAttribAttrsAttrName()); |
1590 | } |
1591 | |
1592 | static void setAttributionAttr(GPUFuncOp op, unsigned index, StringAttr name, |
1593 | Attribute value, StringAttr attrsName) { |
1594 | MLIRContext *ctx = op.getContext(); |
1595 | SmallVector<NamedAttribute> elems; |
1596 | DictionaryAttr oldDict = getAttributionAttrs(op, index, attrsName); |
1597 | if (oldDict) |
1598 | elems.append(oldDict.getValue().begin(), oldDict.getValue().end()); |
1599 | |
1600 | bool found = false; |
1601 | bool mustSort = true; |
1602 | for (unsigned i = 0, e = elems.size(); i < e; ++i) { |
1603 | if (elems[i].getName() == name) { |
1604 | found = true; |
1605 | if (!value) { |
1606 | std::swap(a&: elems[i], b&: elems[elems.size() - 1]); |
1607 | elems.pop_back(); |
1608 | } else { |
1609 | mustSort = false; |
1610 | elems[i] = NamedAttribute(elems[i].getName(), value); |
1611 | } |
1612 | break; |
1613 | } |
1614 | } |
1615 | if (!found) { |
1616 | if (!value) |
1617 | return; |
1618 | elems.emplace_back(name, value); |
1619 | } |
1620 | if (mustSort) { |
1621 | DictionaryAttr::sortInPlace(elems); |
1622 | } |
1623 | auto newDict = DictionaryAttr::getWithSorted(ctx, elems); |
1624 | setAttributionAttrs(op, index, newDict, attrsName); |
1625 | } |
1626 | |
1627 | void GPUFuncOp::setWorkgroupAttributionAttr(unsigned index, StringAttr name, |
1628 | Attribute value) { |
1629 | assert(index < getNumWorkgroupAttributions() && |
1630 | "index must map to a workgroup attribution"); |
1631 | setAttributionAttr(*this, index, name, value, |
1632 | getWorkgroupAttribAttrsAttrName()); |
1633 | } |
1634 | |
1635 | void GPUFuncOp::setPrivateAttributionAttr(unsigned index, StringAttr name, |
1636 | Attribute value) { |
1637 | assert(index < getNumPrivateAttributions() && |
1638 | "index must map to a private attribution"); |
1639 | setAttributionAttr(*this, index, name, value, |
1640 | getPrivateAttribAttrsAttrName()); |
1641 | } |
1642 | |
1643 | LogicalResult GPUFuncOp::verifyType() { |
1644 | if (isKernel() && getFunctionType().getNumResults() != 0) |
1645 | return emitOpError() << "expected void return type for kernel function"; |
1646 | |
1647 | return success(); |
1648 | } |
1649 | |
1650 | /// Verifies the body of the function. |
1651 | LogicalResult GPUFuncOp::verifyBody() { |
1652 | if (empty()) |
1653 | return emitOpError() << "expected body with at least one block"; |
1654 | unsigned numFuncArguments = getNumArguments(); |
1655 | unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); |
1656 | unsigned numBlockArguments = front().getNumArguments(); |
1657 | if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) |
1658 | return emitOpError() << "expected at least " |
1659 | << numFuncArguments + numWorkgroupAttributions |
1660 | << " arguments to body region"; |
1661 | |
1662 | ArrayRef<Type> funcArgTypes = getFunctionType().getInputs(); |
1663 | for (unsigned i = 0; i < numFuncArguments; ++i) { |
1664 | Type blockArgType = front().getArgument(i).getType(); |
1665 | if (funcArgTypes[i] != blockArgType) |
1666 | return emitOpError() << "expected body region argument #"<< i |
1667 | << " to be of type "<< funcArgTypes[i] << ", got " |
1668 | << blockArgType; |
1669 | } |
1670 | |
1671 | if (failed(verifyAttributions(getOperation(), getWorkgroupAttributions(), |
1672 | GPUDialect::getWorkgroupAddressSpace())) || |
1673 | failed(verifyAttributions(getOperation(), getPrivateAttributions(), |
1674 | GPUDialect::getPrivateAddressSpace()))) |
1675 | return failure(); |
1676 | |
1677 | return success(); |
1678 | } |
1679 | |
1680 | static LogicalResult verifyKnownLaunchSizeAttr(gpu::GPUFuncOp op, |
1681 | StringRef attrName) { |
1682 | auto maybeAttr = op->getAttr(attrName); |
1683 | if (!maybeAttr) |
1684 | return success(); |
1685 | auto array = llvm::dyn_cast<DenseI32ArrayAttr>(maybeAttr); |
1686 | if (!array) |
1687 | return op.emitOpError(attrName + " must be a dense i32 array"); |
1688 | if (array.size() != 3) |
1689 | return op.emitOpError(attrName + " must contain exactly 3 elements"); |
1690 | return success(); |
1691 | } |
1692 | |
1693 | LogicalResult GPUFuncOp::verify() { |
1694 | if (failed(verifyKnownLaunchSizeAttr(*this, getKnownBlockSizeAttrName()))) |
1695 | return failure(); |
1696 | if (failed(verifyKnownLaunchSizeAttr(*this, getKnownGridSizeAttrName()))) |
1697 | return failure(); |
1698 | return success(); |
1699 | } |
1700 | |
1701 | //===----------------------------------------------------------------------===// |
1702 | // ReturnOp |
1703 | //===----------------------------------------------------------------------===// |
1704 | |
1705 | LogicalResult gpu::ReturnOp::verify() { |
1706 | GPUFuncOp function = (*this)->getParentOfType<GPUFuncOp>(); |
1707 | |
1708 | FunctionType funType = function.getFunctionType(); |
1709 | |
1710 | if (funType.getNumResults() != getOperands().size()) |
1711 | return emitOpError() |
1712 | .append("expected ", funType.getNumResults(), " result operands") |
1713 | .attachNote(function.getLoc()) |
1714 | .append("return type declared here"); |
1715 | |
1716 | for (const auto &pair : llvm::enumerate( |
1717 | llvm::zip(function.getFunctionType().getResults(), getOperands()))) { |
1718 | auto [type, operand] = pair.value(); |
1719 | if (type != operand.getType()) |
1720 | return emitOpError() << "unexpected type `"<< operand.getType() |
1721 | << "' for operand #"<< pair.index(); |
1722 | } |
1723 | return success(); |
1724 | } |
1725 | |
1726 | //===----------------------------------------------------------------------===// |
1727 | // GPUModuleOp |
1728 | //===----------------------------------------------------------------------===// |
1729 | |
1730 | void GPUModuleOp::build(OpBuilder &builder, OperationState &result, |
1731 | StringRef name, ArrayAttr targets, |
1732 | Attribute offloadingHandler) { |
1733 | ensureTerminator(*result.addRegion(), builder, result.location); |
1734 | result.attributes.push_back(builder.getNamedAttr( |
1735 | ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); |
1736 | |
1737 | Properties &props = result.getOrAddProperties<Properties>(); |
1738 | if (targets) |
1739 | props.targets = targets; |
1740 | props.offloadingHandler = offloadingHandler; |
1741 | } |
1742 | |
1743 | void GPUModuleOp::build(OpBuilder &builder, OperationState &result, |
1744 | StringRef name, ArrayRef<Attribute> targets, |
1745 | Attribute offloadingHandler) { |
1746 | build(builder, result, name, |
1747 | targets.empty() ? ArrayAttr() : builder.getArrayAttr(targets), |
1748 | offloadingHandler); |
1749 | } |
1750 | |
1751 | ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) { |
1752 | StringAttr nameAttr; |
1753 | ArrayAttr targetsAttr; |
1754 | |
1755 | if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(), |
1756 | result.attributes)) |
1757 | return failure(); |
1758 | |
1759 | Properties &props = result.getOrAddProperties<Properties>(); |
1760 | |
1761 | // Parse the optional offloadingHandler |
1762 | if (succeeded(parser.parseOptionalLess())) { |
1763 | if (parser.parseAttribute(props.offloadingHandler)) |
1764 | return failure(); |
1765 | if (parser.parseGreater()) |
1766 | return failure(); |
1767 | } |
1768 | |
1769 | // Parse the optional array of target attributes. |
1770 | OptionalParseResult targetsAttrResult = |
1771 | parser.parseOptionalAttribute(targetsAttr, Type{}); |
1772 | if (targetsAttrResult.has_value()) { |
1773 | if (failed(*targetsAttrResult)) { |
1774 | return failure(); |
1775 | } |
1776 | props.targets = targetsAttr; |
1777 | } |
1778 | |
1779 | // If module attributes are present, parse them. |
1780 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) |
1781 | return failure(); |
1782 | |
1783 | // Parse the module body. |
1784 | auto *body = result.addRegion(); |
1785 | if (parser.parseRegion(*body, {})) |
1786 | return failure(); |
1787 | |
1788 | // Ensure that this module has a valid terminator. |
1789 | GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); |
1790 | return success(); |
1791 | } |
1792 | |
1793 | void GPUModuleOp::print(OpAsmPrinter &p) { |
1794 | p << ' '; |
1795 | p.printSymbolName(getName()); |
1796 | |
1797 | if (Attribute attr = getOffloadingHandlerAttr()) { |
1798 | p << " <"; |
1799 | p.printAttribute(attr); |
1800 | p << ">"; |
1801 | } |
1802 | |
1803 | if (Attribute attr = getTargetsAttr()) { |
1804 | p << ' '; |
1805 | p.printAttribute(attr); |
1806 | p << ' '; |
1807 | } |
1808 | |
1809 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
1810 | {mlir::SymbolTable::getSymbolAttrName(), |
1811 | getTargetsAttrName(), |
1812 | getOffloadingHandlerAttrName()}); |
1813 | p << ' '; |
1814 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
1815 | /*printBlockTerminators=*/false); |
1816 | } |
1817 | |
1818 | bool GPUModuleOp::hasTarget(Attribute target) { |
1819 | if (ArrayAttr targets = getTargetsAttr()) |
1820 | return llvm::count(targets.getValue(), target); |
1821 | return false; |
1822 | } |
1823 | |
1824 | void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) { |
1825 | ArrayAttr &targetsAttr = getProperties().targets; |
1826 | SmallVector<Attribute> targetsVector(targets); |
1827 | targetsAttr = ArrayAttr::get(getContext(), targetsVector); |
1828 | } |
1829 | |
1830 | //===----------------------------------------------------------------------===// |
1831 | // GPUBinaryOp |
1832 | //===----------------------------------------------------------------------===// |
1833 | void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
1834 | Attribute offloadingHandler, ArrayAttr objects) { |
1835 | auto &properties = result.getOrAddProperties<Properties>(); |
1836 | result.attributes.push_back(builder.getNamedAttr( |
1837 | SymbolTable::getSymbolAttrName(), builder.getStringAttr(name))); |
1838 | properties.objects = objects; |
1839 | if (offloadingHandler) |
1840 | properties.offloadingHandler = offloadingHandler; |
1841 | else |
1842 | properties.offloadingHandler = builder.getAttr<SelectObjectAttr>(nullptr); |
1843 | } |
1844 | |
1845 | void BinaryOp::build(OpBuilder &builder, OperationState &result, StringRef name, |
1846 | Attribute offloadingHandler, ArrayRef<Attribute> objects) { |
1847 | build(builder, result, name, offloadingHandler, |
1848 | objects.empty() ? ArrayAttr() : builder.getArrayAttr(objects)); |
1849 | } |
1850 | |
1851 | static ParseResult parseOffloadingHandler(OpAsmParser &parser, |
1852 | Attribute &offloadingHandler) { |
1853 | if (succeeded(result: parser.parseOptionalLess())) { |
1854 | if (parser.parseAttribute(result&: offloadingHandler)) |
1855 | return failure(); |
1856 | if (parser.parseGreater()) |
1857 | return failure(); |
1858 | } |
1859 | if (!offloadingHandler) |
1860 | offloadingHandler = parser.getBuilder().getAttr<SelectObjectAttr>(nullptr); |
1861 | return success(); |
1862 | } |
1863 | |
1864 | static void printOffloadingHandler(OpAsmPrinter &printer, Operation *op, |
1865 | Attribute offloadingHandler) { |
1866 | if (offloadingHandler != SelectObjectAttr::get(op->getContext(), nullptr)) |
1867 | printer << '<' << offloadingHandler << '>'; |
1868 | } |
1869 | |
1870 | //===----------------------------------------------------------------------===// |
1871 | // GPUMemcpyOp |
1872 | //===----------------------------------------------------------------------===// |
1873 | |
1874 | LogicalResult MemcpyOp::verify() { |
1875 | auto srcType = getSrc().getType(); |
1876 | auto dstType = getDst().getType(); |
1877 | |
1878 | if (getElementTypeOrSelf(srcType) != getElementTypeOrSelf(dstType)) |
1879 | return emitOpError("arguments have incompatible element type"); |
1880 | |
1881 | if (failed(verifyCompatibleShape(srcType, dstType))) |
1882 | return emitOpError("arguments have incompatible shape"); |
1883 | |
1884 | return success(); |
1885 | } |
1886 | |
1887 | namespace { |
1888 | |
1889 | /// Erases a common case of copy ops where a destination value is used only by |
1890 | /// the copy op, alloc and dealloc ops. |
1891 | struct EraseTrivialCopyOp : public OpRewritePattern<MemcpyOp> { |
1892 | using OpRewritePattern<MemcpyOp>::OpRewritePattern; |
1893 | |
1894 | LogicalResult matchAndRewrite(MemcpyOp op, |
1895 | PatternRewriter &rewriter) const override { |
1896 | Value dest = op.getDst(); |
1897 | Operation *destDefOp = dest.getDefiningOp(); |
1898 | // `dest` must be defined by an op having Allocate memory effect in order to |
1899 | // perform the folding. |
1900 | if (!destDefOp || |
1901 | !hasSingleEffect<MemoryEffects::Allocate>(op: destDefOp, value: dest)) |
1902 | return failure(); |
1903 | // We can erase `op` iff `dest` has no other use apart from its |
1904 | // use by `op` and dealloc ops. |
1905 | if (llvm::any_of(Range: dest.getUsers(), P: [op, dest](Operation *user) { |
1906 | return user != op && |
1907 | !hasSingleEffect<MemoryEffects::Free>(user, dest); |
1908 | })) |
1909 | return failure(); |
1910 | // We can perform the folding if and only if op has a single async |
1911 | // dependency and produces an async token as result, or if it does not have |
1912 | // any async dependency and does not produce any async token result. |
1913 | if (op.getAsyncDependencies().size() > 1 || |
1914 | ((op.getAsyncDependencies().empty() && op.getAsyncToken()) || |
1915 | (!op.getAsyncDependencies().empty() && !op.getAsyncToken()))) |
1916 | return failure(); |
1917 | rewriter.replaceOp(op, op.getAsyncDependencies()); |
1918 | return success(); |
1919 | } |
1920 | }; |
1921 | |
1922 | } // end anonymous namespace |
1923 | |
1924 | void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1925 | MLIRContext *context) { |
1926 | results.add<EraseTrivialCopyOp>(context); |
1927 | } |
1928 | |
1929 | //===----------------------------------------------------------------------===// |
1930 | // GPU_SubgroupMmaLoadMatrixOp |
1931 | //===----------------------------------------------------------------------===// |
1932 | |
1933 | LogicalResult SubgroupMmaLoadMatrixOp::verify() { |
1934 | auto srcType = getSrcMemref().getType(); |
1935 | auto resType = getRes().getType(); |
1936 | auto resMatrixType = llvm::cast<gpu::MMAMatrixType>(resType); |
1937 | auto operand = resMatrixType.getOperand(); |
1938 | auto srcMemrefType = llvm::cast<MemRefType>(srcType); |
1939 | |
1940 | if (!isLastMemrefDimUnitStride(srcMemrefType)) |
1941 | return emitError( |
1942 | "expected source memref most minor dim must have unit stride"); |
1943 | |
1944 | if (!operand.equals("AOp") && !operand.equals( "BOp") && |
1945 | !operand.equals("COp")) |
1946 | return emitError("only AOp, BOp and COp can be loaded"); |
1947 | |
1948 | return success(); |
1949 | } |
1950 | |
1951 | //===----------------------------------------------------------------------===// |
1952 | // GPU_SubgroupMmaStoreMatrixOp |
1953 | //===----------------------------------------------------------------------===// |
1954 | |
1955 | LogicalResult SubgroupMmaStoreMatrixOp::verify() { |
1956 | auto srcType = getSrc().getType(); |
1957 | auto dstType = getDstMemref().getType(); |
1958 | auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType); |
1959 | auto dstMemrefType = llvm::cast<MemRefType>(dstType); |
1960 | |
1961 | if (!isLastMemrefDimUnitStride(dstMemrefType)) |
1962 | return emitError( |
1963 | "expected destination memref most minor dim must have unit stride"); |
1964 | |
1965 | if (!srcMatrixType.getOperand().equals("COp")) |
1966 | return emitError( |
1967 | "expected the operand matrix being stored to have 'COp' operand type"); |
1968 | |
1969 | return success(); |
1970 | } |
1971 | |
1972 | //===----------------------------------------------------------------------===// |
1973 | // GPU_SubgroupMmaComputeOp |
1974 | //===----------------------------------------------------------------------===// |
1975 | |
1976 | LogicalResult SubgroupMmaComputeOp::verify() { |
1977 | enum OperandMap { A, B, C }; |
1978 | SmallVector<MMAMatrixType, 3> opTypes; |
1979 | opTypes.push_back(llvm::cast<MMAMatrixType>(getOpA().getType())); |
1980 | opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType())); |
1981 | opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType())); |
1982 | |
1983 | if (!opTypes[A].getOperand().equals("AOp") || |
1984 | !opTypes[B].getOperand().equals("BOp") || |
1985 | !opTypes[C].getOperand().equals("COp")) |
1986 | return emitError("operands must be in the order AOp, BOp, COp"); |
1987 | |
1988 | ArrayRef<int64_t> aShape, bShape, cShape; |
1989 | aShape = opTypes[A].getShape(); |
1990 | bShape = opTypes[B].getShape(); |
1991 | cShape = opTypes[C].getShape(); |
1992 | |
1993 | if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || |
1994 | bShape[1] != cShape[1]) |
1995 | return emitError("operand shapes do not satisfy matmul constraints"); |
1996 | |
1997 | return success(); |
1998 | } |
1999 | |
2000 | LogicalResult MemcpyOp::fold(FoldAdaptor adaptor, |
2001 | SmallVectorImpl<::mlir::OpFoldResult> &results) { |
2002 | return memref::foldMemRefCast(*this); |
2003 | } |
2004 | |
2005 | LogicalResult MemsetOp::fold(FoldAdaptor adaptor, |
2006 | SmallVectorImpl<::mlir::OpFoldResult> &results) { |
2007 | return memref::foldMemRefCast(*this); |
2008 | } |
2009 | |
2010 | //===----------------------------------------------------------------------===// |
2011 | // GPU_WaitOp |
2012 | //===----------------------------------------------------------------------===// |
2013 | |
2014 | namespace { |
2015 | |
2016 | /// Remove gpu.wait op use of gpu.wait op def without async dependencies. |
2017 | /// %t = gpu.wait async [] // No async dependencies. |
2018 | /// ... gpu.wait ... [%t, ...] // %t can be removed. |
2019 | struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern<WaitOp> { |
2020 | public: |
2021 | using OpRewritePattern::OpRewritePattern; |
2022 | |
2023 | LogicalResult matchAndRewrite(WaitOp op, |
2024 | PatternRewriter &rewriter) const final { |
2025 | auto predicate = [](Value value) { |
2026 | auto waitOp = value.getDefiningOp<WaitOp>(); |
2027 | return waitOp && waitOp->getNumOperands() == 0; |
2028 | }; |
2029 | if (llvm::none_of(op.getAsyncDependencies(), predicate)) |
2030 | return failure(); |
2031 | SmallVector<Value> validOperands; |
2032 | for (Value operand : op->getOperands()) { |
2033 | if (predicate(operand)) |
2034 | continue; |
2035 | validOperands.push_back(operand); |
2036 | } |
2037 | rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); }); |
2038 | return success(); |
2039 | } |
2040 | }; |
2041 | |
2042 | /// Simplify trivial gpu.wait ops for the following patterns. |
2043 | /// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async |
2044 | /// dependencies). |
2045 | /// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with |
2046 | /// %t0. |
2047 | /// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async |
2048 | /// dependencies nor return any token. |
2049 | struct SimplifyGpuWaitOp : public OpRewritePattern<WaitOp> { |
2050 | public: |
2051 | using OpRewritePattern::OpRewritePattern; |
2052 | |
2053 | LogicalResult matchAndRewrite(WaitOp op, |
2054 | PatternRewriter &rewriter) const final { |
2055 | // Erase gpu.wait ops that neither have any async dependencies nor return |
2056 | // any async token. |
2057 | if (op.getAsyncDependencies().empty() && !op.getAsyncToken()) { |
2058 | rewriter.eraseOp(op: op); |
2059 | return success(); |
2060 | } |
2061 | // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op. |
2062 | if (llvm::hasSingleElement(op.getAsyncDependencies()) && |
2063 | op.getAsyncToken()) { |
2064 | rewriter.replaceOp(op, op.getAsyncDependencies()); |
2065 | return success(); |
2066 | } |
2067 | // Erase %t = gpu.wait async ... ops, where %t has no uses. |
2068 | if (op.getAsyncToken() && op.getAsyncToken().use_empty()) { |
2069 | rewriter.eraseOp(op: op); |
2070 | return success(); |
2071 | } |
2072 | return failure(); |
2073 | } |
2074 | }; |
2075 | |
2076 | } // end anonymous namespace |
2077 | |
2078 | void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2079 | MLIRContext *context) { |
2080 | results.add<EraseRedundantGpuWaitOpPairs, SimplifyGpuWaitOp>(context); |
2081 | } |
2082 | |
2083 | //===----------------------------------------------------------------------===// |
2084 | // GPU_AllocOp |
2085 | //===----------------------------------------------------------------------===// |
2086 | |
2087 | LogicalResult AllocOp::verify() { |
2088 | auto memRefType = llvm::cast<MemRefType>(getMemref().getType()); |
2089 | |
2090 | if (static_cast<int64_t>(getDynamicSizes().size()) != |
2091 | memRefType.getNumDynamicDims()) |
2092 | return emitOpError("dimension operand count does not equal memref " |
2093 | "dynamic dimension count"); |
2094 | |
2095 | unsigned numSymbols = 0; |
2096 | if (!memRefType.getLayout().isIdentity()) |
2097 | numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols(); |
2098 | if (getSymbolOperands().size() != numSymbols) { |
2099 | return emitOpError( |
2100 | "symbol operand count does not equal memref symbol count"); |
2101 | } |
2102 | |
2103 | return success(); |
2104 | } |
2105 | |
2106 | namespace { |
2107 | |
2108 | /// Folding of memref.dim(gpu.alloc(%size), %idx) -> %size similar to |
2109 | /// `memref::AllocOp`. |
2110 | struct SimplifyDimOfAllocOp : public OpRewritePattern<memref::DimOp> { |
2111 | using OpRewritePattern<memref::DimOp>::OpRewritePattern; |
2112 | |
2113 | LogicalResult matchAndRewrite(memref::DimOp dimOp, |
2114 | PatternRewriter &rewriter) const override { |
2115 | std::optional<int64_t> index = dimOp.getConstantIndex(); |
2116 | if (!index) |
2117 | return failure(); |
2118 | |
2119 | auto memrefType = llvm::dyn_cast<MemRefType>(dimOp.getSource().getType()); |
2120 | if (!memrefType || !memrefType.isDynamicDim(index.value())) |
2121 | return failure(); |
2122 | |
2123 | auto alloc = dimOp.getSource().getDefiningOp<AllocOp>(); |
2124 | if (!alloc) |
2125 | return failure(); |
2126 | |
2127 | Value substituteOp = *(alloc.getDynamicSizes().begin() + |
2128 | memrefType.getDynamicDimIndex(index.value())); |
2129 | rewriter.replaceOp(dimOp, substituteOp); |
2130 | return success(); |
2131 | } |
2132 | }; |
2133 | |
2134 | } // namespace |
2135 | |
2136 | void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2137 | MLIRContext *context) { |
2138 | results.add<SimplifyDimOfAllocOp>(context); |
2139 | } |
2140 | |
2141 | //===----------------------------------------------------------------------===// |
2142 | // GPU object attribute |
2143 | //===----------------------------------------------------------------------===// |
2144 | |
2145 | LogicalResult ObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
2146 | Attribute target, CompilationTarget format, |
2147 | StringAttr object, DictionaryAttr properties) { |
2148 | if (!target) |
2149 | return emitError() << "the target attribute cannot be null"; |
2150 | if (target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) |
2151 | return success(); |
2152 | return emitError() << "the target attribute must implement or promise the " |
2153 | "`gpu::TargetAttrInterface`"; |
2154 | } |
2155 | |
2156 | namespace { |
2157 | LogicalResult parseObject(AsmParser &odsParser, CompilationTarget &format, |
2158 | StringAttr &object) { |
2159 | std::optional<CompilationTarget> formatResult; |
2160 | StringRef enumKeyword; |
2161 | auto loc = odsParser.getCurrentLocation(); |
2162 | if (failed(odsParser.parseOptionalKeyword(&enumKeyword))) |
2163 | formatResult = CompilationTarget::Fatbin; |
2164 | if (!formatResult && |
2165 | (formatResult = |
2166 | gpu::symbolizeEnum<gpu::CompilationTarget>(enumKeyword)) && |
2167 | odsParser.parseEqual()) |
2168 | return odsParser.emitError(loc, message: "expected an equal sign"); |
2169 | if (!formatResult) |
2170 | return odsParser.emitError(loc, message: "expected keyword for GPU object format"); |
2171 | FailureOr<StringAttr> objectResult = |
2172 | FieldParser<StringAttr>::parse(odsParser); |
2173 | if (failed(objectResult)) |
2174 | return odsParser.emitError(loc: odsParser.getCurrentLocation(), |
2175 | message: "failed to parse GPU_ObjectAttr parameter " |
2176 | "'object' which is to be a `StringAttr`"); |
2177 | format = *formatResult; |
2178 | object = *objectResult; |
2179 | return success(); |
2180 | } |
2181 | |
2182 | void printObject(AsmPrinter &odsParser, CompilationTarget format, |
2183 | StringAttr object) { |
2184 | if (format != CompilationTarget::Fatbin) |
2185 | odsParser << stringifyEnum(format) << " = "; |
2186 | odsParser << object; |
2187 | } |
2188 | } // namespace |
2189 | |
2190 | //===----------------------------------------------------------------------===// |
2191 | // GPU select object attribute |
2192 | //===----------------------------------------------------------------------===// |
2193 | |
2194 | LogicalResult |
2195 | gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
2196 | Attribute target) { |
2197 | // Check `target`, it can be null, an integer attr or a GPU Target attribute. |
2198 | if (target) { |
2199 | if (auto intAttr = mlir::dyn_cast<IntegerAttr>(target)) { |
2200 | if (intAttr.getInt() < 0) { |
2201 | return emitError() << "the object index must be positive"; |
2202 | } |
2203 | } else if (!target.hasPromiseOrImplementsInterface<TargetAttrInterface>()) { |
2204 | return emitError() |
2205 | << "the target attribute must be a GPU Target attribute"; |
2206 | } |
2207 | } |
2208 | return success(); |
2209 | } |
2210 | |
2211 | //===----------------------------------------------------------------------===// |
2212 | // DynamicSharedMemoryOp |
2213 | //===----------------------------------------------------------------------===// |
2214 | |
2215 | LogicalResult gpu::DynamicSharedMemoryOp::verify() { |
2216 | if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>()) |
2217 | return emitOpError() << "must be inside an op with symbol table"; |
2218 | |
2219 | MemRefType memrefType = getResultMemref().getType(); |
2220 | // Check address space |
2221 | if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) { |
2222 | return emitOpError() << "address space must be " |
2223 | << gpu::AddressSpaceAttr::getMnemonic() << "<" |
2224 | << stringifyEnum(gpu::AddressSpace::Workgroup) << ">"; |
2225 | } |
2226 | if (memrefType.hasStaticShape()) { |
2227 | return emitOpError() << "result memref type must be memref<?xi8, " |
2228 | "#gpu.address_space<workgroup>>"; |
2229 | } |
2230 | return success(); |
2231 | } |
2232 | |
2233 | //===----------------------------------------------------------------------===// |
2234 | // GPU target options |
2235 | //===----------------------------------------------------------------------===// |
2236 | |
2237 | TargetOptions::TargetOptions( |
2238 | StringRef toolkitPath, ArrayRef<std::string> linkFiles, |
2239 | StringRef cmdOptions, CompilationTarget compilationTarget, |
2240 | function_ref<SymbolTable *()> getSymbolTableCallback) |
2241 | : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles, |
2242 | cmdOptions, compilationTarget, getSymbolTableCallback) {} |
2243 | |
2244 | TargetOptions::TargetOptions( |
2245 | TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles, |
2246 | StringRef cmdOptions, CompilationTarget compilationTarget, |
2247 | function_ref<SymbolTable *()> getSymbolTableCallback) |
2248 | : toolkitPath(toolkitPath.str()), linkFiles(linkFiles), |
2249 | cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget), |
2250 | getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {} |
2251 | |
2252 | TypeID TargetOptions::getTypeID() const { return typeID; } |
2253 | |
2254 | StringRef TargetOptions::getToolkitPath() const { return toolkitPath; } |
2255 | |
2256 | ArrayRef<std::string> TargetOptions::getLinkFiles() const { return linkFiles; } |
2257 | |
2258 | StringRef TargetOptions::getCmdOptions() const { return cmdOptions; } |
2259 | |
2260 | SymbolTable *TargetOptions::getSymbolTable() const { |
2261 | return getSymbolTableCallback ? getSymbolTableCallback() : nullptr; |
2262 | } |
2263 | |
2264 | CompilationTarget TargetOptions::getCompilationTarget() const { |
2265 | return compilationTarget; |
2266 | } |
2267 | |
2268 | CompilationTarget TargetOptions::getDefaultCompilationTarget() { |
2269 | return CompilationTarget::Fatbin; |
2270 | } |
2271 | |
2272 | std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> |
2273 | TargetOptions::tokenizeCmdOptions() const { |
2274 | std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options; |
2275 | llvm::StringSaver stringSaver(options.first); |
2276 | StringRef opts = cmdOptions; |
2277 | // For a correct tokenization of the command line options `opts` must be |
2278 | // unquoted, otherwise the tokenization function returns a single string: the |
2279 | // unquoted `cmdOptions` -which is not the desired behavior. |
2280 | // Remove any quotes if they are at the beginning and end of the string: |
2281 | if (!opts.empty() && opts.front() == '"' && opts.back() == '"') |
2282 | opts.consume_front(Prefix: "\""), opts.consume_back(Suffix: "\""); |
2283 | if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'') |
2284 | opts.consume_front(Prefix: "'"), opts.consume_back(Suffix: "'"); |
2285 | #ifdef _WIN32 |
2286 | llvm::cl::TokenizeWindowsCommandLine(opts, stringSaver, options.second, |
2287 | /*MarkEOLs=*/false); |
2288 | #else |
2289 | llvm::cl::TokenizeGNUCommandLine(Source: opts, Saver&: stringSaver, NewArgv&: options.second, |
2290 | /*MarkEOLs=*/false); |
2291 | #endif // _WIN32 |
2292 | return options; |
2293 | } |
2294 | |
2295 | MLIR_DEFINE_EXPLICIT_TYPE_ID(::mlir::gpu::TargetOptions) |
2296 | |
2297 | #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.cpp.inc" |
2298 | #include "mlir/Dialect/GPU/IR/GPUOpsEnums.cpp.inc" |
2299 | |
2300 | #define GET_ATTRDEF_CLASSES |
2301 | #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.cpp.inc" |
2302 | |
2303 | #define GET_OP_CLASSES |
2304 | #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc" |
2305 | |
2306 | #include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.cpp.inc" |
2307 |
Definitions
- get
- getChecked
- getNumDims
- getShape
- getElementType
- getOperand
- isValidElementType
- verify
- GPUInlinerInterface
- isLegalToInline
- getSparseHandleKeyword
- parseAsyncDependencies
- printAsyncDependencies
- parseAttributions
- printAttributions
- verifyAttributions
- verifyReduceOpAndType
- canMakeGroupOpUniform
- parseAllReduceOperation
- printAllReduceOperation
- addAsyncDependency
- printSizeAssignment
- parseSizeAssignment
- FoldLaunchArguments
- matchAndRewrite
- parseLaunchDimType
- printLaunchDimType
- parseLaunchFuncOperands
- printLaunchFuncOperands
- eraseRedundantGpuBarrierOps
- parseAttributions
- printAttributions
- getAttributionAttrs
- setAttributionAttrs
- getAttributionAttr
- setAttributionAttr
- verifyKnownLaunchSizeAttr
- parseOffloadingHandler
- printOffloadingHandler
- EraseTrivialCopyOp
- matchAndRewrite
- EraseRedundantGpuWaitOpPairs
- matchAndRewrite
- SimplifyGpuWaitOp
- matchAndRewrite
- SimplifyDimOfAllocOp
- matchAndRewrite
- parseObject
- printObject
- TargetOptions
- TargetOptions
- getTypeID
- getToolkitPath
- getLinkFiles
- getCmdOptions
- getSymbolTable
- getCompilationTarget
- getDefaultCompilationTarget
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more