1//===- Async.cpp - MLIR Async Operations ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Async/IR/Async.h"
10
11#include "mlir/IR/DialectImplementation.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/Interfaces/FunctionImplementation.h"
14#include "llvm/ADT/MapVector.h"
15#include "llvm/ADT/TypeSwitch.h"
16
17using namespace mlir;
18using namespace mlir::async;
19
20#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
21
22constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
23
24void AsyncDialect::initialize() {
25 addOperations<
26#define GET_OP_LIST
27#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
28 >();
29 addTypes<
30#define GET_TYPEDEF_LIST
31#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
32 >();
33}
34
35//===----------------------------------------------------------------------===//
36/// ExecuteOp
37//===----------------------------------------------------------------------===//
38
39constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
40
41OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
42 assert(point == getBodyRegion() && "invalid region index");
43 return getBodyOperands();
44}
45
46bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
47 const auto getValueOrTokenType = [](Type type) {
48 if (auto value = llvm::dyn_cast<ValueType>(type))
49 return value.getValueType();
50 return type;
51 };
52 return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
53}
54
55void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
56 SmallVectorImpl<RegionSuccessor> &regions) {
57 // The `body` region branch back to the parent operation.
58 if (point == getBodyRegion()) {
59 regions.push_back(RegionSuccessor(getBodyResults()));
60 return;
61 }
62
63 // Otherwise the successor is the body region.
64 regions.push_back(
65 RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
66}
67
68void ExecuteOp::build(OpBuilder &builder, OperationState &result,
69 TypeRange resultTypes, ValueRange dependencies,
70 ValueRange operands, BodyBuilderFn bodyBuilder) {
71 OpBuilder::InsertionGuard guard(builder);
72 result.addOperands(dependencies);
73 result.addOperands(operands);
74
75 // Add derived `operandSegmentSizes` attribute based on parsed operands.
76 int32_t numDependencies = dependencies.size();
77 int32_t numOperands = operands.size();
78 auto operandSegmentSizes =
79 builder.getDenseI32ArrayAttr({numDependencies, numOperands});
80 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
81
82 // First result is always a token, and then `resultTypes` wrapped into
83 // `async.value`.
84 result.addTypes({TokenType::get(result.getContext())});
85 for (Type type : resultTypes)
86 result.addTypes(ValueType::get(type));
87
88 // Add a body region with block arguments as unwrapped async value operands.
89 Region *bodyRegion = result.addRegion();
90 Block *bodyBlock = builder.createBlock(bodyRegion);
91 for (Value operand : operands) {
92 auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
93 bodyBlock->addArgument(valueType ? valueType.getValueType()
94 : operand.getType(),
95 operand.getLoc());
96 }
97
98 // Create the default terminator if the builder is not provided and if the
99 // expected result is empty. Otherwise, leave this to the caller
100 // because we don't know which values to return from the execute op.
101 if (resultTypes.empty() && !bodyBuilder) {
102 builder.create<async::YieldOp>(result.location, ValueRange());
103 } else if (bodyBuilder) {
104 bodyBuilder(builder, result.location, bodyBlock->getArguments());
105 }
106}
107
108void ExecuteOp::print(OpAsmPrinter &p) {
109 // [%tokens,...]
110 if (!getDependencies().empty())
111 p << " [" << getDependencies() << "]";
112
113 // (%value as %unwrapped: !async.value<!arg.type>, ...)
114 if (!getBodyOperands().empty()) {
115 p << " (";
116 Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
117 llvm::interleaveComma(
118 getBodyOperands(), p, [&, n = 0](Value operand) mutable {
119 Value argument = entry ? entry->getArgument(n++) : Value();
120 p << operand << " as " << argument << ": " << operand.getType();
121 });
122 p << ")";
123 }
124
125 // -> (!async.value<!return.type>, ...)
126 p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
127 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
128 {kOperandSegmentSizesAttr});
129 p << ' ';
130 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
131}
132
133ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
134 MLIRContext *ctx = result.getContext();
135
136 // Sizes of parsed variadic operands, will be updated below after parsing.
137 int32_t numDependencies = 0;
138
139 auto tokenTy = TokenType::get(ctx);
140
141 // Parse dependency tokens.
142 if (succeeded(parser.parseOptionalLSquare())) {
143 SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
144 if (parser.parseOperandList(tokenArgs) ||
145 parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
146 parser.parseRSquare())
147 return failure();
148
149 numDependencies = tokenArgs.size();
150 }
151
152 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
153 SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
154 SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
155 SmallVector<Type, 4> valueTypes;
156
157 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
158 auto parseAsyncValueArg = [&]() -> ParseResult {
159 if (parser.parseOperand(valueArgs.emplace_back()) ||
160 parser.parseKeyword("as") ||
161 parser.parseArgument(unwrappedArgs.emplace_back()) ||
162 parser.parseColonType(valueTypes.emplace_back()))
163 return failure();
164
165 auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
166 unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
167 return success();
168 };
169
170 auto argsLoc = parser.getCurrentLocation();
171 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
172 parseAsyncValueArg) ||
173 parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
174 return failure();
175
176 int32_t numOperands = valueArgs.size();
177
178 // Add derived `operandSegmentSizes` attribute based on parsed operands.
179 auto operandSegmentSizes =
180 parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
181 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
182
183 // Parse the types of results returned from the async execute op.
184 SmallVector<Type, 4> resultTypes;
185 NamedAttrList attrs;
186 if (parser.parseOptionalArrowTypeList(resultTypes) ||
187 // Async execute first result is always a completion token.
188 parser.addTypeToList(tokenTy, result.types) ||
189 parser.addTypesToList(resultTypes, result.types) ||
190 // Parse operation attributes.
191 parser.parseOptionalAttrDictWithKeyword(attrs))
192 return failure();
193
194 result.addAttributes(attrs);
195
196 // Parse asynchronous region.
197 Region *body = result.addRegion();
198 return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
199}
200
201LogicalResult ExecuteOp::verifyRegions() {
202 // Unwrap async.execute value operands types.
203 auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
204 return llvm::cast<ValueType>(operand.getType()).getValueType();
205 });
206
207 // Verify that unwrapped argument types matches the body region arguments.
208 if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
209 return emitOpError("async body region argument types do not match the "
210 "execute operation arguments types");
211
212 return success();
213}
214
215//===----------------------------------------------------------------------===//
216/// CreateGroupOp
217//===----------------------------------------------------------------------===//
218
219LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
220 PatternRewriter &rewriter) {
221 // Find all `await_all` users of the group.
222 llvm::SmallVector<AwaitAllOp> awaitAllUsers;
223
224 auto isAwaitAll = [&](Operation *op) -> bool {
225 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
226 awaitAllUsers.push_back(awaitAll);
227 return true;
228 }
229 return false;
230 };
231
232 // Check if all users of the group are `await_all` operations.
233 if (!llvm::all_of(op->getUsers(), isAwaitAll))
234 return failure();
235
236 // If group is only awaited without adding anything to it, we can safely erase
237 // the create operation and all users.
238 for (AwaitAllOp awaitAll : awaitAllUsers)
239 rewriter.eraseOp(awaitAll);
240 rewriter.eraseOp(op);
241
242 return success();
243}
244
245//===----------------------------------------------------------------------===//
246/// AwaitOp
247//===----------------------------------------------------------------------===//
248
249void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
250 ArrayRef<NamedAttribute> attrs) {
251 result.addOperands({operand});
252 result.attributes.append(attrs.begin(), attrs.end());
253
254 // Add unwrapped async.value type to the returned values types.
255 if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
256 result.addTypes(valueType.getValueType());
257}
258
259static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
260 Type &resultType) {
261 if (parser.parseType(result&: operandType))
262 return failure();
263
264 // Add unwrapped async.value type to the returned values types.
265 if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
266 resultType = valueType.getValueType();
267
268 return success();
269}
270
271static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
272 Type operandType, Type resultType) {
273 p << operandType;
274}
275
276LogicalResult AwaitOp::verify() {
277 Type argType = getOperand().getType();
278
279 // Awaiting on a token does not have any results.
280 if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
281 return emitOpError("awaiting on a token must have empty result");
282
283 // Awaiting on a value unwraps the async value type.
284 if (auto value = llvm::dyn_cast<ValueType>(argType)) {
285 if (*getResultType() != value.getValueType())
286 return emitOpError() << "result type " << *getResultType()
287 << " does not match async value type "
288 << value.getValueType();
289 }
290
291 return success();
292}
293
294//===----------------------------------------------------------------------===//
295// FuncOp
296//===----------------------------------------------------------------------===//
297
298void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
299 FunctionType type, ArrayRef<NamedAttribute> attrs,
300 ArrayRef<DictionaryAttr> argAttrs) {
301 state.addAttribute(SymbolTable::getSymbolAttrName(),
302 builder.getStringAttr(name));
303 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
304
305 state.attributes.append(attrs.begin(), attrs.end());
306 state.addRegion();
307
308 if (argAttrs.empty())
309 return;
310 assert(type.getNumInputs() == argAttrs.size());
311 function_interface_impl::addArgAndResultAttrs(
312 builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
313 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
314}
315
316ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
317 auto buildFuncType =
318 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
319 function_interface_impl::VariadicFlag,
320 std::string &) { return builder.getFunctionType(argTypes, results); };
321
322 return function_interface_impl::parseFunctionOp(
323 parser, result, /*allowVariadic=*/false,
324 getFunctionTypeAttrName(result.name), buildFuncType,
325 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
326}
327
328void FuncOp::print(OpAsmPrinter &p) {
329 function_interface_impl::printFunctionOp(
330 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
331 getArgAttrsAttrName(), getResAttrsAttrName());
332}
333
334/// Check that the result type of async.func is not void and must be
335/// some async token or async values.
336LogicalResult FuncOp::verify() {
337 auto resultTypes = getResultTypes();
338 if (resultTypes.empty())
339 return emitOpError()
340 << "result is expected to be at least of size 1, but got "
341 << resultTypes.size();
342
343 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
344 auto type = resultTypes[i];
345 if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
346 return emitOpError() << "result type must be async value type or async "
347 "token type, but got "
348 << type;
349 // We only allow AsyncToken appear as the first return value
350 if (llvm::isa<TokenType>(type) && i != 0) {
351 return emitOpError()
352 << " results' (optional) async token type is expected "
353 "to appear as the 1st return value, but got "
354 << i + 1;
355 }
356 }
357
358 return success();
359}
360
361//===----------------------------------------------------------------------===//
362/// CallOp
363//===----------------------------------------------------------------------===//
364
365LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
366 // Check that the callee attribute was specified.
367 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
368 if (!fnAttr)
369 return emitOpError("requires a 'callee' symbol reference attribute");
370 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
371 if (!fn)
372 return emitOpError() << "'" << fnAttr.getValue()
373 << "' does not reference a valid async function";
374
375 // Verify that the operand and result types match the callee.
376 auto fnType = fn.getFunctionType();
377 if (fnType.getNumInputs() != getNumOperands())
378 return emitOpError("incorrect number of operands for callee");
379
380 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
381 if (getOperand(i).getType() != fnType.getInput(i))
382 return emitOpError("operand type mismatch: expected operand type ")
383 << fnType.getInput(i) << ", but provided "
384 << getOperand(i).getType() << " for operand number " << i;
385
386 if (fnType.getNumResults() != getNumResults())
387 return emitOpError("incorrect number of results for callee");
388
389 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
390 if (getResult(i).getType() != fnType.getResult(i)) {
391 auto diag = emitOpError("result type mismatch at index ") << i;
392 diag.attachNote() << " op result types: " << getResultTypes();
393 diag.attachNote() << "function result types: " << fnType.getResults();
394 return diag;
395 }
396
397 return success();
398}
399
400FunctionType CallOp::getCalleeType() {
401 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
402}
403
404//===----------------------------------------------------------------------===//
405/// ReturnOp
406//===----------------------------------------------------------------------===//
407
408LogicalResult ReturnOp::verify() {
409 auto funcOp = (*this)->getParentOfType<FuncOp>();
410 ArrayRef<Type> resultTypes = funcOp.isStateful()
411 ? funcOp.getResultTypes().drop_front()
412 : funcOp.getResultTypes();
413 // Get the underlying value types from async types returned from the
414 // parent `async.func` operation.
415 auto types = llvm::map_range(resultTypes, [](const Type &result) {
416 return llvm::cast<ValueType>(result).getValueType();
417 });
418
419 if (getOperandTypes() != types)
420 return emitOpError("operand types do not match the types returned from "
421 "the parent FuncOp");
422
423 return success();
424}
425
426//===----------------------------------------------------------------------===//
427// TableGen'd op method definitions
428//===----------------------------------------------------------------------===//
429
430#define GET_OP_CLASSES
431#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
432
433//===----------------------------------------------------------------------===//
434// TableGen'd type method definitions
435//===----------------------------------------------------------------------===//
436
437#define GET_TYPEDEF_CLASSES
438#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
439
440void ValueType::print(AsmPrinter &printer) const {
441 printer << "<";
442 printer.printType(getValueType());
443 printer << '>';
444}
445
446Type ValueType::parse(mlir::AsmParser &parser) {
447 Type ty;
448 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
449 parser.emitError(parser.getNameLoc(), "failed to parse async value type");
450 return Type();
451 }
452 return ValueType::get(ty);
453}
454

source code of mlir/lib/Dialect/Async/IR/Async.cpp