1 | //===- Async.cpp - MLIR Async Operations ----------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Async/IR/Async.h" |
10 | |
11 | #include "mlir/IR/DialectImplementation.h" |
12 | #include "mlir/IR/IRMapping.h" |
13 | #include "mlir/Interfaces/FunctionImplementation.h" |
14 | #include "llvm/ADT/MapVector.h" |
15 | #include "llvm/ADT/TypeSwitch.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::async; |
19 | |
20 | #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" |
21 | |
22 | constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; |
23 | |
24 | void AsyncDialect::initialize() { |
25 | addOperations< |
26 | #define GET_OP_LIST |
27 | #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" |
28 | >(); |
29 | addTypes< |
30 | #define GET_TYPEDEF_LIST |
31 | #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" |
32 | >(); |
33 | } |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | /// ExecuteOp |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes" ; |
40 | |
41 | OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
42 | assert(point == getBodyRegion() && "invalid region index" ); |
43 | return getBodyOperands(); |
44 | } |
45 | |
46 | bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { |
47 | const auto getValueOrTokenType = [](Type type) { |
48 | if (auto value = llvm::dyn_cast<ValueType>(type)) |
49 | return value.getValueType(); |
50 | return type; |
51 | }; |
52 | return getValueOrTokenType(lhs) == getValueOrTokenType(rhs); |
53 | } |
54 | |
55 | void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, |
56 | SmallVectorImpl<RegionSuccessor> ®ions) { |
57 | // The `body` region branch back to the parent operation. |
58 | if (point == getBodyRegion()) { |
59 | regions.push_back(RegionSuccessor(getBodyResults())); |
60 | return; |
61 | } |
62 | |
63 | // Otherwise the successor is the body region. |
64 | regions.push_back( |
65 | RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); |
66 | } |
67 | |
68 | void ExecuteOp::build(OpBuilder &builder, OperationState &result, |
69 | TypeRange resultTypes, ValueRange dependencies, |
70 | ValueRange operands, BodyBuilderFn bodyBuilder) { |
71 | OpBuilder::InsertionGuard guard(builder); |
72 | result.addOperands(dependencies); |
73 | result.addOperands(operands); |
74 | |
75 | // Add derived `operandSegmentSizes` attribute based on parsed operands. |
76 | int32_t numDependencies = dependencies.size(); |
77 | int32_t numOperands = operands.size(); |
78 | auto operandSegmentSizes = |
79 | builder.getDenseI32ArrayAttr({numDependencies, numOperands}); |
80 | result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); |
81 | |
82 | // First result is always a token, and then `resultTypes` wrapped into |
83 | // `async.value`. |
84 | result.addTypes({TokenType::get(result.getContext())}); |
85 | for (Type type : resultTypes) |
86 | result.addTypes(ValueType::get(type)); |
87 | |
88 | // Add a body region with block arguments as unwrapped async value operands. |
89 | Region *bodyRegion = result.addRegion(); |
90 | Block *bodyBlock = builder.createBlock(bodyRegion); |
91 | for (Value operand : operands) { |
92 | auto valueType = llvm::dyn_cast<ValueType>(operand.getType()); |
93 | bodyBlock->addArgument(valueType ? valueType.getValueType() |
94 | : operand.getType(), |
95 | operand.getLoc()); |
96 | } |
97 | |
98 | // Create the default terminator if the builder is not provided and if the |
99 | // expected result is empty. Otherwise, leave this to the caller |
100 | // because we don't know which values to return from the execute op. |
101 | if (resultTypes.empty() && !bodyBuilder) { |
102 | builder.create<async::YieldOp>(result.location, ValueRange()); |
103 | } else if (bodyBuilder) { |
104 | bodyBuilder(builder, result.location, bodyBlock->getArguments()); |
105 | } |
106 | } |
107 | |
108 | void ExecuteOp::print(OpAsmPrinter &p) { |
109 | // [%tokens,...] |
110 | if (!getDependencies().empty()) |
111 | p << " [" << getDependencies() << "]" ; |
112 | |
113 | // (%value as %unwrapped: !async.value<!arg.type>, ...) |
114 | if (!getBodyOperands().empty()) { |
115 | p << " (" ; |
116 | Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front(); |
117 | llvm::interleaveComma( |
118 | getBodyOperands(), p, [&, n = 0](Value operand) mutable { |
119 | Value argument = entry ? entry->getArgument(n++) : Value(); |
120 | p << operand << " as " << argument << ": " << operand.getType(); |
121 | }); |
122 | p << ")" ; |
123 | } |
124 | |
125 | // -> (!async.value<!return.type>, ...) |
126 | p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes())); |
127 | p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), |
128 | {kOperandSegmentSizesAttr}); |
129 | p << ' '; |
130 | p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false); |
131 | } |
132 | |
133 | ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) { |
134 | MLIRContext *ctx = result.getContext(); |
135 | |
136 | // Sizes of parsed variadic operands, will be updated below after parsing. |
137 | int32_t numDependencies = 0; |
138 | |
139 | auto tokenTy = TokenType::get(ctx); |
140 | |
141 | // Parse dependency tokens. |
142 | if (succeeded(parser.parseOptionalLSquare())) { |
143 | SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs; |
144 | if (parser.parseOperandList(tokenArgs) || |
145 | parser.resolveOperands(tokenArgs, tokenTy, result.operands) || |
146 | parser.parseRSquare()) |
147 | return failure(); |
148 | |
149 | numDependencies = tokenArgs.size(); |
150 | } |
151 | |
152 | // Parse async value operands (%value as %unwrapped : !async.value<!type>). |
153 | SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs; |
154 | SmallVector<OpAsmParser::Argument, 4> unwrappedArgs; |
155 | SmallVector<Type, 4> valueTypes; |
156 | |
157 | // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. |
158 | auto parseAsyncValueArg = [&]() -> ParseResult { |
159 | if (parser.parseOperand(valueArgs.emplace_back()) || |
160 | parser.parseKeyword("as" ) || |
161 | parser.parseArgument(unwrappedArgs.emplace_back()) || |
162 | parser.parseColonType(valueTypes.emplace_back())) |
163 | return failure(); |
164 | |
165 | auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back()); |
166 | unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type(); |
167 | return success(); |
168 | }; |
169 | |
170 | auto argsLoc = parser.getCurrentLocation(); |
171 | if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen, |
172 | parseAsyncValueArg) || |
173 | parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands)) |
174 | return failure(); |
175 | |
176 | int32_t numOperands = valueArgs.size(); |
177 | |
178 | // Add derived `operandSegmentSizes` attribute based on parsed operands. |
179 | auto operandSegmentSizes = |
180 | parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands}); |
181 | result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); |
182 | |
183 | // Parse the types of results returned from the async execute op. |
184 | SmallVector<Type, 4> resultTypes; |
185 | NamedAttrList attrs; |
186 | if (parser.parseOptionalArrowTypeList(resultTypes) || |
187 | // Async execute first result is always a completion token. |
188 | parser.addTypeToList(tokenTy, result.types) || |
189 | parser.addTypesToList(resultTypes, result.types) || |
190 | // Parse operation attributes. |
191 | parser.parseOptionalAttrDictWithKeyword(attrs)) |
192 | return failure(); |
193 | |
194 | result.addAttributes(attrs); |
195 | |
196 | // Parse asynchronous region. |
197 | Region *body = result.addRegion(); |
198 | return parser.parseRegion(*body, /*arguments=*/unwrappedArgs); |
199 | } |
200 | |
201 | LogicalResult ExecuteOp::verifyRegions() { |
202 | // Unwrap async.execute value operands types. |
203 | auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) { |
204 | return llvm::cast<ValueType>(operand.getType()).getValueType(); |
205 | }); |
206 | |
207 | // Verify that unwrapped argument types matches the body region arguments. |
208 | if (getBodyRegion().getArgumentTypes() != unwrappedTypes) |
209 | return emitOpError("async body region argument types do not match the " |
210 | "execute operation arguments types" ); |
211 | |
212 | return success(); |
213 | } |
214 | |
215 | //===----------------------------------------------------------------------===// |
216 | /// CreateGroupOp |
217 | //===----------------------------------------------------------------------===// |
218 | |
219 | LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, |
220 | PatternRewriter &rewriter) { |
221 | // Find all `await_all` users of the group. |
222 | llvm::SmallVector<AwaitAllOp> awaitAllUsers; |
223 | |
224 | auto isAwaitAll = [&](Operation *op) -> bool { |
225 | if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) { |
226 | awaitAllUsers.push_back(awaitAll); |
227 | return true; |
228 | } |
229 | return false; |
230 | }; |
231 | |
232 | // Check if all users of the group are `await_all` operations. |
233 | if (!llvm::all_of(op->getUsers(), isAwaitAll)) |
234 | return failure(); |
235 | |
236 | // If group is only awaited without adding anything to it, we can safely erase |
237 | // the create operation and all users. |
238 | for (AwaitAllOp awaitAll : awaitAllUsers) |
239 | rewriter.eraseOp(awaitAll); |
240 | rewriter.eraseOp(op); |
241 | |
242 | return success(); |
243 | } |
244 | |
245 | //===----------------------------------------------------------------------===// |
246 | /// AwaitOp |
247 | //===----------------------------------------------------------------------===// |
248 | |
249 | void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, |
250 | ArrayRef<NamedAttribute> attrs) { |
251 | result.addOperands({operand}); |
252 | result.attributes.append(attrs.begin(), attrs.end()); |
253 | |
254 | // Add unwrapped async.value type to the returned values types. |
255 | if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType())) |
256 | result.addTypes(valueType.getValueType()); |
257 | } |
258 | |
259 | static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, |
260 | Type &resultType) { |
261 | if (parser.parseType(result&: operandType)) |
262 | return failure(); |
263 | |
264 | // Add unwrapped async.value type to the returned values types. |
265 | if (auto valueType = llvm::dyn_cast<ValueType>(operandType)) |
266 | resultType = valueType.getValueType(); |
267 | |
268 | return success(); |
269 | } |
270 | |
271 | static void printAwaitResultType(OpAsmPrinter &p, Operation *op, |
272 | Type operandType, Type resultType) { |
273 | p << operandType; |
274 | } |
275 | |
276 | LogicalResult AwaitOp::verify() { |
277 | Type argType = getOperand().getType(); |
278 | |
279 | // Awaiting on a token does not have any results. |
280 | if (llvm::isa<TokenType>(argType) && !getResultTypes().empty()) |
281 | return emitOpError("awaiting on a token must have empty result" ); |
282 | |
283 | // Awaiting on a value unwraps the async value type. |
284 | if (auto value = llvm::dyn_cast<ValueType>(argType)) { |
285 | if (*getResultType() != value.getValueType()) |
286 | return emitOpError() << "result type " << *getResultType() |
287 | << " does not match async value type " |
288 | << value.getValueType(); |
289 | } |
290 | |
291 | return success(); |
292 | } |
293 | |
294 | //===----------------------------------------------------------------------===// |
295 | // FuncOp |
296 | //===----------------------------------------------------------------------===// |
297 | |
298 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
299 | FunctionType type, ArrayRef<NamedAttribute> attrs, |
300 | ArrayRef<DictionaryAttr> argAttrs) { |
301 | state.addAttribute(SymbolTable::getSymbolAttrName(), |
302 | builder.getStringAttr(name)); |
303 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); |
304 | |
305 | state.attributes.append(attrs.begin(), attrs.end()); |
306 | state.addRegion(); |
307 | |
308 | if (argAttrs.empty()) |
309 | return; |
310 | assert(type.getNumInputs() == argAttrs.size()); |
311 | function_interface_impl::addArgAndResultAttrs( |
312 | builder, state, argAttrs, /*resultAttrs=*/std::nullopt, |
313 | getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); |
314 | } |
315 | |
316 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
317 | auto buildFuncType = |
318 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
319 | function_interface_impl::VariadicFlag, |
320 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
321 | |
322 | return function_interface_impl::parseFunctionOp( |
323 | parser, result, /*allowVariadic=*/false, |
324 | getFunctionTypeAttrName(result.name), buildFuncType, |
325 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
326 | } |
327 | |
328 | void FuncOp::print(OpAsmPrinter &p) { |
329 | function_interface_impl::printFunctionOp( |
330 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
331 | getArgAttrsAttrName(), getResAttrsAttrName()); |
332 | } |
333 | |
334 | /// Check that the result type of async.func is not void and must be |
335 | /// some async token or async values. |
336 | LogicalResult FuncOp::verify() { |
337 | auto resultTypes = getResultTypes(); |
338 | if (resultTypes.empty()) |
339 | return emitOpError() |
340 | << "result is expected to be at least of size 1, but got " |
341 | << resultTypes.size(); |
342 | |
343 | for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) { |
344 | auto type = resultTypes[i]; |
345 | if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type)) |
346 | return emitOpError() << "result type must be async value type or async " |
347 | "token type, but got " |
348 | << type; |
349 | // We only allow AsyncToken appear as the first return value |
350 | if (llvm::isa<TokenType>(type) && i != 0) { |
351 | return emitOpError() |
352 | << " results' (optional) async token type is expected " |
353 | "to appear as the 1st return value, but got " |
354 | << i + 1; |
355 | } |
356 | } |
357 | |
358 | return success(); |
359 | } |
360 | |
361 | //===----------------------------------------------------------------------===// |
362 | /// CallOp |
363 | //===----------------------------------------------------------------------===// |
364 | |
365 | LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
366 | // Check that the callee attribute was specified. |
367 | auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee" ); |
368 | if (!fnAttr) |
369 | return emitOpError("requires a 'callee' symbol reference attribute" ); |
370 | FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); |
371 | if (!fn) |
372 | return emitOpError() << "'" << fnAttr.getValue() |
373 | << "' does not reference a valid async function" ; |
374 | |
375 | // Verify that the operand and result types match the callee. |
376 | auto fnType = fn.getFunctionType(); |
377 | if (fnType.getNumInputs() != getNumOperands()) |
378 | return emitOpError("incorrect number of operands for callee" ); |
379 | |
380 | for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) |
381 | if (getOperand(i).getType() != fnType.getInput(i)) |
382 | return emitOpError("operand type mismatch: expected operand type " ) |
383 | << fnType.getInput(i) << ", but provided " |
384 | << getOperand(i).getType() << " for operand number " << i; |
385 | |
386 | if (fnType.getNumResults() != getNumResults()) |
387 | return emitOpError("incorrect number of results for callee" ); |
388 | |
389 | for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) |
390 | if (getResult(i).getType() != fnType.getResult(i)) { |
391 | auto diag = emitOpError("result type mismatch at index " ) << i; |
392 | diag.attachNote() << " op result types: " << getResultTypes(); |
393 | diag.attachNote() << "function result types: " << fnType.getResults(); |
394 | return diag; |
395 | } |
396 | |
397 | return success(); |
398 | } |
399 | |
400 | FunctionType CallOp::getCalleeType() { |
401 | return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); |
402 | } |
403 | |
404 | //===----------------------------------------------------------------------===// |
405 | /// ReturnOp |
406 | //===----------------------------------------------------------------------===// |
407 | |
408 | LogicalResult ReturnOp::verify() { |
409 | auto funcOp = (*this)->getParentOfType<FuncOp>(); |
410 | ArrayRef<Type> resultTypes = funcOp.isStateful() |
411 | ? funcOp.getResultTypes().drop_front() |
412 | : funcOp.getResultTypes(); |
413 | // Get the underlying value types from async types returned from the |
414 | // parent `async.func` operation. |
415 | auto types = llvm::map_range(resultTypes, [](const Type &result) { |
416 | return llvm::cast<ValueType>(result).getValueType(); |
417 | }); |
418 | |
419 | if (getOperandTypes() != types) |
420 | return emitOpError("operand types do not match the types returned from " |
421 | "the parent FuncOp" ); |
422 | |
423 | return success(); |
424 | } |
425 | |
426 | //===----------------------------------------------------------------------===// |
427 | // TableGen'd op method definitions |
428 | //===----------------------------------------------------------------------===// |
429 | |
430 | #define GET_OP_CLASSES |
431 | #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" |
432 | |
433 | //===----------------------------------------------------------------------===// |
434 | // TableGen'd type method definitions |
435 | //===----------------------------------------------------------------------===// |
436 | |
437 | #define GET_TYPEDEF_CLASSES |
438 | #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" |
439 | |
440 | void ValueType::print(AsmPrinter &printer) const { |
441 | printer << "<" ; |
442 | printer.printType(getValueType()); |
443 | printer << '>'; |
444 | } |
445 | |
446 | Type ValueType::parse(mlir::AsmParser &parser) { |
447 | Type ty; |
448 | if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { |
449 | parser.emitError(parser.getNameLoc(), "failed to parse async value type" ); |
450 | return Type(); |
451 | } |
452 | return ValueType::get(ty); |
453 | } |
454 | |