1//===- Async.cpp - MLIR Async Operations ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Async/IR/Async.h"
10
11#include "mlir/IR/DialectImplementation.h"
12#include "mlir/Interfaces/FunctionImplementation.h"
13#include "llvm/ADT/TypeSwitch.h"
14
15using namespace mlir;
16using namespace mlir::async;
17
18#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
19
20constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
21
22void AsyncDialect::initialize() {
23 addOperations<
24#define GET_OP_LIST
25#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
26 >();
27 addTypes<
28#define GET_TYPEDEF_LIST
29#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
30 >();
31}
32
33//===----------------------------------------------------------------------===//
34/// ExecuteOp
35//===----------------------------------------------------------------------===//
36
37constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
38
39OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
40 assert(point == getBodyRegion() && "invalid region index");
41 return getBodyOperands();
42}
43
44bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
45 const auto getValueOrTokenType = [](Type type) {
46 if (auto value = llvm::dyn_cast<ValueType>(Val&: type))
47 return value.getValueType();
48 return type;
49 };
50 return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
51}
52
53void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
54 SmallVectorImpl<RegionSuccessor> &regions) {
55 // The `body` region branch back to the parent operation.
56 if (point == getBodyRegion()) {
57 regions.push_back(Elt: RegionSuccessor(getBodyResults()));
58 return;
59 }
60
61 // Otherwise the successor is the body region.
62 regions.push_back(
63 Elt: RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
64}
65
66void ExecuteOp::build(OpBuilder &builder, OperationState &result,
67 TypeRange resultTypes, ValueRange dependencies,
68 ValueRange operands, BodyBuilderFn bodyBuilder) {
69 OpBuilder::InsertionGuard guard(builder);
70 result.addOperands(newOperands: dependencies);
71 result.addOperands(newOperands: operands);
72
73 // Add derived `operandSegmentSizes` attribute based on parsed operands.
74 int32_t numDependencies = dependencies.size();
75 int32_t numOperands = operands.size();
76 auto operandSegmentSizes =
77 builder.getDenseI32ArrayAttr(values: {numDependencies, numOperands});
78 result.addAttribute(name: kOperandSegmentSizesAttr, attr: operandSegmentSizes);
79
80 // First result is always a token, and then `resultTypes` wrapped into
81 // `async.value`.
82 result.addTypes(newTypes: {TokenType::get(ctx: result.getContext())});
83 for (Type type : resultTypes)
84 result.addTypes(newTypes: ValueType::get(valueType: type));
85
86 // Add a body region with block arguments as unwrapped async value operands.
87 Region *bodyRegion = result.addRegion();
88 Block *bodyBlock = builder.createBlock(parent: bodyRegion);
89 for (Value operand : operands) {
90 auto valueType = llvm::dyn_cast<ValueType>(Val: operand.getType());
91 bodyBlock->addArgument(type: valueType ? valueType.getValueType()
92 : operand.getType(),
93 loc: operand.getLoc());
94 }
95
96 // Create the default terminator if the builder is not provided and if the
97 // expected result is empty. Otherwise, leave this to the caller
98 // because we don't know which values to return from the execute op.
99 if (resultTypes.empty() && !bodyBuilder) {
100 builder.create<async::YieldOp>(location: result.location, args: ValueRange());
101 } else if (bodyBuilder) {
102 bodyBuilder(builder, result.location, bodyBlock->getArguments());
103 }
104}
105
106void ExecuteOp::print(OpAsmPrinter &p) {
107 // [%tokens,...]
108 if (!getDependencies().empty())
109 p << " [" << getDependencies() << "]";
110
111 // (%value as %unwrapped: !async.value<!arg.type>, ...)
112 if (!getBodyOperands().empty()) {
113 p << " (";
114 Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
115 llvm::interleaveComma(
116 c: getBodyOperands(), os&: p, each_fn: [&, n = 0](Value operand) mutable {
117 Value argument = entry ? entry->getArgument(i: n++) : Value();
118 p << operand << " as " << argument << ": " << operand.getType();
119 });
120 p << ")";
121 }
122
123 // -> (!async.value<!return.type>, ...)
124 p.printOptionalArrowTypeList(types: llvm::drop_begin(RangeOrContainer: getResultTypes()));
125 p.printOptionalAttrDictWithKeyword(attrs: (*this)->getAttrs(),
126 elidedAttrs: {kOperandSegmentSizesAttr});
127 p << ' ';
128 p.printRegion(blocks&: getBodyRegion(), /*printEntryBlockArgs=*/false);
129}
130
131ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
132 MLIRContext *ctx = result.getContext();
133
134 // Sizes of parsed variadic operands, will be updated below after parsing.
135 int32_t numDependencies = 0;
136
137 auto tokenTy = TokenType::get(ctx);
138
139 // Parse dependency tokens.
140 if (succeeded(Result: parser.parseOptionalLSquare())) {
141 SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
142 if (parser.parseOperandList(result&: tokenArgs) ||
143 parser.resolveOperands(operands&: tokenArgs, type: tokenTy, result&: result.operands) ||
144 parser.parseRSquare())
145 return failure();
146
147 numDependencies = tokenArgs.size();
148 }
149
150 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
151 SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
152 SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
153 SmallVector<Type, 4> valueTypes;
154
155 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
156 auto parseAsyncValueArg = [&]() -> ParseResult {
157 if (parser.parseOperand(result&: valueArgs.emplace_back()) ||
158 parser.parseKeyword(keyword: "as") ||
159 parser.parseArgument(result&: unwrappedArgs.emplace_back()) ||
160 parser.parseColonType(result&: valueTypes.emplace_back()))
161 return failure();
162
163 auto valueTy = llvm::dyn_cast<ValueType>(Val&: valueTypes.back());
164 unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
165 return success();
166 };
167
168 auto argsLoc = parser.getCurrentLocation();
169 if (parser.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::OptionalParen,
170 parseElementFn: parseAsyncValueArg) ||
171 parser.resolveOperands(operands&: valueArgs, types&: valueTypes, loc: argsLoc, result&: result.operands))
172 return failure();
173
174 int32_t numOperands = valueArgs.size();
175
176 // Add derived `operandSegmentSizes` attribute based on parsed operands.
177 auto operandSegmentSizes =
178 parser.getBuilder().getDenseI32ArrayAttr(values: {numDependencies, numOperands});
179 result.addAttribute(name: kOperandSegmentSizesAttr, attr: operandSegmentSizes);
180
181 // Parse the types of results returned from the async execute op.
182 SmallVector<Type, 4> resultTypes;
183 NamedAttrList attrs;
184 if (parser.parseOptionalArrowTypeList(result&: resultTypes) ||
185 // Async execute first result is always a completion token.
186 parser.addTypeToList(type: tokenTy, result&: result.types) ||
187 parser.addTypesToList(types: resultTypes, result&: result.types) ||
188 // Parse operation attributes.
189 parser.parseOptionalAttrDictWithKeyword(result&: attrs))
190 return failure();
191
192 result.addAttributes(newAttributes: attrs);
193
194 // Parse asynchronous region.
195 Region *body = result.addRegion();
196 return parser.parseRegion(region&: *body, /*arguments=*/unwrappedArgs);
197}
198
199LogicalResult ExecuteOp::verifyRegions() {
200 // Unwrap async.execute value operands types.
201 auto unwrappedTypes = llvm::map_range(C: getBodyOperands(), F: [](Value operand) {
202 return llvm::cast<ValueType>(Val: operand.getType()).getValueType();
203 });
204
205 // Verify that unwrapped argument types matches the body region arguments.
206 if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
207 return emitOpError(message: "async body region argument types do not match the "
208 "execute operation arguments types");
209
210 return success();
211}
212
213//===----------------------------------------------------------------------===//
214/// CreateGroupOp
215//===----------------------------------------------------------------------===//
216
217LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
218 PatternRewriter &rewriter) {
219 // Find all `await_all` users of the group.
220 llvm::SmallVector<AwaitAllOp> awaitAllUsers;
221
222 auto isAwaitAll = [&](Operation *op) -> bool {
223 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(Val: op)) {
224 awaitAllUsers.push_back(Elt: awaitAll);
225 return true;
226 }
227 return false;
228 };
229
230 // Check if all users of the group are `await_all` operations.
231 if (!llvm::all_of(Range: op->getUsers(), P: isAwaitAll))
232 return failure();
233
234 // If group is only awaited without adding anything to it, we can safely erase
235 // the create operation and all users.
236 for (AwaitAllOp awaitAll : awaitAllUsers)
237 rewriter.eraseOp(op: awaitAll);
238 rewriter.eraseOp(op);
239
240 return success();
241}
242
243//===----------------------------------------------------------------------===//
244/// AwaitOp
245//===----------------------------------------------------------------------===//
246
247void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
248 ArrayRef<NamedAttribute> attrs) {
249 result.addOperands(newOperands: {operand});
250 result.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
251
252 // Add unwrapped async.value type to the returned values types.
253 if (auto valueType = llvm::dyn_cast<ValueType>(Val: operand.getType()))
254 result.addTypes(newTypes: valueType.getValueType());
255}
256
257static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
258 Type &resultType) {
259 if (parser.parseType(result&: operandType))
260 return failure();
261
262 // Add unwrapped async.value type to the returned values types.
263 if (auto valueType = llvm::dyn_cast<ValueType>(Val&: operandType))
264 resultType = valueType.getValueType();
265
266 return success();
267}
268
269static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
270 Type operandType, Type resultType) {
271 p << operandType;
272}
273
274LogicalResult AwaitOp::verify() {
275 Type argType = getOperand().getType();
276
277 // Awaiting on a token does not have any results.
278 if (llvm::isa<TokenType>(Val: argType) && !getResultTypes().empty())
279 return emitOpError(message: "awaiting on a token must have empty result");
280
281 // Awaiting on a value unwraps the async value type.
282 if (auto value = llvm::dyn_cast<ValueType>(Val&: argType)) {
283 if (*getResultType() != value.getValueType())
284 return emitOpError() << "result type " << *getResultType()
285 << " does not match async value type "
286 << value.getValueType();
287 }
288
289 return success();
290}
291
292//===----------------------------------------------------------------------===//
293// FuncOp
294//===----------------------------------------------------------------------===//
295
296void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
297 FunctionType type, ArrayRef<NamedAttribute> attrs,
298 ArrayRef<DictionaryAttr> argAttrs) {
299 state.addAttribute(name: SymbolTable::getSymbolAttrName(),
300 attr: builder.getStringAttr(bytes: name));
301 state.addAttribute(name: getFunctionTypeAttrName(name: state.name), attr: TypeAttr::get(type));
302
303 state.attributes.append(inStart: attrs.begin(), inEnd: attrs.end());
304 state.addRegion();
305
306 if (argAttrs.empty())
307 return;
308 assert(type.getNumInputs() == argAttrs.size());
309 call_interface_impl::addArgAndResultAttrs(
310 builder, result&: state, argAttrs, /*resultAttrs=*/{},
311 argAttrsName: getArgAttrsAttrName(name: state.name), resAttrsName: getResAttrsAttrName(name: state.name));
312}
313
314ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
315 auto buildFuncType =
316 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
317 function_interface_impl::VariadicFlag,
318 std::string &) { return builder.getFunctionType(inputs: argTypes, results); };
319
320 return function_interface_impl::parseFunctionOp(
321 parser, result, /*allowVariadic=*/false,
322 typeAttrName: getFunctionTypeAttrName(name: result.name), funcTypeBuilder: buildFuncType,
323 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
324}
325
326void FuncOp::print(OpAsmPrinter &p) {
327 function_interface_impl::printFunctionOp(
328 p, op: *this, /*isVariadic=*/false, typeAttrName: getFunctionTypeAttrName(),
329 argAttrsName: getArgAttrsAttrName(), resAttrsName: getResAttrsAttrName());
330}
331
332/// Check that the result type of async.func is not void and must be
333/// some async token or async values.
334LogicalResult FuncOp::verify() {
335 auto resultTypes = getResultTypes();
336 if (resultTypes.empty())
337 return emitOpError()
338 << "result is expected to be at least of size 1, but got "
339 << resultTypes.size();
340
341 for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
342 auto type = resultTypes[i];
343 if (!llvm::isa<TokenType>(Val: type) && !llvm::isa<ValueType>(Val: type))
344 return emitOpError() << "result type must be async value type or async "
345 "token type, but got "
346 << type;
347 // We only allow AsyncToken appear as the first return value
348 if (llvm::isa<TokenType>(Val: type) && i != 0) {
349 return emitOpError()
350 << " results' (optional) async token type is expected "
351 "to appear as the 1st return value, but got "
352 << i + 1;
353 }
354 }
355
356 return success();
357}
358
359//===----------------------------------------------------------------------===//
360/// CallOp
361//===----------------------------------------------------------------------===//
362
363LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
364 // Check that the callee attribute was specified.
365 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>(name: "callee");
366 if (!fnAttr)
367 return emitOpError(message: "requires a 'callee' symbol reference attribute");
368 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(from: *this, symbol: fnAttr);
369 if (!fn)
370 return emitOpError() << "'" << fnAttr.getValue()
371 << "' does not reference a valid async function";
372
373 // Verify that the operand and result types match the callee.
374 auto fnType = fn.getFunctionType();
375 if (fnType.getNumInputs() != getNumOperands())
376 return emitOpError(message: "incorrect number of operands for callee");
377
378 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
379 if (getOperand(i).getType() != fnType.getInput(i))
380 return emitOpError(message: "operand type mismatch: expected operand type ")
381 << fnType.getInput(i) << ", but provided "
382 << getOperand(i).getType() << " for operand number " << i;
383
384 if (fnType.getNumResults() != getNumResults())
385 return emitOpError(message: "incorrect number of results for callee");
386
387 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
388 if (getResult(i).getType() != fnType.getResult(i)) {
389 auto diag = emitOpError(message: "result type mismatch at index ") << i;
390 diag.attachNote() << " op result types: " << getResultTypes();
391 diag.attachNote() << "function result types: " << fnType.getResults();
392 return diag;
393 }
394
395 return success();
396}
397
398FunctionType CallOp::getCalleeType() {
399 return FunctionType::get(context: getContext(), inputs: getOperandTypes(), results: getResultTypes());
400}
401
402//===----------------------------------------------------------------------===//
403/// ReturnOp
404//===----------------------------------------------------------------------===//
405
406LogicalResult ReturnOp::verify() {
407 auto funcOp = (*this)->getParentOfType<FuncOp>();
408 ArrayRef<Type> resultTypes = funcOp.isStateful()
409 ? funcOp.getResultTypes().drop_front()
410 : funcOp.getResultTypes();
411 // Get the underlying value types from async types returned from the
412 // parent `async.func` operation.
413 auto types = llvm::map_range(C&: resultTypes, F: [](const Type &result) {
414 return llvm::cast<ValueType>(Val: result).getValueType();
415 });
416
417 if (getOperandTypes() != types)
418 return emitOpError(message: "operand types do not match the types returned from "
419 "the parent FuncOp");
420
421 return success();
422}
423
424//===----------------------------------------------------------------------===//
425// TableGen'd op method definitions
426//===----------------------------------------------------------------------===//
427
428#define GET_OP_CLASSES
429#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
430
431//===----------------------------------------------------------------------===//
432// TableGen'd type method definitions
433//===----------------------------------------------------------------------===//
434
435#define GET_TYPEDEF_CLASSES
436#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
437
438void ValueType::print(AsmPrinter &printer) const {
439 printer << "<";
440 printer.printType(type: getValueType());
441 printer << '>';
442}
443
444Type ValueType::parse(mlir::AsmParser &parser) {
445 Type ty;
446 if (parser.parseLess() || parser.parseType(result&: ty) || parser.parseGreater()) {
447 parser.emitError(loc: parser.getNameLoc(), message: "failed to parse async value type");
448 return Type();
449 }
450 return ValueType::get(valueType: ty);
451}
452

source code of mlir/lib/Dialect/Async/IR/Async.cpp